# Preambule

In [1]:
library(tidyverse)
library(stringr)
library(data.table)
library(stringr)
library(dplyr)
library(qs)
library(parallel)
library(clustermq)
library(ggpubr)
library(DALEX)
library('iBreakDown')

── [1mAttaching packages[22m ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.3.0 ──

[32m✔[39m [34mggplot2[39m 3.3.2     [32m✔[39m [34mpurrr  [39m 0.3.4
[32m✔[39m [34mtibble [39m 3.0.4     [32m✔[39m [34mdplyr  [39m 1.0.2
[32m✔[39m [34mtidyr  [39m 1.1.2     [32m✔[39m [34mstringr[39m 1.4.0
[32m✔[39m [34mreadr  [39m 1.4.0     [32m✔[39m [34mforcats[39m 0.5.0

── [1mConflicts[22m ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()


Attaching package: ‘data.table’


The following objects are masked from ‘package:dplyr’:

    between, first, last


The following object is masked from ‘package:pur

In [2]:
options(clustermq.scheduler = "sge", clustermq.template = "~/.clustermq.tmpl")

In [3]:
file.sources <- list.files('../scripts/', pattern = '*.R', full.names=TRUE, ignore.case=TRUE)
for (f in file.sources) {
    source(f)
}
rm('file.sources', 'f')

# Data

In [4]:
meta <- c('dataset','Sample','age','gender','country','BMI','westernized', 'number_reads')

In [5]:
rules <- qread('../tmp/ruleExec_full.qs')
X <- qread('../tmp/X.qs')

In [6]:
related_taxa <- qread('../tmp/related_taxa_full.qs')
is_s <- str_which(names(related_taxa), pattern = '^s_')
related_taxa[is_s] <- related_taxa[is_s] %>% lapply(function(x){str_subset(x, pattern = '^s|g')})
related_taxa$groupa <- 'groupa'
related_taxa$groupb <- 'groupb'
related_taxa$groupc <- 'groupc'
related_taxa$groupd <- 'groupd'

In [7]:
related_taxa <- related_taxa[!duplicated(related_taxa) ]

In [8]:
fnames <- list.files('/ebio/abt3_projects/temp_data/aruaud/MtgSimu50/p005_B10/', full.names = TRUE, pattern = 'simu')

In [9]:
length(fnames)

# Functions

In [10]:
getInterpo <- function(x, lim){
    tmp <- approx(x[[1]]$fp, x[[1]]$tp, xout = 1:lim, ties = max)
    tmp <- as.data.frame(do.call(cbind, tmp))
        
    # add max values to the extrapolations
    mafp <- max(tmp$x[!is.na(tmp$y)])
    tmp$y[tmp$x>mafp] <- max(x[[1]]$tp)
    
    # add the proba for the min 
    tmp <- tmp[complete.cases(tmp),]
    mifp <- min(tmp$x[!is.na(tmp$y)])
    if (max(mifp)-1 > 1){
        mitp <- min(tmp$y, na.rm = TRUE)
        x <- 1:(mifp+mitp)
        mifp <- x*mifp/max(x)
        mitp <- x*mitp/max(x)
        tmp <- rbind(tmp, as.data.frame(do.call(cbind,approx(mifp, mitp, xout = 1:(max(mifp)-1), ties = max)) ))
    }
    
    tmp <- tmp %>% add_row(x = 0, y = 0)

    colnames(tmp) <- c('fp', 'tp')
    return(tmp)
}

In [11]:
getPR <- function(thr, res, related_taxa){
    
    tmp <- subset(res, val >= thr)
    if (nrow(tmp) == 0){
        return(c('tp' = NA, 'fp' = NA, 'tn' = NA, 'fn' = NA) )
    }
    nodes <- unique(tmp$var)
    
    # those that should not be but are = in pred_edges but not truth
    fp <- sum(!(nodes %in% related_taxa))
    
    # those that should be and are
    tp <- sum(sapply(related_taxa, function(x){ifelse(sum(x %in% nodes) > 0, 1, 0)}))
    
    # those that should be but are not = in truth but not in pred_edges
    fn <- length(which(!(names(related_taxa) %in% nodes)))
    
    res <- c('tp' = tp, 'fp' = fp, 'fn' = fn) 
    return(res)
    
}


In [12]:
### Modify the DALEX function: silence the initial check
shap_bab <- function(X, explain_rf) {
  # call the shap from iBreakDown
  res <- iBreakDown::shap(x = explain_rf, new_observation = X)
  class(res) <- c('predict_parts', class(res))
  return(res)
}

In [13]:
formatSingleSHAP <- function(i, res){
    res <- res[[i]] %>% select(variable_name, variable_value, contribution) %>% 
                group_by(variable_name, variable_value) %>% summarise_all(mean) %>% 
                ungroup()
    res$sample <- i
    return(res)
}

In [14]:
wrapComp <- function(fname, data_ori, related_taxa, path, n_cores = 10){
    
    res <- list()
    on.exit(return(res))
    # get data 
    message('Data preparation...')
    seedOri <- as.numeric(str_extract(fname, pattern = '(?<=simu)[:digit:]+(?=\\_)'))
    res$seedOri <- seedOri
    set.seed(seedOri)
    data_ori <- data_ori[sample(1:nrow(data_ori)),]
    simu <- qread(fname)
    message(paste0("Let's go with seed ", seedOri))
    
    # ground truth
    message('Ground truth...')
    tn <- unique( str_replace(unlist(simu$true_edges), pattern = '\\_{2}.*', replacement = '') ) 
    related_taxa <- related_taxa[tn] 
    res$tp_nodes <- length(tn)
    res$n_nodes <- ncol(data_ori)
    
    # shap
    message('DALEX...')
    explain_rf <- DALEX::explain( model = simu$rf, data = simu$data, y = simu$target == "1")
    listed_data <- as.list(lapply(seq_len(nrow(simu$data)), function(i){simu$data[i,]}))
    
    message('shap_bab...')
    cl <- makeCluster(n_cores)
    clusterEvalQ(cl, library(iBreakDown))
    clusterEvalQ(cl, library(DALEX))
    clusterEvalQ(cl, library(randomForest))
    on.exit(stopCluster(cl), add = TRUE, after =  FALSE)
    res_shap <- parLapply(cl = cl, X = listed_data, fun = shap_bab, explain_rf = explain_rf)
    
    res$raw_shap <- res_shap
    qsave(res_shap, paste0(path, 'raw_shap_', seedOri, '.qs'))
    
    # format res 
    message('Formatting...')
    res_shap <- lapply(1:length(res_shap), formatSingleSHAP, res = res_shap) 
    res_shap <- as.data.frame(do.call(rbind, res_shap))
    res_shap$contribution <- as.numeric(res_shap$contribution)
    res_shap$variable_value <- as.numeric(res_shap$variable_value)
    res$res_shap <- res_shap
    
    # PR curves
    # edges
    message('PR curves...')
    resM <- group_by(res_shap, variable_name) %>% summarise(val = mean(abs(contribution)))
    colnames(resM)[1] <- 'var'
    thr <- sort(unique(resM$val)) 
    pr_nodes <- as.data.frame(t(sapply(thr, getPR, res=resM, related_taxa=related_taxa)))
    pr_nodes <- arrange(pr_nodes, tp, fp)
    
    res$nodes <- pr_nodes
    return(res)
    
}

# go

In [15]:
tmpl <- list(conda = "r-ml", cores = 20, job_time = '96:00:00', job_mem = '1G')

In [18]:
res <- Q(wrapComp
         , fail_on_error = TRUE
         , fname = fnames
         , const = list('data_ori' = X, 'related_taxa' = related_taxa
                       , 'path' = '/ebio/abt3_projects/temp_data/aruaud/MtgSimu50/comparison_SHAP_rf/'
                       , 'n_cores' = 20)
         , export = c('getPR' = getPR, 'shap_bab' = shap_bab, 'formatSingleSHAP' = formatSingleSHAP)
         , pkgs = c('DALEX', 'iBreakDown', 'qs', 'stringr', 'tidyverse', 'parallel', 'randomForest')
         , n_jobs = length(fnames)
         , template = tmpl
         , log_worker=TRUE
        )

Submitting 50 worker jobs (ID: cmq8392) ...

Running 50 calculations (7 objs/8.6 Mb common; 1 calls/chunk) ...


[---------------------------------------------------]   0% (1/50 wrk) eta:  ?s

[---------------------------------------------------]   0% (2/50 wrk) eta:  ?s

[---------------------------------------------------]   0% (3/50 wrk) eta:  ?s

[---------------------------------------------------]   0% (4/50 wrk) eta:  ?s

[---------------------------------------------------]   0% (5/50 wrk) eta:  ?s

[---------------------------------------------------]   0% (6/50 wrk) eta:  ?s

[---------------------------------------------------]   0% (7/50 wrk) eta:  ?s

[---------------------------------------------------]   0% (8/50 wrk) eta:  ?s

[---------------------------------------------------]   0% (9/50 wrk) eta:  ?s

[--------------------------------------------------]   0% (10/50 wrk) eta:  ?s

[--------------------------------------------------]   0% (11/50 wrk) eta:  ?s

[------

                                                                              

Master: [4798.1s 3.6% CPU]; Worker: [avg 0.8% CPU, max 1002.8 Mb]



In [22]:
res_trimmed <- lapply(res, function(x){x[names(x) != 'raw_shap']})

In [16]:
qsave(res_trimmed, '../tmp/comparison_shapRF.qs')