In [None]:
%run 00_MonotherapyUtils.ipynb

In [None]:
torch.set_float32_matmul_precision('high')

base_directory='Base directory that DD-PRiSM located'    
batch_size=1024
num_workers=0

expression_df=pd.read_csv(base_directory+'NCI_ALMANAC_mono/Processed/Expression_ZNormalized.csv',index_col=0)
valid_gene_list=expression_df.columns

KEGG_legacy_file='c2.cp.kegg_legacy.v2023.2.Hs.symbols.gmt' #186 gene sets

GeneSet_List=[]
GeneSetFile=base_directory+'Raw/'+KEGG_legacy_file
with open(GeneSetFile) as f:
    reader = csv.reader(f)
    data = list(list(rec) for rec in csv.reader(f, delimiter='\t')) #reads csv into a list of lists
    for row in data:
        GeneSet_List.append(row)

GeneSet_Dic={}
for GeneSet in GeneSet_List:
    GeneSet_Dic[GeneSet[0]]=GeneSet[2:]

GeneSet_Dic_valid={}
for GeneSet in GeneSet_Dic:
    GeneSet_tmp=pd.Series(GeneSet_Dic[GeneSet])
    GeneSet_tmp=GeneSet_tmp[GeneSet_tmp.isin(valid_gene_list)]
    GeneSet_Dic_valid[GeneSet]=GeneSet_tmp

pathway_list=list(GeneSet_Dic_valid.keys())

geneexpression_df_dic={}
for pathway in pathway_list:
    geneexpression_df_dic[pathway]=pd.read_csv(base_directory+'Input/'+pathway+'.csv',index_col=0)
fingerprint_df=pd.read_csv(base_directory+'Input/Fingerprint_Morgan512.csv',index_col=0)

dtype=torch.float
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

geneexpression_tensor_df_path=base_directory+'NCI_ALMANAC_mono/Processed/GeneExpression_Tensor.pickle'

trainval_df_path=base_directory+'NCI_ALMANAC_mono/Training/TrainVal.csv'
unseenpair_df_path=base_directory+'NCI_ALMANAC_mono/Training/UnseenPair.csv'
unseencellline_df_path=base_directory+'NCI_ALMANAC_mono/Training/UnseenCellLine.csv'
unseendrug_df_path=base_directory+'NCI_ALMANAC_mono/Training/UnseenDrug.csv'
unseenboth_df_path=base_directory+'NCI_ALMANAC_mono/Training/UnseenBoth.csv'

if (os.path.isfile(geneexpression_tensor_df_path)):
    with open(geneexpression_tensor_df_path,"rb") as fr:    
        geneexpression_df_by_cellline=pickle.load(fr)

else:
    pathway_list=list(GeneSet_Dic_valid.keys())
    cellline_list=list(geneexpression_df_dic[pathway_list[0]].index)
    geneexpression_dic_by_cellline={}
    for cellline in tqdm(cellline_list):
        list_for_cellline=[]
        for pathway in pathway_list:
            list_for_cellline.append(geneexpression_df_dic[pathway].loc[cellline].values)
        geneexpression_dic_by_cellline[cellline]=list_for_cellline
    geneexpression_df_by_cellline=pd.DataFrame.from_dict(geneexpression_dic_by_cellline,orient='index')
    geneexpression_df_by_cellline.columns=pathway_list

    with open(geneexpression_tensor_df_path,"wb") as fw:    
        pickle.dump(geneexpression_df_by_cellline, fw)
    
df_TrainVal=pd.read_csv(trainval_df_path,index_col=0)
df_UnseenPair=pd.read_csv(unseenpair_df_path,index_col=0)
df_UnseenCellLine=pd.read_csv(unseencellline_df_path,index_col=0)
df_UnseenDrug=pd.read_csv(unseendrug_df_path,index_col=0)
df_UnseenBoth=pd.read_csv(unseenboth_df_path,index_col=0)

df_TrainVal=df_TrainVal.sample(frac=1)
df_UnseenPair=df_UnseenPair.sample(frac=1)
df_UnseenCellLine=df_UnseenCellLine.sample(frac=1)
df_UnseenDrug=df_UnseenDrug.sample(frac=1)
df_UnseenBoth=df_UnseenBoth.sample(frac=1)

df_TrainVal=df_TrainVal.reset_index(drop=True)
df_UnseenPair=df_UnseenPair.reset_index(drop=True)
df_UnseenCellLine=df_UnseenCellLine.reset_index(drop=True)
df_UnseenDrug=df_UnseenDrug.reset_index(drop=True)
df_UnseenBoth=df_UnseenBoth.reset_index(drop=True)

trainval_dataset=MonotherapyDataset(df_TrainVal,pathway_list,geneexpression_df_by_cellline,fingerprint_df)

len_train=int(len(trainval_dataset)*8/9)
len_val=len(trainval_dataset)-len_train
training_set,validation_set=random_split(trainval_dataset,[len_train,len_val])
pair_set=MonotherapyDataset(df_UnseenPair,pathway_list,geneexpression_df_by_cellline,fingerprint_df)
cellline_set=MonotherapyDataset(df_UnseenCellLine,pathway_list,geneexpression_df_by_cellline,fingerprint_df)
drug_set=MonotherapyDataset(df_UnseenDrug,pathway_list,geneexpression_df_by_cellline,fingerprint_df)
both_set=MonotherapyDataset(df_UnseenBoth,pathway_list,geneexpression_df_by_cellline,fingerprint_df)

training_sampler = BatchSampler(RandomSampler(training_set),batch_size=batch_size,drop_last=False)
training_dataloader=DataLoader(training_set,sampler=training_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

validation_sampler = BatchSampler(RandomSampler(validation_set),batch_size=batch_size,drop_last=False)
validation_dataloader=DataLoader(validation_set,sampler=validation_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

pair_sampler = BatchSampler(RandomSampler(pair_set),batch_size=batch_size,drop_last=False)
pair_dataloader=DataLoader(pair_set,sampler=pair_sampler,batch_size=None,num_workers=num_workers,pin_memory=True)#,multiprocessing_context='spawn')

cellline_sampler = BatchSampler(RandomSampler(cellline_set),batch_size=batch_size,drop_last=False)
cellline_dataloader=DataLoader(cellline_set,sampler=cellline_sampler,batch_size=None,num_workers=num_workers, pin_memory=True)#,multiprocessing_context='spawn')

drug_sampler = BatchSampler(RandomSampler(drug_set),batch_size=batch_size,drop_last=False)
drug_dataloader=DataLoader(drug_set,sampler=drug_sampler,batch_size=None,num_workers=num_workers, pin_memory=True)#,multiprocessing_context='spawn')

both_sampler = BatchSampler(RandomSampler(both_set),batch_size=batch_size,drop_last=False)
both_dataloader=DataLoader(both_set,sampler=both_sampler,batch_size=None,num_workers=num_workers, pin_memory=True)#,multiprocessing_context='spawn')

density_path=base_directory+'NCI_ALMANAC_mono/Processed/Density_dic.pickle'
try:
    with open(density_path,"rb") as fr:    
        density_dic=pickle.load(fr)

except:
    precision=2
    density_dic=estimate_density(training_set.dataset.df.VIABILITY,precision)
    with open(density_path,"wb") as fw: 
        pickle.dump(density_dic, fw)

In [None]:
def set_bn_eval(module):
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module.eval()

In [None]:
alpha=1
beta=0.5
gamma=0.75
learning_rate=1e-2

loss_df=pd.read_csv(base_directory+'NCI60/Result/AdamW/Training_Log.csv',index_col=0)
best_epoch=loss_df.sort_values(by='Loss').index[0]
mono_model=MonotherapyModel(GeneSet_Dic_valid).to(device)
mono_model.load_state_dict(torch.load(base_directory+'NCI60/Result/AdamW/Model/'+str(best_epoch)+'.pt'))
mono_model.eval()
mode='Frozen'
if mode=='Frozen':
    child_list=[]
    for idx,(name, param) in enumerate(mono_model.named_parameters()):
        child_list.append(name)
    unfrozen_params=child_list[-20:]
    for param in mono_model.parameters():
        param.requires_grad = False
    for name, param in mono_model.named_parameters():
        if name in unfrozen_params:
            param.requires_grad = True
mono_model.apply(set_bn_eval)

num_epochs=1000
learning_rate=0.001

optimizer=AdamW(mono_model.parameters(),lr=learning_rate)
loss_fn = CustomLoss(density_dic,alpha, beta, gamma)

val_loss_dic={}
best_loss=np.inf #Initially, it is inf
threshold=0.0005
max_patience_lr=10 #Maximum 10 epochs can be trained without improvement
max_patience_es=20
patience_lr=max_patience_lr #Currently How many epochs can be trained without improvement
patience_es=max_patience_es

for epoch in range(num_epochs):
    train_mono(training_dataloader,mono_model,loss_fn,optimizer)
    torch.save(mono_model.state_dict(),base_directory+'NCI_ALMANAC_mono/Result/AdamW/Model/'+str(epoch)+'.pt')
    val_loss,pcc,rmse=test_mono(validation_dataloader,mono_model,loss_fn)
    val_loss_dic[epoch]=[val_loss,pcc,rmse,learning_rate]
    print('best_loss: '+str(best_loss))
    print('current_val_loss: '+str(val_loss))
    
    if(val_loss<best_loss-threshold):
        best_loss=val_loss
        patience_lr=max_patience_lr
        patience_es=max_patience_es
    else:
        patience_lr=patience_lr-1
        patience_es=patience_es-1
    if patience_lr<=0:
        patience_lr=max_patience_lr
        learning_rate=learning_rate*0.1
        optimizer.param_groups[0]['lr']=learning_rate
        print('Learning Rate is changed into '+str(learning_rate))
    if patience_es<=0:
        print('Early Stopping')
        break


In [None]:
df_tmp=pd.DataFrame.from_dict(val_loss_dic,orient='index')
log_df=pd.DataFrame({'Loss': [get_tensor_value(x) for x in df_tmp[0]],'PCC': df_tmp[1],'RMSE': df_tmp[2],'lr': df_tmp[3].values})
log_df.to_csv(base_directory+'NCI_ALMANAC_mono/Result/AdamW/Training_Log.csv')

with open(base_directory+'NCI_ALMANAC_mono/Result/AdamW/validation_df.pickle',"wb") as fw:    
    pickle.dump(validation_set.dataset.df, fw)