In [None]:
%run 00_CombinationtherapyUtils.ipynb

In [None]:
torch.set_float32_matmul_precision('high')

base_directory='Base directory that DD-PRiSM located'    
batch_size=1024
num_workers=0

dtype=torch.float
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

predicted_pathway_attention_path=base_directory+'NCI_ALMANAC_combination/Processed/PredictedPathwayAttention_AdamW.csv'
predicted_viability_path=base_directory+'NCI_ALMANAC_combination/Processed/PredictedViability_AdamW.csv'

if (os.path.isfile(predicted_pathway_attention_path))&(os.path.isfile(predicted_viability_path)):
    with open(predicted_pathway_attention_path,"rb") as fr:    
        nci_almanac_single_attention=pickle.load(fr)       
    with open(predicted_viability_path,"rb") as fr:    
        nci_almanac_single_viability=pickle.load(fr)

else:
    expression_df=pd.read_csv(base_directory+'NCI_ALMANAC_mono/Processed/Expression_ZNormalized.csv',index_col=0)
    valid_gene_list=expression_df.columns
    
    KEGG_legacy_file='c2.cp.kegg_legacy.v2023.2.Hs.symbols.gmt' #186 gene sets
    
    GeneSet_List=[]
    GeneSetFile=base_directory+'Raw/'+KEGG_legacy_file
    with open(GeneSetFile) as f:
        reader = csv.reader(f)
        data = list(list(rec) for rec in csv.reader(f, delimiter='\t')) #reads csv into a list of lists
        for row in data:
            GeneSet_List.append(row)
    
    GeneSet_Dic={}
    for GeneSet in GeneSet_List:
        GeneSet_Dic[GeneSet[0]]=GeneSet[2:]
    
    GeneSet_Dic_valid={}
    for GeneSet in GeneSet_Dic:
        GeneSet_tmp=pd.Series(GeneSet_Dic[GeneSet])
        GeneSet_tmp=GeneSet_tmp[GeneSet_tmp.isin(valid_gene_list)]
        GeneSet_Dic_valid[GeneSet]=GeneSet_tmp
    
    pathway_list=list(GeneSet_Dic_valid.keys())
    
    geneexpression_df_dic={}
    for pathway in pathway_list:
        geneexpression_df_dic[pathway]=pd.read_csv(base_directory+'Input/'+pathway+'.csv',index_col=0)
    fingerprint_df=pd.read_csv(base_directory+'Input/Fingerprint_Morgan512.csv',index_col=0)
        
    geneexpression_tensor_df_path=base_directory+'NCI_ALMANAC_mono/Processed/GeneExpression_Tensor.pickle'
        
    if (os.path.isfile(geneexpression_tensor_df_path)):        
        with open(geneexpression_tensor_df_path,"rb") as fr:    
            geneexpression_df_by_cellline=pickle.load(fr)
    
    else:    
        pathway_list=list(GeneSet_Dic_valid.keys())
        cellline_list=list(geneexpression_df_dic[pathway_list[0]].index)
        geneexpression_dic_by_cellline={}
        for cellline in tqdm(cellline_list):
            list_for_cellline=[]
            for pathway in pathway_list:
                list_for_cellline.append(geneexpression_df_dic[pathway].loc[cellline].values)
            geneexpression_dic_by_cellline[cellline]=list_for_cellline
        geneexpression_df_by_cellline=pd.DataFrame.from_dict(geneexpression_dic_by_cellline,orient='index')
        geneexpression_df_by_cellline.columns=pathway_list
        
        with open(geneexpression_tensor_df_path,"wb") as fw:    
            pickle.dump(geneexpression_df_by_cellline, fw)

    nci_almanac_comb=pd.read_csv(base_directory+'NCI_ALMANAC_combination/Processed/NCI_ALMANAC_combination.csv',index_col=0)
    nci_almanac_comb1 = nci_almanac_comb[['NSC1','CONCENTRATION1','CELLNAME']]
    nci_almanac_comb1.columns = ['NSC','CONCENTRATION','CELLNAME']
    nci_almanac_comb2 = nci_almanac_comb[['NSC2','CONCENTRATION2','CELLNAME']]
    nci_almanac_comb2.columns = ['NSC','CONCENTRATION','CELLNAME']
    nci_almanac_single = pd.concat([nci_almanac_comb1,nci_almanac_comb2],axis=0).drop_duplicates()
    nci_almanac_single.CONCENTRATION =[np.around(x,5) for x in nci_almanac_single.CONCENTRATION]
    nci_almanac_single['VIABILITY'] = 1 #Dummy viability value for the dataloader
    
    #Here, we used SequentialSampler, not the RandomSampler, as we need predicted pathway attention and predicted viability for each pairs with their informaiton
    nci_almanac_single_dataset = MonotherapyDataset(nci_almanac_single,pathway_list,geneexpression_df_by_cellline,fingerprint_df)
    nci_almanac_single_sampler = BatchSampler(SequentialSampler(nci_almanac_single_dataset),batch_size=batch_size,drop_last=False)
    nci_almanac_single_dataloader = DataLoader(nci_almanac_single_dataset,sampler=nci_almanac_single_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

    log_df=pd.read_csv(base_directory+'NCI_ALMANAC_mono/Result/AdamW/Training_Log.csv',index_col=0)
    best_epoch=log_df.sort_values(by='Loss').index[0]
    pretrained_mono_model_path=base_directory+'NCI_ALMANAC_mono/Result/AdamW/Model/'+str(best_epoch)+'.pt'
    pretrained_mono_model=MonotherapyModel(GeneSet_Dic_valid).to(device)
    pretrained_mono_model.load_state_dict(torch.load(pretrained_mono_model_path))
    pretrained_mono_model.eval()
    
    target_module='sample_attention_block'
    predicted_pathway_attention, predicted_viability = get_intermediate_output(nci_almanac_single_dataloader, pretrained_mono_model, target_module)
    
    nci_almanac_single_attention=nci_almanac_single.copy().reset_index(drop=True)
    nci_almanac_single_attention=nci_almanac_single_attention[['NSC','CELLNAME']]
    nci_almanac_single_attention=pd.concat([nci_almanac_single_attention,pd.DataFrame(get_tensor_value(predicted_pathway_attention))],axis=1).drop_duplicates(subset=['NSC','CELLNAME'])
    nci_almanac_single_attention.index=pd.MultiIndex.from_frame(nci_almanac_single_attention[['NSC','CELLNAME']])
    nci_almanac_single_attention=nci_almanac_single_attention[nci_almanac_single_attention.columns[2:]] #First two columns are NSC and CELLNAME
    
    nci_almanac_single_viability=nci_almanac_single.copy().reset_index(drop=True)
    nci_almanac_single_viability=nci_almanac_single_viability[['NSC','CONCENTRATION','CELLNAME']]
    nci_almanac_single_viability=pd.concat([nci_almanac_single_viability,pd.DataFrame(get_tensor_value(predicted_viability))],axis=1).drop_duplicates(subset=['NSC','CONCENTRATION','CELLNAME'])
    nci_almanac_single_viability.columns=['NSC','CONCENTRATION','CELLNAME','PREDICTED_VIABILITY']
    nci_almanac_single_viability.index=pd.MultiIndex.from_frame(nci_almanac_single_viability[['NSC','CONCENTRATION','CELLNAME']])
    nci_almanac_single_viability=nci_almanac_single_viability[['PREDICTED_VIABILITY']]
    
    with open(predicted_pathway_attention_path,"wb") as fw:
        pickle.dump(nci_almanac_single_attention, fw)
        
    with open(predicted_viability_path,"wb") as fw:
        pickle.dump(nci_almanac_single_viability, fw)


In [None]:
trainval_df_path=base_directory+'NCI_ALMANAC_combination/Training/UnseenSetting/TrainVal.csv'
unseen_pair_path=base_directory+'NCI_ALMANAC_combination/Training/UnseenSetting/UnseenPair.csv'
unseen_cellline_path=base_directory+'NCI_ALMANAC_combination/Training/UnseenSetting/UnseenCellLine.csv'
unseen_drug1_path=base_directory+'NCI_ALMANAC_combination/Training/UnseenSetting/UnseenDrug1.csv'
unseen_drug2_path=base_directory+'NCI_ALMANAC_combination/Training/UnseenSetting/UnseenDrug2.csv'
unseen_both_path=base_directory+'NCI_ALMANAC_combination/Training/UnseenSetting/UnseenBoth.csv'

df_TrainVal=pd.read_csv(trainval_df_path,index_col=0)
df_TrainVal=df_TrainVal.sample(frac=1)
df_TrainVal=df_TrainVal.reset_index(drop=True)
df_TrainVal.CONCENTRATION1=[np.around(x,5) for x in df_TrainVal.CONCENTRATION1]
df_TrainVal.CONCENTRATION2=[np.around(x,5) for x in df_TrainVal.CONCENTRATION2]

df_UnseenPair=pd.read_csv(unseen_pair_path,index_col=0).sample(frac=1).reset_index(drop=True)
df_UnseenPair.CONCENTRATION1=[np.around(x,5) for x in df_UnseenPair.CONCENTRATION1]
df_UnseenPair.CONCENTRATION2=[np.around(x,5) for x in df_UnseenPair.CONCENTRATION2]

df_UnseenCellline=pd.read_csv(unseen_cellline_path,index_col=0).sample(frac=1).reset_index(drop=True)
df_UnseenCellline.CONCENTRATION1=[np.around(x,5) for x in df_UnseenCellline.CONCENTRATION1]
df_UnseenCellline.CONCENTRATION2=[np.around(x,5) for x in df_UnseenCellline.CONCENTRATION2]

df_UnseenDrug1=pd.read_csv(unseen_drug1_path,index_col=0).sample(frac=1).reset_index(drop=True)
df_UnseenDrug1.CONCENTRATION1=[np.around(x,5) for x in df_UnseenDrug1.CONCENTRATION1]
df_UnseenDrug1.CONCENTRATION2=[np.around(x,5) for x in df_UnseenDrug1.CONCENTRATION2]

df_UnseenDrug2=pd.read_csv(unseen_drug2_path,index_col=0).sample(frac=1).reset_index(drop=True)
df_UnseenDrug2.CONCENTRATION1=[np.around(x,5) for x in df_UnseenDrug2.CONCENTRATION1]
df_UnseenDrug2.CONCENTRATION2=[np.around(x,5) for x in df_UnseenDrug2.CONCENTRATION2]

df_UnseenBoth=pd.read_csv(unseen_both_path,index_col=0).sample(frac=1).reset_index(drop=True)
df_UnseenBoth.CONCENTRATION1=[np.around(x,5) for x in df_UnseenBoth.CONCENTRATION1]
df_UnseenBoth.CONCENTRATION2=[np.around(x,5) for x in df_UnseenBoth.CONCENTRATION2]

trainval_dataset=CombinationDataset(df_TrainVal,nci_almanac_single_attention,nci_almanac_single_viability)
len_train=int(len(trainval_dataset)*7/8)
len_val=len(trainval_dataset)-len_train
training_set, validation_set=random_split(trainval_dataset,[len_train,len_val])
df_Train=df_TrainVal.iloc[training_set.indices]
df_Train_copy=df_Train.copy()
df_Train_copy=df_Train_copy[['NSC2','CONCENTRATION2','NSC1','CONCENTRATION1','CELLNAME','VIABILITY']]
df_Train=pd.concat([df_Train,df_Train_copy],axis=0)
df_Val=df_TrainVal.iloc[validation_set.indices]

training_set=CombinationDataset(df_Train,nci_almanac_single_attention,nci_almanac_single_viability)
training_sampler = BatchSampler(RandomSampler(training_set),batch_size=batch_size,drop_last=False)
training_dataloader=DataLoader(training_set,sampler=training_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

validation_set=CombinationDataset(df_Val,nci_almanac_single_attention,nci_almanac_single_viability)
validation_sampler = BatchSampler(RandomSampler(validation_set),batch_size=batch_size,drop_last=False)
validation_dataloader=DataLoader(validation_set,sampler=validation_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

pair_set=CombinationDataset(df_UnseenPair,nci_almanac_single_attention,nci_almanac_single_viability)
pair_sampler = BatchSampler(SequentialSampler(pair_set),batch_size=batch_size,drop_last=False)
pair_dataloader = DataLoader(pair_set,sampler=pair_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

cellline_set=CombinationDataset(df_UnseenCellline,nci_almanac_single_attention,nci_almanac_single_viability)
cellline_sampler = BatchSampler(SequentialSampler(cellline_set),batch_size=batch_size,drop_last=False)
cellline_dataloader = DataLoader(cellline_set,sampler=cellline_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

drug1_set=CombinationDataset(df_UnseenDrug1,nci_almanac_single_attention,nci_almanac_single_viability)
drug1_sampler = BatchSampler(SequentialSampler(drug1_set),batch_size=batch_size,drop_last=False)
drug1_dataloader = DataLoader(drug1_set,sampler=drug1_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

drug2_set=CombinationDataset(df_UnseenDrug2,nci_almanac_single_attention,nci_almanac_single_viability)
drug2_sampler = BatchSampler(SequentialSampler(drug2_set),batch_size=batch_size,drop_last=False)
drug2_dataloader = DataLoader(drug2_set,sampler=drug2_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

both_set=CombinationDataset(df_UnseenBoth,nci_almanac_single_attention,nci_almanac_single_viability)
both_sampler = BatchSampler(SequentialSampler(both_set),batch_size=batch_size,drop_last=False)
both_dataloader = DataLoader(both_set,sampler=both_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')


In [None]:
precision=2
density_dic=estimate_density(training_set.df.VIABILITY,precision)

In [None]:
alpha=1
beta=0.5
gamma=0.75
       
num_epochs=1000
learning_rate=0.01

comb_model=CombinationTherapyModel().to(device)
optimizer=AdamW(comb_model.parameters(),lr=learning_rate)
loss_fn = CustomLoss(density_dic,alpha, beta, gamma)

val_loss_dic={}
best_loss=np.inf #Initially, it is inf
threshold=0.0005
max_patience_lr=10 #Maximum 10 epochs can be trained without improvement
max_patience_es=20
patience_lr=max_patience_lr #Currently How many epochs can be trained without improvement
patience_es=max_patience_es

for epoch in range(num_epochs):
    train_comb(training_dataloader,comb_model,loss_fn,optimizer)
    torch.save(comb_model.state_dict(),base_directory+'NCI_ALMANAC_combination/Result/AdamW/Model/'+str(epoch)+'.pt')
    val_loss,pcc,rmse=test_comb(validation_dataloader,comb_model,loss_fn)
    val_loss_dic[epoch]=[val_loss,pcc,rmse,learning_rate]
    print('best_loss: '+str(best_loss))
    print('current_val_loss: '+str(val_loss))
    
    if(val_loss<best_loss-threshold):
        best_loss=val_loss
        patience_lr=max_patience_lr
        patience_es=max_patience_es
    else:
        patience_lr=patience_lr-1
        patience_es=patience_es-1
    if patience_lr<=0:
        patience_lr=max_patience_lr
        learning_rate=learning_rate*0.1
        optimizer.param_groups[0]['lr']=learning_rate
        print('Learning Rate is changed into '+str(learning_rate))
    if patience_es<=0:
        print('Early Stopping')
        break

In [None]:
df_tmp=pd.DataFrame.from_dict(val_loss_dic,orient='index')
log_df=pd.DataFrame({'Loss': [get_tensor_value(x) for x in df_tmp[0]],'PCC': df_tmp[1],'RMSE': df_tmp[2],'lr': df_tmp[3].values})
log_df.to_csv(base_directory+'NCI_ALMANAC_combination/Result/AdamW/Training_Log.csv')

with open(base_directory+'NCI_ALMANAC_combination/Result/AdamW/validation_idx.pickle',"wb") as fw:    
    pickle.dump(validation_set.df, fw)

In [None]:
test_dataloader=pair_dataloader

log_df=pd.read_csv(base_directory+'NCI_ALMANAC_combination/Result/AdamW/Training_Log.csv',index_col=0)
best_epoch=log_df.sort_values(by='Loss').index[0]
comb_model=CombinationTherapyModel().to(device)
comb_model.load_state_dict(torch.load(base_directory+'NCI_ALMANAC_combination/Result/AdamW/Model/'+str(best_epoch)+'.pt'))

real,predicted=predict_comb(test_dataloader, comb_model)
real=real.cpu().numpy()
predicted=predicted.cpu().numpy()

print('RMSE: '+str(mean_squared_error(real,predicted)**0.5))
print('PCC: '+str(stats.pearsonr(real,predicted)[0]))
print('R2: '+str(r2_score(real,predicted)))