In [2]:
from parsing_utils import *

# Drug resistance classification

In [None]:
drug_auc_ft_report, drug_auc_ft_report_std, drug_auc_ft_result_dict = generate_hyper_ft_report(metric_name='auroc', measurement='AUC')

In [None]:
import pandas as pd
import numpy as np
import data_config
import data
import ml_baseline



In [None]:
gex_features_df = pd.read_csv(data_config.gex_feature_file, index_col=0)

drugs = ['gem', 'fu', 'cis', 'tem']
metric = 'auroc'

auc_drug_dict = {
    'gem':0.8,
    'fu':0.65,
    'cis':0.95,
    'tem':0.9
}

enet_auc_result_dict = dict()

In [None]:
for drug in drugs:
    print(f"{drug}")
    labeled_ccle_dataloader, labeled_tcga_dataloader = data.get_labeled_dataloaders(
            gex_features_df=gex_features_df,
            seed=2020,
            batch_size=64,
            drug=drug,
            threshold=auc_drug_dict[drug],
            ft_flag=False,
            ccle_measurement='AUC'
        )
    metric_result_list = ml_baseline.n_time_cv(
        model_fn=ml_baseline.classify_with_enet,
        n=10,
        train_data=(
                labeled_ccle_dataloader.dataset.tensors[0].numpy(),
                labeled_ccle_dataloader.dataset.tensors[1].numpy()
        ),
        test_data=(
                labeled_tcga_dataloader.dataset.tensors[0].numpy(),
                labeled_tcga_dataloader.dataset.tensors[1].numpy()
        ),
            metric=metric
        )[1][metric]
    enet_auc_result_dict[drug] = metric_result_list





In [None]:
for drug in drugs:
    drug_auc_ft_report[drug]['en'] = enet_auc_result_dict[drug]


In [None]:
indicator_result_df = None
drug_dict = {'gem': 'Gemcitabine', 'fu': 'Fluorouracil', 'cis': 'Cisplatin', 'tem': 'Temozolomide'}
for cat in ['gem', 'fu']:
    temp_df = pd.DataFrame.from_dict(drug_auc_ft_report[cat])
    temp_df = temp_df.rename(columns={'dsn':"dsn-mmd", 'dsnw': 'dsn-adv', 'code_base': 'code-ae-base', 'code_mmd':'code-ae-mmd', 'code_adv': 'code-ae-adv'})
    temp_df = temp_df.rename(columns={k: k.upper() for k in temp_df.columns})

    temp_df = pd.melt(temp_df)
    temp_df['drug'] = drug_dict[cat]
    temp_df['metric'] = 'AUC'
    indicator_result_df = pd.concat((indicator_result_df, temp_df))
    
indicator_result_df = indicator_result_df.rename(columns={'variable': 'method', 'value': 'auroc'})
method_order = ['EN', 'MLP', 'AE','DAE','VAE', 'CORAL','ADAE', 'DSN-MMD','DSN-ADV', 'CODE-AE-BASE', 'CODE-AE-MMD','CODE-AE-ADV']
indicator_result_df.method = indicator_result_df.method.astype('category')
indicator_result_df.method.cat.set_categories(method_order, inplace=True)
indicator_result_df.sort_values(by='method', inplace=True)

In [None]:
days_result_df = None
drug_dict = {'gem': 'Gemcitabine', 'fu': 'Fluorouracil', 'cis': 'Cisplatin', 'tem': 'Temozolomide'}
for cat in ['cis', 'tem']:
    temp_df = pd.DataFrame.from_dict(drug_auc_ft_report[cat])
    temp_df = temp_df.rename(columns={'dsn':"dsn-mmd", 'dsnw': 'dsn-adv', 'code_base': 'code-ae-base', 'code_mmd':'code-ae-mmd', 'code_adv': 'code-ae-adv'})
    temp_df = temp_df.rename(columns={k: k.upper() for k in temp_df.columns})

    temp_df = pd.melt(temp_df)
    temp_df['drug'] = drug_dict[cat]
    temp_df['metric'] = 'AUC'
    days_result_df = pd.concat((days_result_df, temp_df))
    
days_result_df = days_result_df.rename(columns={'variable': 'method', 'value': 'auroc'})
method_order = ['EN', 'MLP', 'AE','DAE','VAE', 'CORAL','ADAE', 'DSN-MMD','DSN-ADV', 'CODE-AE-BASE', 'CODE-AE-MMD','CODE-AE-ADV']
days_result_df.method = days_result_df.method.astype('category')
days_result_df.method.cat.set_categories(method_order, inplace=True)
days_result_df.sort_values(by='method', inplace=True)

In [None]:
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(rc={'figure.figsize':(15,10)})
sns.set(font_scale=1.5)
sns.set_theme(style="white")


In [None]:
palette ={
    'EN':'purple', 
    'MLP':'purple', 
    'AE':'lightgreen',
    'DAE':'lightgreen',
    'VAE':'lightgreen', 
    'CORAL':'orange',
    'ADAE':'blue', 
    'DSN-MMD':'red',
    'DSN-ADV':'red', 
    'CODE-AE-BASE':'yellow', 
    'CODE-AE-MMD':'yellow',
    'CODE-AE-ADV':'yellow'}

In [None]:
ax = sns.barplot(x='drug', y='auroc',hue='method', data=indicator_result_df.loc[result_df.metric=='AUC'], palette=palette)
hatches = itertools.cycle(['','+','/'])
for i, bar in enumerate(ax.patches):
    if i % 2 == 0:
        hatch = next(hatches)
    bar.set_hatch(hatch)
    
plt.ylim(0.4,0.85)
plt.xlabel('')
plt.ylabel('AUROC', fontsize=15, weight='bold')
plt.xticks(fontsize=18,weight='bold')
plt.yticks(fontsize=15)
plt.legend(loc='upper right', bbox_to_anchor=(1.2, 1.01), fontsize=14)
#plt.savefig('../paper/tcga_auc_bar.png', format='png', dpi=350,bbox_inches='tight')

In [None]:
ax = sns.barplot(x='drug', y='auroc',hue='method', data=days_result_df.loc[result_df.metric=='AUC'], palette=palette)
hatches = itertools.cycle(['','+','/'])
for i, bar in enumerate(ax.patches):
    if i % 2 == 0:
        hatch = next(hatches)
    bar.set_hatch(hatch)
    
plt.ylim(0.4,0.85)
plt.xlabel('')
plt.ylabel('AUROC', fontsize=15, weight='bold')
plt.xticks(fontsize=18,weight='bold')
plt.yticks(fontsize=15)
plt.legend(loc='upper right', bbox_to_anchor=(1.2, 1.01), fontsize=14)
#plt.savefig('../paper/cis_auc_bar.png', format='png', dpi=350,bbox_inches='tight')

In [None]:
from scipy.stats import ttest_ind, ttest_ind_from_stats
def generate_p_val(a, b):
    t, p = ttest_ind(a, b, equal_var=False)
    print("ttest_ind:            t = %g  p = %g" % (t, p))
    
def generate_p_val2(a_mu, a_std, b_mu, b_std, a_n=10, b_n=10):
    t, p = ttest_ind_from_stats(a_mu, a_std, a_n,
                              b_mu, b_std, b_n,
                              equal_var=False)
    print("ttest_ind_from_stats: t = %g  p = %g" % (t, p))

In [None]:
# generate_p_val(a=auc_ft_result_dict['gem']['adsn'], b=auc_ft_result_dict['gem']['dsnw'])

# generate_p_val(a=auc_ft_result_dict['fu']['adsn'], b=auc_ft_result_dict['fu']['vae'])

# generate_p_val(a=auc_ft_result_dict['cis']['adsn'], b=auc_ft_result_dict['cis']['dsnw'])

# generate_p_val(a=auc_ft_result_dict['tem']['dsnw'], b=auc_ft_result_dict['tem']['adae'])

In [None]:
# generate_p_val2(a_mu=0.931891, a_std=0.001785, b_mu=0.903831, b_std=0.008108)

# generate_p_val2(a_mu=0.940007, a_std=0.007376, b_mu=0.944379, b_std=0.005745)

# generate_p_val2(a_mu=0.973023, a_std=0.000666, b_mu=0.962145, b_std=0.007951)

# generate_p_val2(a_mu=0.983963, a_std=0.003298, b_mu=0.986240, b_std=0.002010)

# Deconfounded representation transferability

In [None]:
de_auroc_ml_report, de_auroc_ml_report_std, _ = generate_hyper_ml_report(metric_name='auroc', measurement='AUC')

In [None]:
de_auprc_ml_report, de_auprc_ml_report_std, _ = generate_hyper_ml_report(metric_name='auprc', measurement='AUC')