# Scenario 2. Binary classification

In [1]:
R.version
library(tidyverse)
library(caret)
library(randomForest)
library(reticulate)
library(mltools)
np <- import("numpy")

               _                           
platform       x86_64-pc-linux-gnu         
arch           x86_64                      
os             linux-gnu                   
system         x86_64, linux-gnu           
status                                     
major          3                           
minor          5.3                         
year           2019                        
month          03                          
day            11                          
svn rev        76217                       
language       R                           
version.string R version 3.5.3 (2019-03-11)
nickname       Great Truth                 

── [1mAttaching packages[22m ─────────────────────────────────────── tidyverse 1.2.1 ──

[32m✔[39m [34mggplot2[39m 3.2.1     [32m✔[39m [34mpurrr  [39m 0.3.2
[32m✔[39m [34mtibble [39m 2.1.3     [32m✔[39m [34mdplyr  [39m 0.8.3
[32m✔[39m [34mtidyr  [39m 1.0.2     [32m✔[39m [34mstringr[39m 1.4.0
[32m✔[39m [34mreadr  [39m 1.3.1     [32m✔[39m [34mforcats[39m 0.4.0

── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()

Loading required package: lattice


Attaching package: ‘caret’


The following object is masked from ‘package:purrr’:

    lift


randomForest 4.6-14

Type rfNews() to see new features/changes/bug fixes.


Attaching package: ‘randomForest’


The following object is masked from ‘package:dplyr’:

    combine


The following object is masked from ‘package:g

### Configuration

In [2]:
data_path = "./data/simulation/s2"
path_genus = "./data/genus48"
count_path = 'data/simulation/count/'

y_path = sprintf('%s/%s', data_path, 'y.csv')
tree_info_path = './data/genus48/genus48_dic.csv'
count_path = './data/simulation/count'
count_list_path = './data/simulation/gcount_list.csv'
idx_path = './data/simulation/s1/idx.csv'

num_classes = 0 # regression
tree_level_list = c('Genus', 'Family', 'Order', 'Class', 'Phylum')

In [3]:
# # Read phylogenetic tree information

# phylogenetic_tree_info = read.csv(tree_info_path)
# phylogenetic_tree_info = phylogenetic_tree_info %>% select(tree_level_list)

# print(sprintf('Phylogenetic tree level list: %s', str_c(phylogenetic_tree_info %>% colnames, collapse = ', ')))

### Read dataset

#### Read training, test dataset

In [4]:
read_dataset <- function(x_path, y_path, sim){
    print(str_c('Load data for repetition ', sim))
    x = read.csv(x_path)
    y = read.csv(y_path)[,sim]
    x = (x - max(x)) / (max(x) - min(x))

    idxs = idxs_total[, sim]
    remain_idxs = setdiff(seq(1, dim(x)[1]), idxs)

    x_train = x[idxs,]
    x_test = x[remain_idxs,]
    y_train = y[idxs]
    y_test = y[remain_idxs]
    
    return (list(x_train, x_test, y_train, y_test))
}

In [5]:
idxs_total = read.csv(idx_path)
number_of_fold = dim(idxs_total)[2]; number_of_fold
x_list = read.csv(count_list_path, header = FALSE)
x_path = x_list$V1 %>% sprintf('%s/%s', count_path, .)

#### Read true tree weight

In [6]:
tw_1 = np$load(sprintf('%s/tw_1.npy', data_path))

## Random Forest

### Importance type

* See <https://stats.stackexchange.com/questions/92419/relative-importance-of-a-set-of-predictors-in-a-random-forests-classification-in>

Here are the definitions of the variable importance measures.

- `type=1`: **Mean decrease in accuracy**
    - The first measure is computed from permuting Out-of-bag (OOB) data.
    - For each tree, the prediction error on the out-of-bag portion of the data is recorded (error rate for classification, MSE for regression). Then the same is done after permuting each predictor variable. 
    - The difference between the two are then averaged over all trees, and normalized by the standard deviation of the differences.
    - If the standard deviation of the differences is equal to 0 for a variable, the division is not done (but the average is almost always equal to 0 in that case).

- `type=2`: **Mean decrease in node impurity**
    - The second measure is the total decrease in node impurities from splitting on the variable, averaged over all trees. 
    - For classification, the node impurity is measured by the Gini index. 
    - For regression, it is measured by residual sum of squares.
    
### Feature selection

* `vi_f`: variable importance by Gini importance (`type=2`)
* `relative_vi_f` : relative variable importance
* `thrd`: threshold for relative variable importance
* Select features which have relative variable importance `relative_vi_f` equal or larger than threshold `thrd`

### Simulate for all $n$

In [7]:
random_forest_res <- function(fold, importance_type=2, fs_thrd = 0.1){
    print(sprintf('-----------------------------------------------------------------'))
    print(sprintf('Random Forest computation for %dth repetition', fold))

    dataset = read_dataset(x_path[fold], y_path, fold)
    x_train = dataset[[1]]
    x_test = dataset[[2]]
    y_train = dataset[[3]]
    y_test = dataset[[4]]
    
    # Binary classification
    y_train = factor(y_train, levels = c(0,1), ordered=TRUE)
    y_test = factor(y_test, levels = c(0,1), ordered=TRUE)

    fit.rf <- randomForest(y_train~.,data=x_train, ntree=1000,  mtry=10, importance=TRUE)
    train.pred <- fit.rf$predicted
    test.pred <- predict(fit.rf,x_test)

    train_conf <- table(train.pred, y_train)
    train_sensitivity <- sensitivity(train_conf)
    train_specificity <- specificity(train_conf)
    train_gmeasure <- sqrt(train_sensitivity*train_specificity)
    train_accuracy <- sum(diag(train_conf))/sum(train_conf)
    train_auc <- auc_roc(train.pred, y_train)

    test_conf <- table(test.pred, y_test)
    test_sensitivity <- sensitivity(test_conf)
    test_specificity <- specificity(test_conf)
    test_gmeasure <- sqrt(test_sensitivity*test_specificity)
    test_accuracy <- sum(diag(test_conf))/sum(test_conf)
    test_auc <- auc_roc(test.pred, y_test)
    
    # Feature selection
    ## variable importance
    vi_f = importance(fit.rf, type=importance_type)
    relative_vi_f <- vi_f / sum(vi_f)
    selected_genus <- ifelse(relative_vi_f >= fs_thrd, 1, 0)
    
    order <- order(relative_vi_f, decreasing = TRUE)
    sorted_relative_vi_f <- relative_vi_f[order]
    names(sorted_relative_vi_f) <- colnames(x_train)[order]
    print(sorted_relative_vi_f)

    fold_genus = apply(tw_1[fold,,], 1, sum)
    names(fold_genus) <- x_train %>% colnames

    fs_conf_table <- table(selected_genus, fold_genus)
    
    fs_sensitivity <- sensitivity(fs_conf_table) 
    fs_specificity <- specificity(fs_conf_table)
    fs_gmeasure <- sqrt(fs_sensitivity*fs_specificity)
    fs_accuracy <- sum(diag(fs_conf_table))/sum(fs_conf_table)

    print(sprintf('Train sensitivity: %s, Train sensitivity: %s, Train gmeasure: %s, Train accuracy: %s, Train AUC: %s',
                  train_sensitivity, train_specificity, train_gmeasure, train_accuracy, train_auc))
    print(sprintf('Test sensitivity: %s, Test sensitivity: %s, Test gmeasure: %s, Test accuracy: %s, Test AUC: %s',
                  test_sensitivity, test_specificity, test_gmeasure, test_accuracy, test_auc))
    print(sprintf('FS sensitivity: %s, FS sensitivity: %s, FS gmeasure: %s, FS accuracy: %s',
                  fs_sensitivity, fs_specificity, fs_gmeasure, fs_accuracy))
    
    return (c(train_sensitivity, train_specificity, train_gmeasure, train_accuracy, train_auc, 
              test_sensitivity, test_specificity, test_gmeasure, test_accuracy, test_auc,
              fs_sensitivity, fs_specificity, fs_gmeasure, fs_accuracy))
}

In [8]:
set.seed(100)
# res <- sapply(seq(1,1), random_forest_res)
res <- sapply(seq(1,10), random_forest_res)
# res <- sapply(seq(1,number_of_fold), random_forest_res)

[1] "-----------------------------------------------------------------"
[1] "Random Forest computation for 1th repetition"
[1] "Load data for repetition 1"
                       Rothia                  Oribacterium 
                  0.330196439                   0.110725797 
            Propionibacterium                    Tropheryma 
                  0.050529145                   0.050367816 
                     Moryella                   Actinomyces 
                  0.032971117                   0.026832935 
                   Filifactor                Pseudonocardia 
                  0.025765962                   0.018555369 
              Corynebacterium                     Treponema 
                  0.016087377                   0.015729644 
                Porphyromonas                    Prevotella 
                  0.014964231                   0.014694758 
                Streptococcus                 Campylobacter 
                  0.014018473                   0.0

In [9]:
results_table = res %>% t %>% data.frame
colnames(results_table) = c('Train sensitivity', 'Train sensitivity', 'Train gmeasure', 'Train accuracy', 'Train AUC',
                            'Test sensitivity', 'Test sensitivity', 'Test gmeasure', 'Test accuracy', 'Test AUC',
                            'Taxa selection sensitivity','Taxa selection sensitivity',
                            'Taxa selection gmeasure', 'Taxa selection accuracy')
results_table

Train sensitivity,Train sensitivity.1,Train gmeasure,Train accuracy,Train AUC,Test sensitivity,Test sensitivity.1,Test gmeasure,Test accuracy,Test AUC,Taxa selection sensitivity,Taxa selection sensitivity.1,Taxa selection gmeasure,Taxa selection accuracy
0.7801724,0.9710425,0.870391,0.912,0.8756074,0.8395062,0.9881657,0.910808,0.94,0.9138359,1,0.10526316,0.3244428,0.6458333
0.7156863,0.974359,0.8350661,0.904,0.8450226,0.7236842,0.9827586,0.8433308,0.904,0.8532214,1,0.10526316,0.3244428,0.6458333
0.7698745,0.9549902,0.8574512,0.896,0.8624323,0.7916667,0.9719101,0.8771709,0.92,0.8817884,1,0.10526316,0.3244428,0.6458333
0.8045455,0.9716981,0.8841806,0.9226667,0.8881218,0.8607595,0.9707602,0.9141067,0.936,0.9157599,1,0.10526316,0.3244428,0.6458333
0.7545455,0.9698113,0.8554336,0.9066667,0.8621784,0.7733333,0.9714286,0.8667399,0.912,0.872381,1,0.10526316,0.3244428,0.6458333
0.7822222,0.9733333,0.8725612,0.916,0.8777778,0.7972973,0.9772727,0.88271,0.924,0.887285,1,0.10526316,0.3244428,0.6458333
0.7804878,0.9541284,0.8629517,0.9066667,0.8673081,0.775,0.9705882,0.867298,0.908,0.8727941,1,0.05263158,0.2294157,0.625
0.7904762,0.9703704,0.8758166,0.92,0.8804233,0.8,0.9666667,0.8793937,0.92,0.8833333,1,0.10526316,0.3244428,0.6458333
0.7982456,0.9597701,0.8752898,0.9106667,0.8790079,0.8666667,0.9714286,0.9175537,0.94,0.9190476,1,0.05263158,0.2294157,0.625
0.8247863,0.9593023,0.8895052,0.9173333,0.8920443,0.8089888,0.9689441,0.8853614,0.912,0.8889664,1,0.10526316,0.3244428,0.6458333


In [10]:
print('Mean')
apply(results_table, 2, mean)

[1] "Mean"


In [11]:
print('SD')
apply(results_table, 2, sd)

[1] "SD"
