# Preambule

## libraries

In [1]:
library(tidyverse)
library(stringr)
library(caret)
library(data.table)
library(stringr)
library(dplyr)
library(randomForest)
library(qs)
library(parallel)
library(clustermq)
library(inTrees)
library(RRF)

── [1mAttaching packages[22m ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.3.0 ──

[32m✔[39m [34mggplot2[39m 3.3.2     [32m✔[39m [34mpurrr  [39m 0.3.4
[32m✔[39m [34mtibble [39m 3.0.4     [32m✔[39m [34mdplyr  [39m 1.0.2
[32m✔[39m [34mtidyr  [39m 1.1.2     [32m✔[39m [34mstringr[39m 1.4.0
[32m✔[39m [34mreadr  [39m 1.4.0     [32m✔[39m [34mforcats[39m 0.5.0

── [1mConflicts[22m ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()

Loading required package: lattice


Attaching package: ‘caret’


The following object is masked from ‘package:purrr’:

    lift



Attaching package: ‘data.table’


T

In [2]:
sessionInfo()

R version 4.0.3 (2020-10-10)
Platform: x86_64-conda-linux-gnu (64-bit)
Running under: Ubuntu 18.04.6 LTS

Matrix products: default
BLAS/LAPACK: /ebio/abt3_projects/Methanogen_SCFA/Metagenomes_methanogen/envs/r-ml/lib/libopenblasp-r0.3.10.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] parallel  stats     graphics  grDevices utils     datasets  methods  
[8] base     

other attached packages:
 [1] RRF_1.9.1           inTrees_1.3         clustermq_0.8.95.1 
 [4] qs_0.23.4           randomForest_4.6-14 data.table_1.13.4  
 [7] caret_6.0-86        lattice_0.20-41     forcats_0.5.0      
[10] stringr_1.4.0       dplyr_1.0.2         purrr_0.3.4      

In [3]:
options(clustermq.scheduler = "sge", clustermq.template = "~/.clustermq.tmpl")

# Data

In [4]:
taxa <- qread('../data/taxa_table.qs')
pp <- qread('../data/tax_meta.qs')

In [5]:
# remove the MetaCyc pathways
to_rm <- seq(pp %>% colnames %>% str_which(pattern = '^[:lower:]\\_{1}') %>% max, ncol(pp), 1)
pp <- pp[,-to_rm]
pp %>% dim

In [6]:
# remove the metadata
meta <- c('dataset_name','Sample','age','gender','country','BMI','westernized')
X <- select(pp, -all_of(meta))

In [7]:
# keep only taxa with prevalance > 24%
tmp <- colnames(X)[which(colSums(X != 0) > nrow(X)/4)]
X <- select(X, all_of(tmp))
X %>% dim

In [8]:
# shuffle samples
set.seed(0)
X <- X[complete.cases(X),][sample(1:nrow(X)),]

# Make the target

In [9]:
# draw random variables for predicting groups
set.seed(1209)
var_ix <- sample(which(colSums(X != 0) > nrow(X)/2), 9, replace = FALSE)
var_n <- colnames(X)[var_ix]

In [10]:
var_n

In [11]:
nr <- nrow(X)
ng <- floor(nr/4)
# make groups
X <- as.data.table(X)[,'group':=c(rep('a',ng), rep('b', ng), rep('c', ng), rep('d', nr-3*ng))]

In [12]:
target <- data.frame('group'= X$group, 'tc'='1', stringsAsFactors = FALSE)

In [13]:
# for each group, make target according to random drawn taxa of var_n
target$tc[target$group == 'a'] <- ifelse( X$s_Marvinbryantia_sp900066075[target$group == 'a'] > 0 
                                         & X$g_Alistipes_A[target$group == 'a'] > 0 
                                         , '1', '-1')
target$tc[target$group == 'b'] <- ifelse( X$f_Bacteroidaceae[target$group == 'b'] > 10^-(1) 
                                         & X$g_Dialister[target$group == 'b'] > 10^-(2.5)
                                         , '1', '-1')
target$tc[target$group == 'c'] <- ifelse((X$s_Oscillibacter_sp001916835[target$group == 'c'] > 0 
                                         & X$s_Bacteroides_clarus[target$group == 'c'] > 0)
                                         | X$s_Faecalibacterium_prausnitzii_G[target$group == 'c'] >10^-2
                                         , '1', '-1')
target$tc[target$group == 'd'] <- ifelse( X$s_Lawsonibacter_sp000177015[target$group == 'd'] <= 10^-3.4 
                                         & X$f_Anaerovoracaceae[target$group == 'd'] > 0
                                         , '1', '-1')

In [14]:
table(target$tc, target$group)

    
       a   b   c   d
  -1 319 299 214 361
  1  217 237 322 178

In [15]:
# randomise group labels
groups <- c('a', 'b', 'c', 'd', 'e')
set.seed(0)
brnounou <- rbinom(n = length(X$group), size = 1,prob = 0.05)
for (i in 1:length(brnounou)){
    if (brnounou[i] == 1){
        set.seed(i)
        X$group[i] <- sample(groups[groups != X$group[i]], 1)
    }
}

In [16]:
table(target$tc, X$group)

    
       a   b   c   d   e
  -1 312 296 226 347  12
  1  219 232 318 175  10

# Get taxonomic lists

In [20]:
source('../../Common_scripts//get_taxa_lists.R')

In [21]:
tax_names <- taxa %>% select(f,g,s)
tax_names$f <- endoR::compatibleNames(paste0('f_',tax_names$f))
tax_names$g <- endoR::compatibleNames(paste0('g_',tax_names$g))
tax_names$s <- endoR::compatibleNames(paste0('s_',tax_names$s))

“replacing previous import ‘data.table::last’ by ‘dplyr::last’ when loading ‘endoR’”
“replacing previous import ‘data.table::first’ by ‘dplyr::first’ when loading ‘endoR’”
“replacing previous import ‘data.table::between’ by ‘dplyr::between’ when loading ‘endoR’”
“replacing previous import ‘dplyr::union’ by ‘igraph::union’ when loading ‘endoR’”
“replacing previous import ‘dplyr::as_data_frame’ by ‘igraph::as_data_frame’ when loading ‘endoR’”
“replacing previous import ‘dplyr::groups’ by ‘igraph::groups’ when loading ‘endoR’”


In [22]:
families <- lapply(unique(tax_names$f), getFamilies, tax_names = tax_names)
names(families) <- unique(tax_names$f)
genera <- lapply(unique(tax_names$g), getGenera, tax_names = tax_names)
names(genera) <- unique(tax_names$g)
species <- lapply(unique(tax_names$s), getSpecies, tax_names = tax_names)
names(species) <- unique(tax_names$s)

# Train 

## data

In [23]:
target_c <- as.factor(target$tc)
X <- X[, 'group':= as.factor(group)]

In [24]:
# transform to dummy
dummies <- dummyVars(~ ., data = X )
dummies <- as.data.table(predict(dummies, newdata = X ))

In [25]:
colnames(dummies) <- colnames(dummies) %>% str_replace_all(pattern = '\\.', replacement ='')

## CV

In [26]:
tmpl <- list(conda = "r-ml", cores = 1, job_time = '00:59:00', job_mem = '1G')

In [27]:
wf <- function(ix, data, target, families, genera, species, ntree = 500, gamma = 1, k = 0.5){
    set.seed(ix[1])
    res <- list()
    
    # Feature selection
    message('Feature selection')
    RF <- RRF(data[ix,], flagReg=0, as.factor(target[ix]))
    regterm <- data.frame(Feature = names(RF$importance[,"MeanDecreaseGini"])
                     , imp = RF$importance[,"MeanDecreaseGini"])
    # normalization across all features
    regterm$imp_norm <- (regterm$imp - min(regterm$imp))/(max(regterm$imp) - min(regterm$imp))
    # normalization per branch
    regterm$imp_tax <- NA
    regterm$mb <- NA
    for (x in regterm$Feature){
        tmp <- unique(c(families[[x]], genera[[x]], species[[x]]))
        ib <- which(regterm$Feature %in% tmp)
        if (length(ib) == 0) next
        mb <- max(regterm$imp[ib])

        ib <- which(regterm$Feature == x)
        regterm$mb[ib] <- mb
        regterm$imp_tax[ib] <- regterm$imp[ib] / mb
    }
    regterm$imp_tax <- ifelse(is.na(regterm$imp_tax), regterm$imp_norm, regterm$imp_tax)
    # regularization term
    regterm$lambda <- (1-gamma) + gamma*regterm$imp_norm^(1-k)*regterm$imp_tax^k
    GRRF <- RRF(data[ix,], as.factor(target[ix]), flagReg=1, coefReg=regterm$lambda)
    
    
    # Select data
    message('Subset data')
    to_keep <- colnames(data)[GRRF$feaSet]
    X_fs <- select(data, all_of(to_keep))
    res$confirmed <- to_keep
    
    # RF
    message('RF')
    rf_fs <- randomForest(x = X_fs[ix,], y = target[ix], ntree = ntree)
    pred <- predict(object = rf_fs, newdata = X_fs[-ix, ])
    tmp <- confusionMatrix(data = pred, reference = target[-ix])
    res$rf_performance <- tmp$overall
    
    return(res)
}

In [28]:
gammaTuning <- function(trainIx, data, target, families, genera, species, gamma = 1, k = seq(0,1,0.25), ntree = 500){
    res <- list()
    for (kk in k){
        res[[as.character(kk)]] <- lapply(trainIx, wf, data=data, target=target
                                          , families = families, genera = genera, species = species
                                          , gamma=gamma, k = kk, ntree=ntree)
    }
    return(res)
}

In [29]:
set.seed(0)
trainIx <- createDataPartition(y = target_c, times = 10, p = .7, list = TRUE)

In [30]:
gammas <- seq(0,1, by = 0.05)
ks <- seq(0,1,0.25)

In [33]:
res <- Q(gammaTuning
  , gamma = gammas
  , const = list('data'= dummies, 'target' = target_c, 'trainIx' = trainIx
                 , 'families' = families, 'genera' = genera, 'species' = species
                 , 'k' = ks, 'ntree' = 500)
  , export = list('wf' = wf)
  , n_jobs= length(gammas)
  , pkgs=c('caret', 'randomForest', 'dplyr', 'RRF')
  , log_worker=FALSE
  , template = tmpl
 )

Submitting 21 worker jobs (ID: cmq6799) ...

Running 21 calculations (5 objs/8.7 Mb common; 1 calls/chunk) ...


[---------------------------------------------------]   0% (1/21 wrk) eta:  ?s

[---------------------------------------------------]   0% (2/21 wrk) eta:  ?s

[---------------------------------------------------]   0% (3/21 wrk) eta:  ?s

[---------------------------------------------------]   0% (4/21 wrk) eta:  ?s

[---------------------------------------------------]   0% (5/21 wrk) eta:  ?s

[---------------------------------------------------]   0% (6/21 wrk) eta:  ?s

[---------------------------------------------------]   0% (7/21 wrk) eta:  ?s

[---------------------------------------------------]   0% (8/21 wrk) eta:  ?s

[---------------------------------------------------]   0% (9/21 wrk) eta:  ?s

[--------------------------------------------------]   0% (10/21 wrk) eta:  ?s

[--------------------------------------------------]   0% (11/21 wrk) eta:  ?s

[------

# Results

In [25]:
all <- list()
for (i in 1:length(res)){
    all[[i]] <- list()
    for (j in 1:length(res[[i]])){
        tmp <- t(sapply(res[[i]][[j]], function(x){return(x$rf_performance)}))
        tmp <- data.frame(meanAcc = mean(tmp[,1]), sdAcc = sd(tmp[,1])
                          , meanK = mean(tmp[,2]), sdK = sd(tmp[,2]), ks = ks[j])
        all[[i]][[j]] <- tmp
    }
    all[[i]] <- do.call(rbind, all[[i]])
    all[[i]]$gamma <- gammas[i]
}
all <- do.call(rbind, all)

In [44]:
all %>% group_by(ks) %>% summarise_all(mean)

ks,meanAcc,sdAcc,meanK,sdK,gamma
<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>
0.0,0.7814782,0.03154637,0.5525106,0.06493909,0.5
0.25,0.6759017,0.01527241,0.3353188,0.03161351,0.5
0.5,0.6752722,0.01455877,0.3333558,0.03010379,0.5
0.75,0.6738132,0.01201444,0.3296025,0.02468818,0.5
1.0,0.666563,0.01118714,0.3135224,0.02221691,0.5


In [49]:
all %>% subset(ks>0) %>% arrange(-meanK) %>% head

Unnamed: 0_level_0,meanAcc,sdAcc,meanK,sdK,ks,gamma
Unnamed: 0_level_1,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>
1,0.7306376,0.03185376,0.4449608,0.06526011,0.25,0.25
2,0.7297045,0.02737052,0.4434044,0.05731672,0.25,0.2
3,0.7283048,0.01515033,0.4411198,0.03150921,0.25,0.15
4,0.7214619,0.01205664,0.4266573,0.02483762,0.5,0.15
5,0.7150855,0.01462228,0.4140012,0.03114154,0.25,0.1
6,0.7138414,0.01415915,0.4113256,0.02816726,0.75,0.15


# FS all data

In [58]:
gamma <- 0.25; k <- 0.25

In [51]:
RF <- RRF(dummies, flagReg=0, as.factor(target_c))
regterm <- data.frame(Feature = names(RF$importance[,"MeanDecreaseGini"])
 , imp = RF$importance[,"MeanDecreaseGini"])

In [52]:
# normalization across all features
regterm$imp_norm <- (regterm$imp - min(regterm$imp))/(max(regterm$imp) - min(regterm$imp))

In [56]:
# normalization per branch
regterm$imp_tax <- NA
regterm$mb <- NA
for (x in regterm$Feature){
    tmp <- unique(c(families[[x]], genera[[x]], species[[x]]))
    ib <- which(regterm$Feature %in% tmp)
    if (length(ib) == 0) next
    mb <- max(regterm$imp[ib])

    ib <- which(regterm$Feature == x)
    regterm$mb[ib] <- mb
    regterm$imp_tax[ib] <- regterm$imp[ib] / mb
}
regterm$imp_tax <- ifelse(is.na(regterm$imp_tax), regterm$imp_norm, regterm$imp_tax)

In [59]:
# regularization term
regterm$lambda <- (1-gamma) + gamma*regterm$imp_norm^(1-k)*regterm$imp_tax^k
GRRF <- RRF(dummies, as.factor(target_c), flagReg=1, coefReg=regterm$lambda)

In [60]:
# Select data
message('Subset data')
to_keep <- colnames(dummies)[GRRF$feaSet]
length(to_keep)

Subset data

