In [1]:
#!/home/was966/micromamba/envs/responder/bin/python
#sbatch --mem 64G -c 4 -t 100:00:00 -p gpu_quad --gres=gpu:rtx8000:1 ./ctct_dense16.py

import sys

sys.path.insert(0, '/home/was966/Research/mims-conceptor/')
from conceptor.utils import plot_embed_with_label
from conceptor import PreTrainer, FineTuner, loadconceptor #, get_minmal_epoch
from conceptor.utils import plot_embed_with_label,plot_performance, score2
from conceptor.tokenizer import CANCER_CODE

import os
from tqdm import tqdm
from itertools import chain
import pandas as pd
import numpy as np
import random, torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style = 'white', font_scale=1.3)
import warnings
warnings.filterwarnings("ignore")

def onehot(S):
    assert type(S) == pd.Series, 'Input type should be pd.Series'
    dfd = pd.get_dummies(S, dummy_na=True)
    nanidx = dfd[dfd[np.nan].astype(bool)].index
    dfd.loc[nanidx, :] = np.nan
    dfd = dfd.drop(columns=[np.nan])*1.
    cols = dfd.sum().sort_values(ascending=False).index.tolist()
    dfd = dfd[cols]
    return dfd



pretrainer = loadconceptor('../../../checkpoint/latest/pretrainer.pt')
data_path = '../../../../paper/00_data/'
df_label = pd.read_pickle(os.path.join(data_path, 'ITRP.PATIENT.TABLE'))
df_tpm = pd.read_pickle(os.path.join(data_path, 'ITRP.TPM.TABLE'))
df_tpm.shape, df_label.shape

dfcx = df_label.cancer_type.map(CANCER_CODE).to_frame('cancer_code').join(df_tpm)

df_task = onehot(df_label.response_label)
size = df_label.groupby('cohort').size()
size = size.index + "\n(n = " + size.astype(str) + ")"
cohorts = df_label.groupby('cohort').size().sort_values().index.tolist()
#cohorts = ['Choueiri']


def cohort_to_cohort(cohorts):
    # Create a list of lists, each missing one element from the original list
    return [(cohorts[i], cohorts[:i] + cohorts[i+1:]) for i in range(len(cohorts))]
train_test_cohorts = cohort_to_cohort(cohorts)


params = {'mode': 'PFT',
            'seed':42,
            'lr': 1e-2,
            'device':'cuda',
            'weight_decay': 1e-1,
            'batch_size':8,
            'max_epochs': 10,
          
            'task_loss_type': 'ce_loss', #focal_loss

            'task_type': 'c',
          
            'load_decoder':False,
            'task_dense_layer': [16],
            'task_batch_norms':True,
            'entropy_weight': 0.0,
            'with_wandb': False,
            'save_best_model':False,
            'verbose': False}

seed = 42

for mode in ['PFT']: #,

    print('Evaludation on Model %s' % mode)

    params['mode'] = mode
    params['seed'] = seed
    
    work_dir = './CTCT/CTCT_%s_%s' % (mode, seed)
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)
    
    res = []
    for train_cohort, test_cohorts in train_test_cohorts:
        
        ## Get data for this cohort
        cohort_idx = df_label[df_label['cohort'] == train_cohort].index
        cohort_X = dfcx.loc[cohort_idx]
        cohort_y = df_task.loc[cohort_idx]
    
        if len(cohort_idx) > 100:
            params['batch_size'] = 16
        else:
            params['batch_size'] = 8
        
        ## Get features for specific method
        train_X = cohort_X
        train_y = cohort_y
        
        pretrainer = pretrainer.copy()
        finetuner = FineTuner(pretrainer, **params, 
                              work_dir= work_dir, 
                              task_name = '%s' % train_cohort)
        
        finetuner = finetuner.tune(dfcx_train = train_X,
                                   dfy_train = train_y, min_mcc=0.6)

        
        for test_cohort in cohorts:
            test_cohort_idx = df_label[df_label['cohort'] == test_cohort].index
            test_cohort_X = dfcx.loc[test_cohort_idx]
            test_cohort_y = df_task.loc[test_cohort_idx]

            _, pred_testy = finetuner.predict(test_cohort_X, batch_size = 16)


            pred_testy['train_cohort'] = train_cohort
            pred_testy['test_cohort'] = test_cohort 
            
            pred_testy['best_epoch'] = finetuner.best_epoch
            pred_testy['n_trainable_params'] = finetuner.count_parameters()
            pred_testy['mode'] = mode
            pred_testy['seed'] = seed
            pred_testy['batch_size'] = params['batch_size']
            pred_testy['task_dense_layer'] = str(params['task_dense_layer'])
            dfp = test_cohort_y.join(pred_testy)
    
            # y_true, y_prob, y_pred = dfp['R'], dfp[1], dfp[[0, 1]].idxmax(axis=1)
            # fig = plot_performance(y_true, y_prob, y_pred)
            # fig.suptitle('cohort to cohort transfer: train: %s, test: %s' % (train_cohort, test_cohort), fontsize=16)
            # fig.savefig(os.path.join(work_dir, 'CTCT_train_%s_test_%s.jpg' % (train_cohort, test_cohort)))
            res.append(dfp)
    
    dfs = pd.concat(res)
    dfp = dfs.groupby(['train_cohort', 'test_cohort']).apply(lambda x:score2(x['R'], x[1], x[[0, 1]].idxmax(axis=1)))

    #roc, prc, f1, acc, mcc
    dfp = dfp.apply(pd.Series)
    dfp.columns = ['ROC', 'PRC', 'F1', 'ACC', 'MCC']
    dfp = dfp.reset_index()
    
    dfs.to_csv(os.path.join(work_dir, 'source_performance.tsv'), sep='\t')
    dfp.to_csv(os.path.join(work_dir, 'metric_performance.tsv'), sep='\t')



Evaludation on Model PFT


100%|##########| 10/10 [00:14<00:00,  1.47s/it]
100%|##########| 1/1 [00:00<00:00,  2.54it/s]
100%|##########| 2/2 [00:00<00:00,  4.46it/s]
100%|##########| 2/2 [00:00<00:00,  4.63it/s]
100%|##########| 2/2 [00:00<00:00,  4.35it/s]
100%|##########| 2/2 [00:00<00:00,  4.47it/s]
100%|##########| 2/2 [00:00<00:00,  4.55it/s]
100%|##########| 3/3 [00:00<00:00,  6.41it/s]
100%|##########| 3/3 [00:00<00:00,  5.61it/s]
100%|##########| 3/3 [00:00<00:00,  6.32it/s]
100%|##########| 4/4 [00:00<00:00,  7.75it/s]
100%|##########| 5/5 [00:00<00:00,  8.03it/s]
100%|##########| 6/6 [00:00<00:00,  9.59it/s]
100%|##########| 7/7 [00:00<00:00,  9.91it/s]
100%|##########| 7/7 [00:00<00:00,  9.30it/s]
100%|##########| 11/11 [00:00<00:00, 13.05it/s]
100%|##########| 19/19 [00:01<00:00, 14.78it/s]
 60%|######    | 6/10 [00:09<00:06,  1.62s/it]

Stopping early at epoch  7. Meet minimal requirements by: f1=1.00,mcc=1.00,prc=1.00, roc=1.00



100%|##########| 1/1 [00:00<00:00,  2.44it/s]
100%|##########| 2/2 [00:00<00:00,  4.56it/s]
100%|##########| 2/2 [00:00<00:00,  4.47it/s]
100%|##########| 2/2 [00:00<00:00,  4.47it/s]
100%|##########| 2/2 [00:00<00:00,  4.14it/s]
100%|##########| 2/2 [00:00<00:00,  4.33it/s]
100%|##########| 3/3 [00:00<00:00,  6.47it/s]
100%|##########| 3/3 [00:00<00:00,  6.03it/s]
100%|##########| 3/3 [00:00<00:00,  5.99it/s]
100%|##########| 4/4 [00:00<00:00,  7.83it/s]
100%|##########| 5/5 [00:00<00:00,  8.28it/s]
100%|##########| 6/6 [00:00<00:00,  9.00it/s]
100%|##########| 7/7 [00:00<00:00,  9.80it/s]
100%|##########| 7/7 [00:00<00:00,  9.48it/s]
100%|##########| 11/11 [00:00<00:00, 12.99it/s]
100%|##########| 19/19 [00:01<00:00, 13.97it/s]
 70%|#######   | 7/10 [00:11<00:04,  1.62s/it]

Stopping early at epoch  8. Meet minimal requirements by: f1=0.80,mcc=0.75,prc=0.97, roc=0.98



100%|##########| 1/1 [00:00<00:00,  2.14it/s]
100%|##########| 2/2 [00:00<00:00,  4.29it/s]
100%|##########| 2/2 [00:00<00:00,  4.23it/s]
100%|##########| 2/2 [00:00<00:00,  4.38it/s]
100%|##########| 2/2 [00:00<00:00,  4.03it/s]
100%|##########| 2/2 [00:00<00:00,  4.64it/s]
100%|##########| 3/3 [00:00<00:00,  6.32it/s]
100%|##########| 3/3 [00:00<00:00,  6.28it/s]
100%|##########| 3/3 [00:00<00:00,  5.94it/s]
100%|##########| 4/4 [00:00<00:00,  7.22it/s]
100%|##########| 5/5 [00:00<00:00,  8.07it/s]
100%|##########| 6/6 [00:00<00:00,  9.49it/s]
100%|##########| 7/7 [00:00<00:00, 10.27it/s]
100%|##########| 7/7 [00:00<00:00,  9.96it/s]
100%|##########| 11/11 [00:00<00:00, 12.32it/s]
100%|##########| 19/19 [00:01<00:00, 14.46it/s]
 70%|#######   | 7/10 [00:13<00:05,  1.96s/it]

Stopping early at epoch  8. Meet minimal requirements by: f1=0.93,mcc=0.91,prc=0.97, roc=0.98



100%|##########| 1/1 [00:00<00:00,  2.38it/s]
100%|##########| 2/2 [00:00<00:00,  4.11it/s]
100%|##########| 2/2 [00:00<00:00,  4.79it/s]
100%|##########| 2/2 [00:00<00:00,  4.17it/s]
100%|##########| 2/2 [00:00<00:00,  4.46it/s]
100%|##########| 2/2 [00:00<00:00,  4.47it/s]
100%|##########| 3/3 [00:00<00:00,  6.46it/s]
100%|##########| 3/3 [00:00<00:00,  5.86it/s]
100%|##########| 3/3 [00:00<00:00,  5.78it/s]
100%|##########| 4/4 [00:00<00:00,  7.53it/s]
100%|##########| 5/5 [00:00<00:00,  7.18it/s]
100%|##########| 6/6 [00:00<00:00,  9.65it/s]
100%|##########| 7/7 [00:00<00:00, 10.94it/s]
100%|##########| 7/7 [00:00<00:00,  9.76it/s]
100%|##########| 11/11 [00:00<00:00, 11.95it/s]
100%|##########| 19/19 [00:01<00:00, 14.98it/s]
 70%|#######   | 7/10 [00:14<00:06,  2.06s/it]

Stopping early at epoch  8. Meet minimal requirements by: f1=0.88,mcc=0.78,prc=0.99, roc=0.99



100%|##########| 1/1 [00:00<00:00,  2.44it/s]
100%|##########| 2/2 [00:00<00:00,  4.54it/s]
100%|##########| 2/2 [00:00<00:00,  4.52it/s]
100%|##########| 2/2 [00:00<00:00,  4.46it/s]
100%|##########| 2/2 [00:00<00:00,  4.77it/s]
100%|##########| 2/2 [00:00<00:00,  4.46it/s]
100%|##########| 3/3 [00:00<00:00,  5.88it/s]
100%|##########| 3/3 [00:00<00:00,  5.92it/s]
100%|##########| 3/3 [00:00<00:00,  5.69it/s]
100%|##########| 4/4 [00:00<00:00,  7.74it/s]
100%|##########| 5/5 [00:00<00:00,  8.04it/s]
100%|##########| 6/6 [00:00<00:00,  9.31it/s]
100%|##########| 7/7 [00:00<00:00,  9.92it/s]
100%|##########| 7/7 [00:00<00:00, 10.69it/s]
100%|##########| 11/11 [00:00<00:00, 13.00it/s]
100%|##########| 19/19 [00:01<00:00, 15.09it/s]
 60%|######    | 6/10 [00:12<00:08,  2.00s/it]

Stopping early at epoch  7. Meet minimal requirements by: f1=0.87,mcc=0.70,prc=1.00, roc=1.00



100%|##########| 1/1 [00:00<00:00,  2.39it/s]
100%|##########| 2/2 [00:00<00:00,  4.05it/s]
100%|##########| 2/2 [00:00<00:00,  4.51it/s]
100%|##########| 2/2 [00:00<00:00,  4.20it/s]
100%|##########| 2/2 [00:00<00:00,  4.47it/s]
100%|##########| 2/2 [00:00<00:00,  4.58it/s]
100%|##########| 3/3 [00:00<00:00,  5.85it/s]
100%|##########| 3/3 [00:00<00:00,  6.36it/s]
100%|##########| 3/3 [00:00<00:00,  5.89it/s]
100%|##########| 4/4 [00:00<00:00,  7.32it/s]
100%|##########| 5/5 [00:00<00:00,  8.57it/s]
100%|##########| 6/6 [00:00<00:00,  9.05it/s]
100%|##########| 7/7 [00:00<00:00,  9.70it/s]
100%|##########| 7/7 [00:00<00:00,  9.63it/s]
100%|##########| 11/11 [00:00<00:00, 12.72it/s]
100%|##########| 19/19 [00:01<00:00, 15.74it/s]
 70%|#######   | 7/10 [00:16<00:07,  2.36s/it]

Stopping early at epoch  8. Meet minimal requirements by: f1=0.83,mcc=0.73,prc=0.97, roc=0.98



100%|##########| 1/1 [00:00<00:00,  2.12it/s]
100%|##########| 2/2 [00:00<00:00,  4.25it/s]
100%|##########| 2/2 [00:00<00:00,  4.65it/s]
100%|##########| 2/2 [00:00<00:00,  4.43it/s]
100%|##########| 2/2 [00:00<00:00,  4.08it/s]
100%|##########| 2/2 [00:00<00:00,  4.37it/s]
100%|##########| 3/3 [00:00<00:00,  5.60it/s]
100%|##########| 3/3 [00:00<00:00,  5.36it/s]
100%|##########| 3/3 [00:00<00:00,  6.01it/s]
100%|##########| 4/4 [00:00<00:00,  7.73it/s]
100%|##########| 5/5 [00:00<00:00,  8.26it/s]
100%|##########| 6/6 [00:00<00:00,  9.78it/s]
100%|##########| 7/7 [00:00<00:00, 10.31it/s]
100%|##########| 7/7 [00:00<00:00, 10.61it/s]
100%|##########| 11/11 [00:00<00:00, 13.06it/s]
100%|##########| 19/19 [00:01<00:00, 14.23it/s]
100%|##########| 10/10 [00:20<00:00,  2.06s/it]
100%|##########| 1/1 [00:00<00:00,  2.37it/s]
100%|##########| 2/2 [00:00<00:00,  4.76it/s]
100%|##########| 2/2 [00:00<00:00,  4.88it/s]
100%|##########| 2/2 [00:00<00:00,  4.31it/s]
100%|##########| 2/2 [00:00

Stopping early at epoch  8. Meet minimal requirements by: f1=0.86,mcc=0.81,prc=0.97, roc=0.98



100%|##########| 1/1 [00:00<00:00,  2.21it/s]
100%|##########| 2/2 [00:00<00:00,  4.70it/s]
100%|##########| 2/2 [00:00<00:00,  4.53it/s]
100%|##########| 2/2 [00:00<00:00,  4.23it/s]
100%|##########| 2/2 [00:00<00:00,  3.89it/s]
100%|##########| 2/2 [00:00<00:00,  4.15it/s]
100%|##########| 3/3 [00:00<00:00,  6.17it/s]
100%|##########| 3/3 [00:00<00:00,  6.13it/s]
100%|##########| 3/3 [00:00<00:00,  5.50it/s]
100%|##########| 4/4 [00:00<00:00,  7.15it/s]
100%|##########| 5/5 [00:00<00:00,  7.49it/s]
100%|##########| 6/6 [00:00<00:00,  9.21it/s]
100%|##########| 7/7 [00:00<00:00, 10.02it/s]
100%|##########| 7/7 [00:00<00:00,  9.93it/s]
100%|##########| 11/11 [00:00<00:00, 12.72it/s]
100%|##########| 19/19 [00:01<00:00, 15.29it/s]
 60%|######    | 6/10 [00:19<00:12,  3.18s/it]

Stopping early at epoch  7. Meet minimal requirements by: f1=0.89,mcc=0.87,prc=0.97, roc=0.99



100%|##########| 1/1 [00:00<00:00,  2.50it/s]
100%|##########| 2/2 [00:00<00:00,  4.72it/s]
100%|##########| 2/2 [00:00<00:00,  4.44it/s]
100%|##########| 2/2 [00:00<00:00,  4.79it/s]
100%|##########| 2/2 [00:00<00:00,  4.23it/s]
100%|##########| 2/2 [00:00<00:00,  4.62it/s]
100%|##########| 3/3 [00:00<00:00,  6.19it/s]
100%|##########| 3/3 [00:00<00:00,  6.29it/s]
100%|##########| 3/3 [00:00<00:00,  6.34it/s]
100%|##########| 4/4 [00:00<00:00,  7.10it/s]
100%|##########| 5/5 [00:00<00:00,  8.69it/s]
100%|##########| 6/6 [00:00<00:00,  8.56it/s]
100%|##########| 7/7 [00:00<00:00, 10.02it/s]
100%|##########| 7/7 [00:00<00:00,  8.92it/s]
100%|##########| 11/11 [00:00<00:00, 11.98it/s]
100%|##########| 19/19 [00:01<00:00, 14.95it/s]
 90%|######### | 9/10 [00:35<00:03,  3.98s/it]

Stopping early at epoch 10. Meet minimal requirements by: f1=0.86,mcc=0.69,prc=0.96, roc=0.94



100%|##########| 1/1 [00:00<00:00,  2.49it/s]
100%|##########| 2/2 [00:00<00:00,  4.98it/s]
100%|##########| 2/2 [00:00<00:00,  4.46it/s]
100%|##########| 2/2 [00:00<00:00,  4.83it/s]
100%|##########| 2/2 [00:00<00:00,  4.52it/s]
100%|##########| 2/2 [00:00<00:00,  4.28it/s]
100%|##########| 3/3 [00:00<00:00,  6.36it/s]
100%|##########| 3/3 [00:00<00:00,  6.18it/s]
100%|##########| 3/3 [00:00<00:00,  6.05it/s]
100%|##########| 4/4 [00:00<00:00,  7.22it/s]
100%|##########| 5/5 [00:00<00:00,  9.06it/s]
100%|##########| 6/6 [00:00<00:00,  8.96it/s]
100%|##########| 7/7 [00:00<00:00,  9.83it/s]
100%|##########| 7/7 [00:00<00:00, 10.05it/s]
100%|##########| 11/11 [00:00<00:00, 12.57it/s]
100%|##########| 19/19 [00:01<00:00, 14.61it/s]
100%|##########| 10/10 [00:42<00:00,  4.21s/it]
100%|##########| 1/1 [00:00<00:00,  2.36it/s]
100%|##########| 2/2 [00:00<00:00,  4.18it/s]
100%|##########| 2/2 [00:00<00:00,  4.09it/s]
100%|##########| 2/2 [00:00<00:00,  4.39it/s]
100%|##########| 2/2 [00:00

In [4]:
ls

[0m[01;34mCTCT[0m/            run_fft.ipynb  run_nft.ipynb
ctct_test.ipynb  run_lft.ipynb  run_pft.ipynb


In [3]:
0.163

0.163