In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn 
import random
import os

import matplotlib.pyplot as plt 
from src.utils.g2d_diff_genodrug_dataset import *
from src.g2d_diff_ce import *
from ignite.handlers.param_scheduler import create_lr_scheduler_with_warmup



def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


 

In [None]:
import torch
torch.__version__

In [None]:
!nvidia-smi

## CE Pretraining Wrapper

In [None]:
import logging
import datetime
from torch import autograd
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F

import numpy
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.utils import shuffle
import copy


class G2D_DIFF_CE_Pretrain(nn.Module):
    def __init__(self, num_of_genotypes = 1, num_of_dcls = 5, cond_dim = 128,  drug_dim = 128, use_nest_info = True,  device = 'cuda'):
        super(G2D_DIFF_CE_Pretrain, self).__init__()
        
       

        self.cond_mapper = nn.Linear(cond_dim, cond_dim)
        self.drug_mapper = nn.Linear(drug_dim, cond_dim)
        self.drug_encoder = DrugEncoder(input_dim = drug_dim,  device = device)

        ## If you want to get attention results, pass get_att = True here
        self.condition_encoder = Condition_Encoder(num_of_genotypes=num_of_genotypes, num_of_dcls=num_of_dcls, \
                                                   device = device, neighbor_info = use_nest_info, get_att = False)
          
        
        
        self.num_of_genotypes = num_of_genotypes
        self.num_of_dcls = num_of_dcls
        self.condim = cond_dim
        self.device_name = device
        
        
       

        
    def train(self, dataset_obj, collate_fn, train_config: Dict, sampler):

        num_neg = train_config['num_neg']
        warmup_epoch = train_config['warmup_epoch']
        temperature = train_config['cons_temp']
        lr = train_config['lr']
        batch_size = train_config['batch_size']
        epochs = train_config['epoch']
        max_step = train_config['max_step']
        train_config['current_epoch'] = 0
        seed = train_config['seed']
        
        
        sim_loss = nn.CrossEntropyLoss()

        
        C_solver = optim.Adam(list(self.cond_mapper.parameters())+list(self.drug_mapper.parameters())+\
                              list(self.condition_encoder.parameters())+list(self.drug_encoder.parameters()),lr=lr)
        
           
        lr_scheduler = create_lr_scheduler_with_warmup(optim.lr_scheduler.LambdaLR(C_solver, lr_lambda=[lambda epoch: 1]),
                                               warmup_start_value=0.0,
                                               warmup_duration= warmup_epoch * max_step,
                                               warmup_end_value=lr)
        
        


        
        tr_loader = DataLoader(dataset_obj, batch_size=batch_size, drop_last=True, collate_fn=collate_fn, sampler = sampler)
        
        for epoch in range(train_config['current_epoch'], epochs):
            
            C_losses = []
            print ("Epoch: %d" %epoch)
            
            
            for i, batch in tqdm(enumerate(tr_loader), total = max_step):
                lr_scheduler(None)
                
                ## Batch data load to device
                for key in batch.keys():
                    if 'genotype' in key:
                        for mut in batch[key].keys():
                            batch[key][mut] = batch[key][mut].to(self.device_name)
                    elif key == 'cell_name':
                        None
                    elif key == 'drug_name':
                        None
                    else:
                        batch[key] = batch[key].to(self.device_name)
   
                ## Final gene embeddings, Final cond encoding, Final layer attention, Whole attention list
                _, cond_orig, _, _ = self.condition_encoder(batch)
                drug_orig = self.drug_encoder(batch['drug'])
            
                cond_feats = self.cond_mapper(cond_orig)
                drug_feats = self.drug_mapper(drug_orig)

                cfeat_norm = nn.functional.normalize(cond_feats, dim = -1)
                dfeat_norm = nn.functional.normalize(drug_feats, dim = -1)
                
             
                scores = torch.mm(cfeat_norm, dfeat_norm.transpose(0, 1))
                scores1 = scores / temperature
                
                sim_masks = []
                for k in range(batch_size):
                    mask1 = torch.logical_and((batch['class'] == batch['class'][k]), torch.BoolTensor((np.array(batch['cell_name']) == batch['cell_name'][k])).to(self.device_name))
                    mask2 = torch.BoolTensor((np.array(batch['drug_name']) == batch['drug_name'][k])).to(self.device_name)
                    mask1[k] = 0
                    mask2[k] = 0
                    sim_masks.append(mask1.reshape((1, -1)) + mask2.reshape((1, -1)))
                sim_masks = torch.cat(sim_masks, dim = 0).to(self.device_name)
                scores1.data.masked_fill_(sim_masks, -float('inf'))
                
                
                neg_masks = []
                for k in range(batch_size):
                    mask1 = torch.rand(batch_size, device = self.device_name) < (1 - (float(num_neg) / batch_size))
                    mask1[k] = 0
                    neg_masks.append(mask1.reshape((1, -1)))
                neg_masks = torch.cat(neg_masks, dim = 0).to(self.device_name)
                scores1.data.masked_fill_(neg_masks, -float('inf'))
                
                scores2 = scores1.transpose(0, 1)
                labels = Variable(torch.LongTensor(range(batch_size))).to(self.device_name)
             
                
                
                
                similarity_loss1 = sim_loss(scores1, labels)
                similarity_loss2 = sim_loss(scores2, labels)
                

                C_loss = 0.9 * similarity_loss1 + 0.1 * similarity_loss2 
 
                C_solver.zero_grad()
                C_loss.backward()
        
                C_solver.step()
                
                C_losses.append(C_loss.detach().cpu().numpy())
                
                
                if i == max_step:
                    break

            
            t= int(len(np.array(C_losses)) / 1)
            x = []
            for i in range(int(len(np.array(C_losses)) / t)):
                x.append(np.mean(np.array(C_losses)[i*t:(i+1)*t]))
            fin_c_losses += x
            
            train_config['current_epoch'] = epoch
            ckpt_dict = {
            'condition_state_dict': self.condition_encoder.state_dict(),
            'dencoder_state_dict': self.drug_encoder.state_dict(),
            'cond_mapper_state_dict': self.cond_mapper.state_dict(),
            'drug_mapper_state_dict': self.drug_mapper.state_dict(),
            'csolver_state_dict': C_solver.state_dict(),
            'C_losses': fin_c_losses,
            'configs' : train_config
            }

        ## Use here to save the model
        #torch.save(ckpt_dict, "../DAS_DATA/model_ckpts/reproduce_10/seed_"+str(seed)+"_0914_%d.pth"%(epoch))
              
      
        return None
    
  
        

## Data Loading

In [None]:

import pickle
PREDIFINED_GENOTYPES = ['mut', 'cna', 'cnd']



nci_data = pd.read_csv("data/drug_response_data/CC_drug_response.csv")
nci_data = nci_data.dropna()

valid_celllines = ['TK10_KIDNEY', 'OVCAR5_OVARY', 'HOP92_LUNG', 'SKMEL2_SKIN', 'HS578T_BREAST']

nci_data_train = nci_data[~nci_data['ccle_name'].isin(valid_celllines)]
nci_data_val = nci_data[nci_data['ccle_name'].isin(valid_celllines)]



cell2mut = pd.read_csv("data/drug_response_data/original_cell2mut.csv", index_col = 0).rename(columns={'index':'ccle_name'})
cell2cna = pd.read_csv("data/drug_response_data/original_cell2cna.csv", index_col = 0).rename(columns={'index':'ccle_name'})
cell2cnd = pd.read_csv("data/drug_response_data/original_cell2cnd.csv", index_col = 0).rename(columns={'index':'ccle_name'})


drug2smi = pd.read_csv("data/drug_response_data/CC_drug2smi.csv")






## Training

In [None]:
seed = 44
set_seed(seed)
train_config = {}
train_config['lr'] = 5e-5
train_config['batch_size'] = 128
train_config['epoch'] = 53
train_config['max_step'] = 2000
train_config['num_neg'] = 10
train_config['warmup_epoch'] = 2
train_config['cons_temp'] = 0.3
train_config['seed'] = seed

device = "cuda:0"
framework = G2D_DIFF_CE_Pretrain(num_of_genotypes = 3, num_of_dcls = 5, cond_dim = 128,  device = device)
framework.to(device).to(torch.float)

dataset_obj = GenoDrugDataset(nci_data_train, cell2mut, drug2smi, cna=cell2cna, cnd=cell2cnd)
collate_fn = GenoDrugCollator(genotypes=PREDIFINED_GENOTYPES)

class_count = []
for i in range(5):
    class_count.append(len(nci_data_train[nci_data_train['auc_label']==i]))
class_count = np.array(class_count)
weight = ( 1. / class_count ) 
samples_weight = np.array([weight[t] for t in nci_data_train['auc_label']])
samples_weight = torch.from_numpy(samples_weight) 

sampler = torch.utils.data.WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))

framework.train(dataset_obj = dataset_obj, collate_fn = collate_fn, \
                                 train_config = train_config, sampler = sampler)