# Scenario 1. Regression design

## Strategy 1. Microbiome taxa associated with outcome $y$ are clusterd at the phylum level.

In [1]:
R.version
library(tidyverse)
library(caret)
library(randomForest)
library(reticulate)
np <- import("numpy")

               _                           
platform       x86_64-pc-linux-gnu         
arch           x86_64                      
os             linux-gnu                   
system         x86_64, linux-gnu           
status                                     
major          3                           
minor          5.3                         
year           2019                        
month          03                          
day            11                          
svn rev        76217                       
language       R                           
version.string R version 3.5.3 (2019-03-11)
nickname       Great Truth                 

── [1mAttaching packages[22m ─────────────────────────────────────── tidyverse 1.2.1 ──

[32m✔[39m [34mggplot2[39m 3.2.1     [32m✔[39m [34mpurrr  [39m 0.3.2
[32m✔[39m [34mtibble [39m 2.1.3     [32m✔[39m [34mdplyr  [39m 0.8.3
[32m✔[39m [34mtidyr  [39m 1.0.2     [32m✔[39m [34mstringr[39m 1.4.0
[32m✔[39m [34mreadr  [39m 1.3.1     [32m✔[39m [34mforcats[39m 0.4.0

── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()

Loading required package: lattice


Attaching package: ‘caret’


The following object is masked from ‘package:purrr’:

    lift


randomForest 4.6-14

Type rfNews() to see new features/changes/bug fixes.


Attaching package: ‘randomForest’


The following object is masked from ‘package:dplyr’:

    combine


The following object is masked from ‘package:g

### Configuration

In [2]:
data_path = "./data/simulation/s0"
path_genus = "./data/genus48"
count_path = 'data/simulation/count/'

y_path = sprintf('%s/%s', data_path, 'y.csv')
tree_info_path = './data/genus48/genus48_dic.csv'
count_path = './data/simulation/count'
count_list_path = './data/simulation/gcount_list.csv'
idx_path = './data/simulation/s0/idx.csv'

num_classes = 0 # regression
tree_level_list = c('Genus', 'Family', 'Order', 'Class', 'Phylum')

In [3]:
# # Read phylogenetic tree information

# phylogenetic_tree_info = read.csv(tree_info_path)
# phylogenetic_tree_info = phylogenetic_tree_info %>% select(tree_level_list)

# print(sprintf('Phylogenetic tree level list: %s', str_c(phylogenetic_tree_info %>% colnames, collapse = ', ')))

### Read dataset

#### Read training, test dataset

In [4]:
read_dataset <- function(x_path, y_path, sim){
    print(str_c('Load data for repetition ', sim))
    x = read.csv(x_path)
    y = read.csv(y_path)[,sim]
    x = (x - max(x)) / (max(x) - min(x))

    idxs = idxs_total[, sim]
    remain_idxs = setdiff(seq(1, dim(x)[1]), idxs)

    x_train = x[idxs,]
    x_test = x[remain_idxs,]
    y_train = y[idxs]
    y_test = y[remain_idxs]
    
    return (list(x_train, x_test, y_train, y_test))
}

In [5]:
idxs_total = read.csv(idx_path)
number_of_fold = dim(idxs_total)[2]; number_of_fold
x_list = read.csv(count_list_path, header = FALSE)
x_path = x_list$V1 %>% sprintf('%s/%s', count_path, .)

#### Read true tree weight

In [6]:
tw_1 = np$load(sprintf('%s/tw_1.npy', data_path))

## Random Forest

### Importance type

* See <https://stats.stackexchange.com/questions/92419/relative-importance-of-a-set-of-predictors-in-a-random-forests-classification-in>

Here are the definitions of the variable importance measures.

- `type=1`: **Mean decrease in accuracy**
    - The first measure is computed from permuting Out-of-bag (OOB) data.
    - For each tree, the prediction error on the out-of-bag portion of the data is recorded (error rate for classification, MSE for regression). Then the same is done after permuting each predictor variable. 
    - The difference between the two are then averaged over all trees, and normalized by the standard deviation of the differences.
    - If the standard deviation of the differences is equal to 0 for a variable, the division is not done (but the average is almost always equal to 0 in that case).

- `type=2`: **Mean decrease in node impurity**
    - The second measure is the total decrease in node impurities from splitting on the variable, averaged over all trees. 
    - For classification, the node impurity is measured by the Gini index. 
    - For regression, it is measured by residual sum of squares.
    
### Feature selection

* `vi_f`: variable importance by Gini importance (`type=2`)
* `relative_vi_f` : relative variable importance
* `thrd`: threshold for relative variable importance
* Select features which have relative variable importance `relative_vi_f` equal or larger than threshold `thrd`

### Simulate for all $n$

In [7]:
random_forest_res <- function(fold, importance_type=2, fs_thrd = 0.1){
    print(sprintf('-----------------------------------------------------------------'))
    print(sprintf('Random Forest computation for %dth repetition', fold))

    dataset = read_dataset(x_path[fold], y_path, fold)
    x_train = dataset[[1]]
    x_test = dataset[[2]]
    y_train = dataset[[3]]
    y_test = dataset[[4]]

    fit.rf <- randomForest(y_train~.,data=x_train, ntree=1000,  mtry=10, importance=TRUE)
    train.pred <- fit.rf$predicted
    test.pred <- predict(fit.rf,x_test)

    train.mse <- mean((y_train - train.pred)^2)
    train.cor <- cor(y_train, train.pred)

    test.mse <- mean((y_test - test.pred)^2)
    test.cor <- cor(y_test, test.pred)
    
    # Feature selection
    ## variable importance
    vi_f = importance(fit.rf, type=importance_type)
    relative_vi_f <- vi_f / sum(vi_f)
    selected_genus <- ifelse(relative_vi_f >= fs_thrd, 1, 0)
    
    order <- order(relative_vi_f, decreasing = TRUE)
    sorted_relative_vi_f <- relative_vi_f[order]
    names(sorted_relative_vi_f) <- colnames(x_train)[order]
    print(sorted_relative_vi_f)

    fold_genus = apply(tw_1[fold,,], 1, sum)
    names(fold_genus) <- x_train %>% colnames

    fs_conf_table <- table(selected_genus, fold_genus)
    
    fs_sensitivity <- sensitivity(fs_conf_table) 
    fs_specificity <- specificity(fs_conf_table)
    fs_gmeasure <- sqrt(fs_sensitivity*fs_specificity)
    fs_accuracy <- sum(diag(fs_conf_table))/sum(fs_conf_table)

    print(sprintf('Train mse: %s, Train Correlation: %s', train.mse, train.cor))
    print(sprintf('Test mse: %s, Test Correlation: %s', test.mse, test.cor))
    print(sprintf('FS sensitivity: %s, FS sensitivity: %s, FS gmeasure: %s, FS accuracy: %s',
                  fs_sensitivity, fs_specificity, fs_gmeasure, fs_accuracy))
    
    return (c(train.mse, train.cor, test.mse, test.cor, fs_sensitivity, fs_specificity, fs_gmeasure, fs_accuracy))
}

In [8]:
set.seed(100)
# res <- sapply(seq(1,1), random_forest_res)
res <- sapply(seq(1,10), random_forest_res)
# res <- sapply(seq(1,number_of_fold), random_forest_res)

[1] "-----------------------------------------------------------------"
[1] "Random Forest computation for 1th repetition"
[1] "Load data for repetition 1"
                   Tropheryma                 Porphyromonas 
                  0.516617258                   0.146956055 
                    Neisseria                   Veillonella 
                  0.019556406                   0.018810278 
                 Oribacterium                Pseudonocardia 
                  0.018447398                   0.016529154 
                Solobacterium                 Streptococcus 
                  0.014603530                   0.012727190 
                  Actinomyces                Granulicatella 
                  0.012249896                   0.010424407 
                Fusobacterium     OD1_genera_incertae_sedis 
                  0.010391410                   0.010038766 
                  Selenomonas                    Parvimonas 
                  0.009700077                   0.0

In [9]:
results_table = res %>% t %>% data.frame
colnames(results_table) = c('Training MSE', 'Training Correlation', 'Test MSE', 'Test Correlation', 
                            'Taxa selection sensitivity','Taxa selection sensitivity',
                            'Taxa selection gmeasure', 'Taxa selection accuracy')
results_table

Training MSE,Training Correlation,Test MSE,Test Correlation,Taxa selection sensitivity,Taxa selection sensitivity.1,Taxa selection gmeasure,Taxa selection accuracy
0.06279524,0.9549467,0.05962205,0.875114,1,0.06451613,0.2540003,0.3958333
0.08096974,0.9028268,0.0836503,0.9606535,1,0.06451613,0.2540003,0.3958333
0.0718343,0.9356166,0.04600076,0.8763665,1,0.06451613,0.2540003,0.3958333
0.07237074,0.9204837,0.01753623,0.9416163,1,0.06451613,0.2540003,0.3958333
0.07409322,0.9043939,0.06539333,0.9269943,1,0.06451613,0.2540003,0.3958333
0.05716759,0.9396911,0.04576656,0.8940942,1,0.06451613,0.2540003,0.3958333
0.0773566,0.9272322,0.03239857,0.8931325,1,0.06451613,0.2540003,0.3958333
0.06505412,0.9367583,0.10564457,0.8517679,1,0.06451613,0.2540003,0.3958333
0.0805552,0.8980913,0.07778365,0.9672162,1,0.06451613,0.2540003,0.3958333
0.05570161,0.8980644,0.1136503,0.8980928,1,0.06451613,0.2540003,0.3958333


In [10]:
print('Mean')
apply(results_table, 2, mean)

[1] "Mean"


In [11]:
print('SD')
apply(results_table, 2, sd)

[1] "SD"


### Summarize

#### Prediction performance

|                                   | Testing |       |             |       | Training |       |             |       |
|-----------------------------------|:-------:|-------|-------------|-------|:--------:|-------|-------------|-------|
| Method                            |   MSE   |       | Correlation |       |    MSE   |       | Correlation |       |
|                                   | mean    | sd    | mean        | sd    | mean     | sd    | mean        | sd    |
| Linear Regression                 | 0.104   | 0.024 | 0.824       | 0.049 | 0.087    | 0.011 | 0.851       | 0.023 |
| Ridge                             | 0.104   | 0.022 | 0.824       | 0.049 | 0.09     | 0.012 | 0.851       | 0.023 |
| Lasso                             | 0.100   | 0.023 | 0.833       | 0.048 | 0.092    | 0.013 | 0.843       | 0.025 |
| Elastic Net                       | 0.100   | 0.023 | 0.833       | 0.048 | 0.092    | 0.012 | 0.844       | 0.025 |
| Random Forest                     |         |       |             |       |          |       |             |       |
| Random Forest + Feature Selection |         |       |             |       |          |       |             |       |
| DNN                               | 0.076   | 0.040 | 0.874       | 0.077 | 0.032    | 0.034 | 0.947       | 0.067 |
| DNN+$\ell_1$                      | 0.075   | 0.040 | 0.875       | 0.073 | 0.034    | 0.039 | 0.945       | 0.068 |
| DeepBiome                         | 0.071   | 0.036 | 0.882       | 0.069 | 0.043    | 0.034 | 0.929       | 0.061 |

#### Taxa selection performance

|              | PhyloTree | No. true taxa (total) | Sensitivity |       | Specificity |       | g-Measure |       |  ACC  |       |
|--------------|-----------|:---------------------:|:-----------:|-------|:-----------:|-------|:---------:|-------|:-----:|-------|
| Mdthod       |           |                       | mean        | sd    | mean        | sd    | mean      | sd    | mean  | sd    |
| Lasso        | Genus     |        31 (48)        | 0.380       | 0.136 | 0.812       | 0.15  | 0.536     | 0.083 | 0.533 | 0.064 |
|              | Family    |        23 (40)        | 0.474       | 0.150 | 0.812       | 0.150 | 0.602     | 0.086 | 0.618 | 0.065 |
|              | Order     |         9 (23)        | 0.637       | 0.169 | 0.783       | 0.169 | 0.688     | 0.092 | 0.726 | 0.084 |
|              | Class     |         7 (17)        | 0.739       | 0.161 | 0.730       | 0.197 | 0.715     | 0.105 | 0.734 | 0.099 |
| Elastic-Net  | Genus     |        31 (48)        | 0.389       | 0.138 | 0.803       | 0.158 | 0.540     | 0.075 | 0.536 | 0.063 |
|              | Family    |        23 (40)        | 0.484       | 0.149 | 0.803       | 0.158 | 0.605     | 0.078 | 0.620 | 0.063 |
|              | Order     |         9 (23)        | 0.646       | 0.159 | 0.774       | 0.174 | 0.691     | 0.088 | 0.724 | 0.088 |
|              | Class     |         7 (17)        | 0.750       | 0.149 | 0.720       | 0.201 | 0.717     | 0.107 | 0.733 | 0.104 |
| Random Forest| Genus     |        31 (48)        |             |       |             |       |           |       |       |       |
|              | Family    |        23 (40)        |             |       |             |       |           |       |       |       |
|              | Order     |         9 (23)        |             |       |             |       |           |       |       |       |
|              | Class     |         7 (17)        |             |       |             |       |           |       |       |       |
| Random Forest| Genus     |        31 (48)        |             |       |             |       |           |       |       |       |
| + Feature    | Family    |        23 (40)        |             |       |             |       |           |       |       |       |
| Selection    | Order     |         9 (23)        |             |       |             |       |           |       |       |       |
|              | Class     |         7 (17)        |             |       |             |       |           |       |       |       |
| DNN+$\ell_1$ | Genus     |        31 (48)        | 0.967       | 0.032 | 0.034       | 0.006 | 0.181     | 0.016 | 0.049 | 0.006 |
|              | Family    |        23 (40)        | 0.970       | 0.036 | 0.031       | 0.006 | 0.174     | 0.017 | 0.055 | 0.006 |
|              | Order     |         9 (23)        | 0.972       | 0.055 | 0.026       | 0.008 | 0.156     | 0.026 | 0.048 | 0.008 |
|              | Class     |         7 (17)        | 0.978       | 0.056 | 0.021       | 0.012 | 0.136     | 0.047 | 0.065 | 0.012 |
| DeepBiome    | Genus     |        31 (48)        | 0.954       | 0.042 | 0.669       | 0.087 | 0.797     | 0.053 | 0.673 | 0.085 |
|              | Family    |        23 (40)        | 0.967       | 0.037 | 0.828       | 0.062 | 0.894     | 0.037 | 0.832 | 0.060 |
|              | Order     |         9 (23)        | 0.970       | 0.058 | 0.855       | 0.057 | 0.910     | 0.042 | 0.858 | 0.056 |
|              | Class     |         7 (17)        | 0.983       | 0.050 | 0.835       | 0.063 | 0.905     | 0.043 | 0.842 | 0.060 |