In [17]:
import os
import argparse
from utils import bool_ext, load_dataset, split_dataset, evaluate, checkCorrelations
from models import CITRUS
import pickle
import torch
import numpy as np
import pandas as pd

device = 'cuda' if torch.cuda.is_available() else 'cpu'

if device == 'cuda':
    device_name = torch.cuda.get_device_name(0)
else:
    device_name = 'cpu'


parser = argparse.ArgumentParser()

parser.add_argument(
    "--input_dir", 
    help="directory of input files", 
    type=str, 
    default="./data"
)
parser.add_argument(
    "--output_dir",
    help="directory of output files",
    type=str,
    default="./output",
)
parser.add_argument(
    "--embedding_size",
    help="embedding dimension of genes and tumors",
    type=int,
    default=512,
)
parser.add_argument(
    "--hidden_size", 
    help="hidden layer dimension of MLP decoder", 
    type=int, 
    default=400
)
parser.add_argument(
    "--attention_size", 
    help="size of attention parameter beta_j", 
    type=int, 
    default=256
)
parser.add_argument(
    "--attention_head", 
    help="number of attention heads", 
    type=int, 
    default=32
)
parser.add_argument(
    "--learning_rate", 
    help="learning rate for Adam", 
    type=float, 
    default=1e-3
)
parser.add_argument(
    "--max_iter", 
    help="maximum number of training iterations", 
    type=int, 
    default=1000
)
parser.add_argument(
    "--max_fscore",
    help="Max F1 score to early stop model from training",
    type=float,
    default=0.7
)
parser.add_argument(
    "--batch_size", 
    help="training batch size", 
    type=int, 
    default=100
)
parser.add_argument(
    "--test_batch_size", 
    help="test batch size", 
    type=int, 
    default=100
)
parser.add_argument(
    "--test_inc_size",
    help="increment interval size between log outputs",
    type=int,
    default=256
)
parser.add_argument(
    "--dropout_rate", 
    help="dropout rate", 
    type=float, 
    default=0.2
)
parser.add_argument(
    "--input_dropout_rate", 
    help="dropout rate", 
    type=float, 
    default=0.2
)
parser.add_argument(
    "--weight_decay", 
    help="coefficient of l2 regularizer", 
    type=float, 
    default=1e-5
)
parser.add_argument(
    "--activation",
    help="activation function used in hidden layer",
    type=str,
    default="tanh",
)
parser.add_argument(
    "--patience", 
    help="earlystopping patience", 
    type=int, 
    default=30
)
parser.add_argument(
    "--mask01",
    help="wether to ignore the float value and convert mask to 01",
    type=bool_ext,
    default=True,
)
parser.add_argument(
    "--gep_normalization", 
    help="how to normalize gep", 
    type=str, 
    default="scaleRow"
)
parser.add_argument(
    "--attention",
    help="whether to use attention mechanism or not",
    type=bool_ext,
    default=True,
)
parser.add_argument(
    "--cancer_type",
    help="whether to use cancer type or not",
    type=bool_ext,
    default=True,
)
parser.add_argument(
    "--train_model",
    help="whether to train model or load model",
    type=bool_ext,
    default=True,
)
parser.add_argument(
    "--dataset_name",
    help="the dataset name loaded and saved",
    type=str,
    default="dataset_CITRUS",
)
parser.add_argument(
    "--tag", 
    help="a tag passed from command line", 
    type=str, 
    default=""
)
parser.add_argument(
    "--run_count", 
    help="the count for training", 
    type=str, 
    default="1"
)


args = parser.parse_args([])

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

print("Loading dataset...")
dataset, dataset_test = load_dataset(
    input_dir=args.input_dir,
    mask01=args.mask01,
    dataset_name=args.dataset_name,
    gep_normalization=args.gep_normalization,
)


_df = pd.read_csv('/ihome/hosmanbeyoglu/kor11/tools/CITRUS/data/brca_tcga_pan_can_atlas_2018_clinical_data.tsv', sep='\t')

# dataset['can'] = dataset['can'] + 5

# for tmr, subtype in _df[['Patient ID', 'Subtype']].values:
#     if tmr in dataset['tmr']:
#         if subtype == 'BRCA_LumA':
#             dataset['can'][dataset['tmr'].index(tmr)] = 1
#         if subtype == 'BRCA_LumB':
#             dataset['can'][dataset['tmr'].index(tmr)] = 2
#         if subtype == 'BRCA_Basal':
#             dataset['can'][dataset['tmr'].index(tmr)] = 3
#         if subtype == 'BRCA_her2':
#             dataset['can'][dataset['tmr'].index(tmr)] = 4
#         if subtype == 'BRCA_Normal':
#             dataset['can'][dataset['tmr'].index(tmr)] = 5
            


train_set, test_set = split_dataset(dataset, ratio=0.66)

args.can_size = dataset["can"].max()  # cancer type dimension
args.sga_size = dataset["sga"].max()  # SGA dimension
args.gep_size = dataset["gep"].shape[1]  # GEP output dimension
args.num_max_sga = dataset["sga"].shape[1]  # maximum number of SGAs in a tumor

args.hidden_size = dataset["tf_gene"].shape[0]
print("Hyperparameters:")
print(args)
args.tf_gene = dataset["tf_gene"]


masks = np.load('./pnet_prostate_paper/train/maps.npy', allow_pickle=True)

# models = []
# for i in [1, 2, 3]:
#     model = CITRUS(args, masks)  # initialize CITRUS model
#     model.build(device=device)  # build CITRUS model
#     model.to(device)

#     model.load_state_dict(torch.load(f'/ihome/hosmanbeyoglu/kor11/tools/CITRUS/output/trained_modelx{i}.pth', 
#                         map_location=torch.device('cpu')))
#     model.eval()
    
#     models.append(model)

Loading dataset...
Hyperparameters:
Namespace(activation='tanh', attention=True, attention_head=32, attention_size=256, batch_size=100, can_size=17, cancer_type=True, dataset_name='dataset_CITRUS', dropout_rate=0.2, embedding_size=512, gep_normalization='scaleRow', gep_size=5541, hidden_size=320, input_dir='./data', input_dropout_rate=0.2, learning_rate=0.001, mask01=True, max_fscore=0.7, max_iter=1000, num_max_sga=1396, output_dir='./output', patience=30, run_count='1', sga_size=11998, tag='', test_batch_size=100, test_inc_size=256, train_model=True, weight_decay=1e-05)


In [20]:
from utils import Data

data = Data(
    fGEP_SGA = 'data/CITRUS_GEP_SGAseparated.csv',
    fgene_tf_SGA = 'data/CITRUS_gene_tf_SGAseparated.csv',
    fcancerType_SGA = 'data/CITRUS_canType_SGAseparated.csv',
    fSGA_SGA = 'data/CITRUS_SGA_SGAseparated.csv',
)

In [21]:
from sklearn.decomposition import PCA
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from utils import checkCorrelations
from scipy.stats import ttest_1samp as ttest
import warnings 
from sklearn import metrics
warnings.filterwarnings("ignore")

In [22]:
d = data.cancerType_sga.loc[dataset['tmr']]
d['index'] = dataset['can'].reshape(-1)

In [23]:
daata = pickle.load( open("/ihome/hosmanbeyoglu/kor11/tools/CITRUS/data/dataset_CITRUS.pkl", "rb") )
cancers = daata['idx2can']

In [36]:
perfs = [[], [], []]

for idx, model in enumerate(models):
    preds, hid_tmr, tf, _, _  = model.forward(torch.tensor(test_set['sga']), torch.from_numpy(test_set['can']))
    genes_ = test_set['gep'].shape[1]
    test_df = pd.DataFrame(np.concatenate([test_set['gep'], test_set['can'], preds.detach().numpy()], axis=1))

    test_cancers = {}
    for ix, canc in cancers.items():
        test_cancers[canc] =  {}
        test_cancers[canc]['test'] = test_df[test_df[genes_]==ix+1].values[:, :genes_]    
        test_cancers[canc]['pred'] = test_df[test_df[genes_]==ix+1].values[:, genes_+1:] 
        
    for canc in ['BLCA', 'BRCA', 'CESC', 'COAD', 'ESCA', 
                'GBM', 'HNSC', 'KIRC', 'KIRP', 
                'LIHC', 'LUAD', 'LUSC', 'PCPG', 
                'PRAD', 'STAD', 'THCA', 'UCEC']:
            corr = checkCorrelations(test_cancers[canc]['test'], test_cancers[canc]['pred'], return_value=True)
            mse = metrics.mean_squared_error(test_cancers[canc]['test'], test_cancers[canc]['pred'])
            perfs[idx].append((corr, mse))   

In [38]:
perfs = np.array(perfs)

In [41]:
corr, mse = perfs.mean(0)[:, 0], perfs.mean(0)[:, 1]
corr_std, mse_std = perfs.std(0)[:, 0], perfs.std(0)[:, 1]

In [50]:
r = pd.DataFrame(np.column_stack([corr, corr_std, mse, mse_std]), index=['BLCA', 'BRCA', 'CESC', 'COAD', 'ESCA', 
                'GBM', 'HNSC', 'KIRC', 'KIRP', 
                'LIHC', 'LUAD', 'LUSC', 'PCPG', 
                'PRAD', 'STAD', 'THCA', 'UCEC'])
r.columns = ['MeanExpression', 'StdExpression', 'MSE', 'StdMSE']
r

Unnamed: 0,MeanExpression,StdExpression,MSE,StdMSE
BLCA,0.909137,0.000335,0.186613,0.000593
BRCA,0.928268,0.001118,0.150283,0.001777
CESC,0.910333,0.001241,0.184669,0.002467
COAD,0.939443,0.000961,0.134263,0.001858
ESCA,0.920318,0.00213,0.168926,0.003936
GBM,0.899825,0.002974,0.200847,0.005101
HNSC,0.930309,0.000192,0.148573,0.000196
KIRC,0.939717,0.000483,0.131947,0.000947
KIRP,0.917299,0.001025,0.172904,0.001987
LIHC,0.906193,0.001962,0.189299,0.003871
