# Scenario 3. Multicategory classification

In [1]:
R.version
library(tidyverse)
library(caret)
library(randomForest)
library(reticulate)
library(mltools)
library(mltest)
library(pROC)
np <- import("numpy")

               _                           
platform       x86_64-pc-linux-gnu         
arch           x86_64                      
os             linux-gnu                   
system         x86_64, linux-gnu           
status                                     
major          3                           
minor          5.3                         
year           2019                        
month          03                          
day            11                          
svn rev        76217                       
language       R                           
version.string R version 3.5.3 (2019-03-11)
nickname       Great Truth                 

── [1mAttaching packages[22m ─────────────────────────────────────── tidyverse 1.2.1 ──

[32m✔[39m [34mggplot2[39m 3.2.1     [32m✔[39m [34mpurrr  [39m 0.3.2
[32m✔[39m [34mtibble [39m 2.1.3     [32m✔[39m [34mdplyr  [39m 0.8.3
[32m✔[39m [34mtidyr  [39m 1.0.2     [32m✔[39m [34mstringr[39m 1.4.0
[32m✔[39m [34mreadr  [39m 1.3.1     [32m✔[39m [34mforcats[39m 0.4.0

── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()

Loading required package: lattice


Attaching package: ‘caret’


The following object is masked from ‘package:purrr’:

    lift


randomForest 4.6-14

Type rfNews() to see new features/changes/bug fixes.


Attaching package: ‘randomForest’


The following object is masked from ‘package:dplyr’:

    combine


The following object is masked from ‘package:g

### Configuration

In [2]:
data_path = "./data/simulation/s3"
path_genus = "./data/genus48"
count_path = 'data/simulation/count/'

y_path = sprintf('%s/%s', data_path, 'y.csv')
tree_info_path = './data/genus48/genus48_dic.csv'
count_path = './data/simulation/count'
count_list_path = './data/simulation/gcount_list.csv'
idx_path = './data/simulation/s1/idx.csv'

num_classes = 0 # regression
tree_level_list = c('Genus', 'Family', 'Order', 'Class', 'Phylum')

In [3]:
# # Read phylogenetic tree information

# phylogenetic_tree_info = read.csv(tree_info_path)
# phylogenetic_tree_info = phylogenetic_tree_info %>% select(tree_level_list)

# print(sprintf('Phylogenetic tree level list: %s', str_c(phylogenetic_tree_info %>% colnames, collapse = ', ')))

### Read dataset

#### Read training, test dataset

In [4]:
read_dataset <- function(x_path, y_path, sim){
    print(str_c('Load data for repetition ', sim))
    x = read.csv(x_path)
    y = read.csv(y_path)[,sim]
    x = (x - max(x)) / (max(x) - min(x))

    idxs = idxs_total[, sim]
    remain_idxs = setdiff(seq(1, dim(x)[1]), idxs)

    x_train = x[idxs,]
    x_test = x[remain_idxs,]
    y_train = y[idxs]
    y_test = y[remain_idxs]
    
    return (list(x_train, x_test, y_train, y_test))
}

In [5]:
idxs_total = read.csv(idx_path)
number_of_fold = dim(idxs_total)[2]; number_of_fold
x_list = read.csv(count_list_path, header = FALSE)
x_path = x_list$V1 %>% sprintf('%s/%s', count_path, .)

#### Read true tree weight

In [6]:
tw_1 = np$load(sprintf('%s/tw_1.npy', data_path))

## Random Forest

### Importance type

* See <https://stats.stackexchange.com/questions/92419/relative-importance-of-a-set-of-predictors-in-a-random-forests-classification-in>

Here are the definitions of the variable importance measures.

- `type=1`: **Mean decrease in accuracy**
    - The first measure is computed from permuting Out-of-bag (OOB) data.
    - For each tree, the prediction error on the out-of-bag portion of the data is recorded (error rate for classification, MSE for regression). Then the same is done after permuting each predictor variable. 
    - The difference between the two are then averaged over all trees, and normalized by the standard deviation of the differences.
    - If the standard deviation of the differences is equal to 0 for a variable, the division is not done (but the average is almost always equal to 0 in that case).

- `type=2`: **Mean decrease in node impurity**
    - The second measure is the total decrease in node impurities from splitting on the variable, averaged over all trees. 
    - For classification, the node impurity is measured by the Gini index. 
    - For regression, it is measured by residual sum of squares.
    
### Feature selection

* `vi_f`: variable importance by Gini importance (`type=2`)
* `relative_vi_f` : relative variable importance
* `thrd`: threshold for relative variable importance
* Select features which have relative variable importance `relative_vi_f` equal or larger than threshold `thrd`

### Simulate for all $n$

In [7]:
random_forest_res <- function(fold, importance_type=2, fs_thrd = 0.1){
    print(sprintf('-----------------------------------------------------------------'))
    print(sprintf('Random Forest computation for %dth repetition', fold))

    dataset = read_dataset(x_path[fold], y_path, fold)
    x_train = dataset[[1]]
    x_test = dataset[[2]]
    y_train = dataset[[3]]
    y_test = dataset[[4]]
    
    # Multicategory classification
    y_train = factor(y_train, levels = c(0,1,2), ordered=TRUE)
    y_test = factor(y_test, levels = c(0,1,2), ordered=TRUE)

    fit.rf <- randomForest(y_train~.,data=x_train, ntree=1000,  mtry=10, importance=TRUE)
    train.pred <- fit.rf$predicted
    test.pred <- predict(fit.rf,x_test)

    ml_res <- ml_test(train.pred, y_train)
    ml_roc <- multiclass.roc(y_train, factor(train.pred, levels=c(0,1,2), ordered=TRUE))
    train_sensitivity <- ml_res$recall
    train_specificity <- ml_res$specificitye
    train_gmeasure <- sqrt(train_sensitivity*train_specificity)
    train_accuracy <- ml_res$accuracy
    train_auc <- ml_roc$auc

    ml_res <- ml_test(test.pred, y_test)
    ml_roc <- multiclass.roc(y_test, factor(test.pred, levels=c(0,1,2), ordered=TRUE))
    test_sensitivity <- ml_res$recall
    test_specificity <- ml_res$specificitye
    test_gmeasure <- sqrt(test_sensitivity*test_specificity)
    test_accuracy <- ml_res$accuracy
    test_auc <- ml_roc$auc
    
    # Feature selection
    ## variable importance
    vi_f = importance(fit.rf, type=importance_type)
    relative_vi_f <- vi_f / sum(vi_f)
    selected_genus <- ifelse(relative_vi_f >= fs_thrd, 1, 0)
    
    order <- order(relative_vi_f, decreasing = TRUE)
    sorted_relative_vi_f <- relative_vi_f[order]
    names(sorted_relative_vi_f) <- colnames(x_train)[order]
    print(sorted_relative_vi_f)

    fold_genus = apply(tw_1[fold,,], 1, sum)
    names(fold_genus) <- x_train %>% colnames

    fs_conf_table <- table(selected_genus, fold_genus)
    
    fs_sensitivity <- sensitivity(fs_conf_table) 
    fs_specificity <- specificity(fs_conf_table)
    fs_gmeasure <- sqrt(fs_sensitivity*fs_specificity)
    fs_accuracy <- sum(diag(fs_conf_table))/sum(fs_conf_table)

    print(sprintf('Train sensitivity: %s, Train sensitivity: %s, Train gmeasure: %s, Train accuracy: %s, Train AUC: %s',
                  train_sensitivity, train_specificity, train_gmeasure, train_accuracy, train_auc))
    print(sprintf('Test sensitivity: %s, Test sensitivity: %s, Test gmeasure: %s, Test accuracy: %s, Test AUC: %s',
                  test_sensitivity, test_specificity, test_gmeasure, test_accuracy, test_auc))
    print(sprintf('FS sensitivity: %s, FS sensitivity: %s, FS gmeasure: %s, FS accuracy: %s',
                  fs_sensitivity, fs_specificity, fs_gmeasure, fs_accuracy))
    
    return (c(train_sensitivity, train_specificity, train_gmeasure, train_accuracy, train_auc, 
              test_sensitivity, test_specificity, test_gmeasure, test_accuracy, test_auc,
              fs_sensitivity, fs_specificity, fs_gmeasure, fs_accuracy))
}

In [8]:
set.seed(100)
# res <- sapply(seq(1,1), random_forest_res)
res <- sapply(seq(1,10), random_forest_res)
# res <- sapply(seq(1,number_of_fold), random_forest_res)

[1] "-----------------------------------------------------------------"
[1] "Random Forest computation for 1th repetition"
[1] "Load data for repetition 1"


Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia                  Oribacterium 
                  0.312434606                   0.072892816 
            Propionibacterium                    Tropheryma 
                  0.058668299                   0.049747378 
                     Moryella                Pseudonocardia 
                  0.030999674                   0.028320096 
                  Actinomyces                    Filifactor 
                  0.028075688                   0.026275622 
                Porphyromonas     TM7_genera_incertae_sedis 
                  0.023310208                   0.015565563 
                   Prevotella                     Treponema 
                  0.015562889                   0.015345723 
                Fusobacterium                 Streptococcus 
                  0.014754285                   0.014420247 
                  Selenomonas                   Veillonella 
                  0.014207864                   0.014157096 
                      Ge

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia                  Oribacterium 
                  0.303797013                   0.088741798 
            Propionibacterium                    Tropheryma 
                  0.055180461                   0.049187035 
                Porphyromonas                      Moryella 
                  0.035837560                   0.028448292 
                   Filifactor                   Actinomyces 
                  0.025079776                   0.023650570 
                  Veillonella                    Prevotella 
                  0.021401594                   0.017207955 
                  Selenomonas     OD1_genera_incertae_sedis 
                  0.016468168                   0.016184653 
                Lactobacillus                     Treponema 
                  0.015900957                   0.015424201 
    TM7_genera_incertae_sedis                 Streptococcus 
                  0.015041359                   0.014703300 
               Granulica

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia                  Oribacterium 
                  0.292229936                   0.082482077 
                   Tropheryma             Propionibacterium 
                  0.063747788                   0.054275324 
                  Actinomyces                Pseudonocardia 
                  0.032111245                   0.024443231 
                   Filifactor                 Porphyromonas 
                  0.024029716                   0.022809393 
                     Moryella                   Selenomonas 
                  0.022261514                   0.018512101 
                   Prevotella     TM7_genera_incertae_sedis 
                  0.016758270                   0.016662169 
                Campylobacter                   Veillonella 
                  0.016625984                   0.016493556 
                    Treponema                     Catonella 
                  0.015561658                   0.015390505 
                    Atop

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia                  Oribacterium 
                  0.291550639                   0.071935322 
            Propionibacterium                    Tropheryma 
                  0.068595145                   0.048039933 
                     Moryella                    Filifactor 
                  0.031853877                   0.028545455 
               Pseudonocardia                   Actinomyces 
                  0.026455266                   0.024314213 
                Porphyromonas                    Prevotella 
                  0.022497914                   0.020180250 
                Streptococcus                     Treponema 
                  0.020060977                   0.017659856 
                Fusobacterium                   Selenomonas 
                  0.015960133                   0.015684098 
                Lactobacillus                 Campylobacter 
                  0.015445324                   0.015441575 
                  Veillo

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia                  Oribacterium 
                 0.3020377532                  0.0715073463 
            Propionibacterium                    Tropheryma 
                 0.0558770051                  0.0428667925 
                   Filifactor                Pseudonocardia 
                 0.0390551510                  0.0306939336 
                     Moryella                 Porphyromonas 
                 0.0296513860                  0.0285074674 
                  Actinomyces                     Treponema 
                 0.0212224102                  0.0183283200 
                Streptococcus     TM7_genera_incertae_sedis 
                 0.0176784826                  0.0160054585 
                  Veillonella                   Selenomonas 
                 0.0152428450                  0.0148713224 
                   Prevotella                       Gemella 
                 0.0144649291                  0.0143979330 
    OD1_genera_incertae_

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia                  Oribacterium 
                 0.2969671813                  0.0898490157 
            Propionibacterium                    Tropheryma 
                 0.0529325298                  0.0454574687 
                  Actinomyces                 Porphyromonas 
                 0.0360041451                  0.0347882698 
                     Moryella                    Filifactor 
                 0.0256984189                  0.0210845609 
               Pseudonocardia                 Streptococcus 
                 0.0187760384                  0.0181409873 
               Granulicatella               Corynebacterium 
                 0.0167893716                  0.0166921642 
                Campylobacter                    Prevotella 
                 0.0166486754                  0.0161871433 
    TM7_genera_incertae_sedis                 Fusobacterium 
                 0.0160123157                  0.0159707641 
                  Veillo

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia                  Oribacterium 
                  0.295165255                   0.071533152 
            Propionibacterium                    Tropheryma 
                  0.068291228                   0.042741392 
                Porphyromonas                Pseudonocardia 
                  0.032191053                   0.030247613 
                     Moryella                   Actinomyces 
                  0.028542077                   0.026780738 
                      Gemella                    Prevotella 
                  0.021875024                   0.020592511 
                Fusobacterium                 Streptococcus 
                  0.019987852                   0.017095996 
                  Selenomonas     TM7_genera_incertae_sedis 
                  0.016913282                   0.016745044 
                    Treponema                   Veillonella 
                  0.016508409                   0.016153822 
                   Filif

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia                  Oribacterium 
                  0.311243234                   0.083811410 
            Propionibacterium                    Tropheryma 
                  0.058521051                   0.044634703 
               Pseudonocardia                 Porphyromonas 
                  0.030507589                   0.029154977 
                     Moryella                   Actinomyces 
                  0.028339111                   0.026242650 
                Streptococcus                 Fusobacterium 
                  0.018391598                   0.017037620 
                    Treponema     TM7_genera_incertae_sedis 
                  0.015842646                   0.015657748 
                Campylobacter                    Filifactor 
                  0.015546632                   0.015182717 
                  Selenomonas                    Prevotella 
                  0.014962423                   0.014679511 
                    Cato

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia             Propionibacterium 
                  0.283570571                   0.076327137 
                 Oribacterium                    Tropheryma 
                  0.059313531                   0.057557284 
               Pseudonocardia                   Actinomyces 
                  0.035686443                   0.034234622 
                     Moryella                    Filifactor 
                  0.032950027                   0.028885575 
                Porphyromonas                    Prevotella 
                  0.022090515                   0.018303345 
                Fusobacterium                Granulicatella 
                  0.016138116                   0.015857687 
                Campylobacter                 Streptococcus 
                  0.015418578                   0.014711492 
               Capnocytophaga     TM7_genera_incertae_sedis 
                  0.014693472                   0.014431803 
                    Trep

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases

Setting direction: controls < cases



                       Rothia                  Oribacterium 
                 0.3277749165                  0.0817062589 
            Propionibacterium                    Tropheryma 
                 0.0505821868                  0.0461964237 
                   Filifactor                 Porphyromonas 
                 0.0315109622                  0.0277459448 
                     Moryella                   Actinomyces 
                 0.0262503598                  0.0246799222 
    TM7_genera_incertae_sedis                Pseudonocardia 
                 0.0178009642                  0.0164281490 
                Streptococcus                    Prevotella 
                 0.0162993604                  0.0161351316 
                    Catonella                 Fusobacterium 
                 0.0153648078                  0.0153062654 
    OD1_genera_incertae_sedis                   Veillonella 
                 0.0151877041                  0.0145804181 
               Granulica

In [9]:
results_table = res %>% t %>% data.frame
colnames(results_table) = c('Train sensitivity', 'Train sensitivity', 'Train gmeasure', 'Train accuracy', 'Train AUC',
                            'Test sensitivity', 'Test sensitivity', 'Test gmeasure', 'Test accuracy', 'Test AUC',
                            'Taxa selection sensitivity','Taxa selection sensitivity',
                            'Taxa selection gmeasure', 'Taxa selection accuracy')
results_table

Train sensitivity,Train sensitivity.1,Train gmeasure,Train accuracy,Train AUC,Test sensitivity,Test sensitivity.1,Test gmeasure,Test accuracy,Test AUC,Taxa selection sensitivity,Taxa selection sensitivity.1,Taxa selection gmeasure,Taxa selection accuracy
0.9622642,0.872807,0.4591837,0.8693333,0.7719096,0.9552239,0.8860759,0.4324324,0.856,0.7356903,1,0.05263158,0.2294157,0.625
0.9688995,0.8484848,0.5373134,0.86,0.8030542,0.9635036,0.8428571,0.4651163,0.844,0.752269,1,0.05263158,0.2294157,0.625
0.9528536,0.8917749,0.4396552,0.8546667,0.7817346,0.9801325,0.884058,0.4,0.884,0.7735377,1,0.05263158,0.2294157,0.625
0.9528302,0.8779343,0.4424779,0.8546667,0.7511397,0.9473684,0.8860759,0.5526316,0.868,0.7886551,1,0.05263158,0.2294157,0.625
0.9413203,0.9043062,0.5151515,0.856,0.7901291,0.9370629,0.8888889,0.5714286,0.872,0.8129195,1,0.05263158,0.2294157,0.625
0.9636804,0.8584475,0.4915254,0.8586667,0.7829609,0.9770992,0.8472222,0.5957447,0.868,0.8316697,1,0.05263158,0.2294157,0.625
0.9625293,0.8917526,0.3875969,0.8453333,0.7617833,0.9259259,0.8441558,0.6578947,0.86,0.8349595,1,0.05263158,0.2294157,0.625
0.9610092,0.8627451,0.4090909,0.8533333,0.7565616,0.9770992,0.8955224,0.4230769,0.84,0.7789064,1,0.05263158,0.2294157,0.625
0.9574468,0.8590909,0.4672897,0.8586667,0.7784846,0.9469697,0.8918919,0.5681818,0.864,0.7909244,1,0.05263158,0.2294157,0.625
0.9264706,0.9118943,0.5304348,0.8613333,0.7933565,0.9485294,0.9277108,0.483871,0.884,0.7919543,1,0.05263158,0.2294157,0.625


In [10]:
print('Mean')
apply(results_table, 2, mean)

[1] "Mean"


In [11]:
print('SD')
apply(results_table, 2, sd)

[1] "SD"
