In [4]:
%run 00_Utils.ipynb

In [1]:
import torch
import random
import datetime
import time
import matplotlib.pyplot as plt
import pickle

import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import csv

import torch.multiprocessing as mp

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, BatchSampler, RandomSampler,random_split
from torch.multiprocessing import Pool, Process, set_start_method

from tqdm import tqdm

import scipy.stats

import os.path


In [2]:
class MonotherapyDataset(Dataset):
    def __init__(self,dataframe, pathway_list, geneexpression_df_by_cellline, fingerprint_df):
        self.df=dataframe
        self.pathway_list = pathway_list
        self.geneexpression_df_by_cellline=geneexpression_df_by_cellline
        self.fingerprint_df=fingerprint_df
                
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        time0=time.time()
        data=[]
        df_tmp=self.df.iloc[idx]
        cellline_targets=df_tmp['CELLNAME'].values
        features_by_cellline=self.geneexpression_df_by_cellline.loc[cellline_targets]
        cellline_feature=[torch.Tensor(np.array([x for x in features_by_cellline[pathway].values])).type(torch.float) for pathway in self.pathway_list]
        data.append(cellline_feature)
        drug_targets=df_tmp['NSC'].values
        drug_feature=torch.Tensor(self.fingerprint_df.loc[drug_targets].values).type(torch.float)
        data.append(drug_feature)
        concentration_feature=torch.Tensor(df_tmp[['CONCENTRATION']].values).type(torch.float)
        data.append(concentration_feature)
        viability=torch.Tensor(df_tmp[['VIABILITY']].values).type(torch.float)
        time1=time.time()
        return data,viability

In [3]:
def MonotherapyDataset2device(sample,device):
    # sample = [data, viability]
    # data = [geneset_data,fingerprint_data,concentration]
    # geneset_data = [geneset1, geneset2, ..., geneset186]
    device=torch.device(device)
    data, viability = sample
    for idx, pathway in enumerate(data[0]):
        data[0][idx] = pathway.to(device)
    data[1] = data[1].to(device)
    data[2] = data[2].to(device)
    viability = viability.to(device)
    return data, viability

In [8]:
def batch_dot(tensor1,tensor2,batch_size=1024):
    return (tensor1[None]*tensor2).sum(dim=-1).view(-1,1)

In [19]:
class MonotherapyModel(nn.Module):
    def __init__(self,GeneSet):
        super(MonotherapyModel, self).__init__()

        #GeneSet: Dictionary whose key is pathway name and its values are the member gene for each pathway (key)
        self.GeneSet=GeneSet
        self.pathway_list=list(GeneSet.keys())
        self.num_pathway=len(self.pathway_list)
        
        drug_modules = []
        #1st drug layer
        drug_modules.append(nn.Linear(512, 256))
        drug_modules.append(nn.BatchNorm1d(256))
        drug_modules.append(nn.ReLU())
        #2nd drug layer
        drug_modules.append(nn.Linear(256, 128))
        drug_modules.append(nn.BatchNorm1d(128))
        drug_modules.append(nn.ReLU())
        #create drug block
        self.drug_block = nn.Sequential(*drug_modules)

        
        drug_modules_new = []
        #new 1st layer
        drug_modules_new.append(nn.Linear(512, 128))
        drug_modules_new.append(nn.BatchNorm1d(128))
        drug_modules_new.append(nn.ReLU())
        #new 2nd drug layer
        drug_modules_new.append(nn.Linear(128, 32))
        drug_modules_new.append(nn.BatchNorm1d(32))
        drug_modules_new.append(nn.ReLU())
        #create new drug block
        self.new_drug_block = nn.Sequential(*drug_modules_new)
        
        
        #operations like mul, add, subtract, dot are only possible in the forward so safe sub models in lists
        #gene set model
        self.drug_gene_set_blocks = nn.ModuleDict()
        self.gene_attention_blocks = nn.ModuleDict()
        self.gene_dot_blocks = nn.ModuleDict()
        for pathway in self.pathway_list:
            #input_size = number of member genes per pathway
            input_size = int(len(self.GeneSet[pathway]))
            drug_for_pathway_size = int(input_size/4)+1
            
            drug_gene_set_modules = []   
            #add layers to module
            drug_gene_set_modules.append(nn.Linear(32, drug_for_pathway_size))
            drug_gene_set_modules.append(nn.BatchNorm1d(drug_for_pathway_size))
            drug_gene_set_modules.append(nn.ReLU())
            drug_gene_set_block = nn.Sequential(*drug_gene_set_modules)
            #store geneset sub block
            self.drug_gene_set_blocks[pathway]=drug_gene_set_block
            
            gene_attention_modules = []
            #add layers to module
            gene_attention_modules.append(nn.Linear(input_size + drug_for_pathway_size, input_size))
            gene_attention_modules.append(nn.BatchNorm1d(input_size))
            gene_attention_modules.append(nn.Tanh())
            gene_attention_modules.append(nn.Softmax(dim=1))
            gene_attention_block = nn.Sequential(*gene_attention_modules)
            #store gene attention sub block
            self.gene_attention_blocks[pathway]=gene_attention_block
            
            gene_dot_modules = []
            #add layers to module -> dot product will have size of 1
            gene_dot_modules.append(nn.BatchNorm1d(1))
            gene_dot_modules.append(nn.ReLU())
            gene_dot_block = nn.Sequential(*gene_dot_modules)
            #safe attention sub block
            self.gene_dot_blocks[pathway]=gene_dot_block
                
        drug_for_pathway_size=int(self.num_pathway/16+1)
        #drug dense sample layers
        drug_dense_sample_modules = []
        drug_dense_sample_modules.append(nn.Linear(32, drug_for_pathway_size))
        drug_dense_sample_modules.append(nn.BatchNorm1d(drug_for_pathway_size))
        drug_dense_sample_modules.append(nn.ReLU())
        #drug dense sample block
        self.drug_dense_sample_block = nn.Sequential(*drug_dense_sample_modules)
        
        #sample attention layers
        sample_attention_modules = []
        sample_attention_modules.append(nn.Linear(self.num_pathway + drug_for_pathway_size,self.num_pathway))
        sample_attention_modules.append(nn.BatchNorm1d(self.num_pathway))
        sample_attention_modules.append(nn.Tanh())
        sample_attention_modules.append(nn.Softmax(dim=1))
        #sample attention block
        self.sample_attention_block = nn.Sequential(*sample_attention_modules)
        
        #sample multiplied layers
        sample_multiplied_modules = []
        sample_multiplied_modules.append(nn.BatchNorm1d(self.num_pathway))
        sample_multiplied_modules.append(nn.ReLU())
        self.sample_multiplied_block = nn.Sequential(*sample_multiplied_modules)
        
        #concatenated model layers
        concatenated_modules = []
        concatenated_modules.append(nn.Linear(128+self.num_pathway, 128))
        concatenated_modules.append(nn.BatchNorm1d(128))
        concatenated_modules.append(nn.ReLU())
        self.concatenated_block = nn.Sequential(*concatenated_modules)

        final_modules = []
        #1st final layer
        final_modules.append(nn.Linear(128, 32))
        final_modules.append(nn.BatchNorm1d(32))
        final_modules.append(nn.ReLU())
        #2nd final layer
        final_modules.append(nn.Linear(32, 8))
        final_modules.append(nn.BatchNorm1d(8))
        final_modules.append(nn.ReLU())
        #3rd final layer
        final_modules.append(nn.Linear(8, 2))
        final_modules.append(nn.BatchNorm1d(2))
        final_modules.append(nn.ReLU())
        #create final block
        self.final_block = nn.Sequential(*final_modules)
        
        #curve parameter layer
        self.final_y_max = nn.Linear(2, 1)
        
        self.final_y_min = nn.Linear(2, 1)
        self.final_slope = nn.Linear(2, 1)       
        self.final_IC50 = nn.Linear(2, 1)
        
    def forward(self, input_feature):
        #input_feature = [gene_set_dic, drug_fp, dose]
        #-gene_set_dic: dictionary of gene expressions for each pathway 
        gene_expression_list=input_feature[0]
        #-drug_fp: Morgan fingerprint of input drug
        drug_fp=input_feature[1]
        #-dose: dosage information
        dose=input_feature[2]
        drug_embed = self.drug_block(drug_fp)
        new_drug_embed = self.new_drug_block(drug_fp)
        #gene set model
        attention_dots = []
        
        for idx,pathway in enumerate(self.pathway_list):

            #gene set calculation
            gene_expression=gene_expression_list[idx]
            drug_gene_set_model=self.drug_gene_set_blocks[pathway]
            gene_attention_model=self.gene_attention_blocks[pathway]
            gene_dot_model=self.gene_dot_blocks[pathway]
            
            drug_gene_set_embed = drug_gene_set_model(new_drug_embed)
            #drug feature to gene attention
            gene_concat = torch.cat((gene_expression, drug_gene_set_embed), dim=1)
            
            #gene attention calculation
            gene_attention = gene_attention_model(gene_concat)
            
            attention_dot = batch_dot(gene_expression, gene_attention)
            gene_dot = gene_dot_model(attention_dot)

            #gene set model
            attention_dots.append(gene_dot)

        drug_dense_embed = self.drug_dense_sample_block(new_drug_embed)
        drug_effected_model_for_attention = attention_dots.copy()  
        drug_effected_model_for_attention.append(drug_dense_embed)

        gene_set_concat = torch.cat(attention_dots,dim=1)
        drug_effected_concat = torch.cat(drug_effected_model_for_attention,dim=1)

        sample_attention = self.sample_attention_block(drug_effected_concat)
        
        sample_multiplied = torch.mul(gene_set_concat, sample_attention)

        total_concat = torch.cat([sample_multiplied, drug_embed],dim=1)
        
        concat_embed = self.concatenated_block(total_concat)
        
        final_embed = self.final_block(concat_embed)
        
        final_y_max = self.final_y_max(final_embed)
        final_y_min = self.final_y_min(final_embed)
        final_slope = self.final_slope(final_embed)
        final_ic50 = self.final_IC50(final_embed)
          
        #final calculations    
        final_1 = torch.sub(dose, final_ic50)
        final_2 = torch.mul(final_slope, final_1)
        final_neg = torch.mul(final_2, -1)
        final_sigmoid = torch.sigmoid(final_neg)
        final_scale = torch.sub(final_y_max, final_y_min)
        final_3 = torch.mul(final_scale, final_sigmoid)
        final_4 = torch.add(final_3, final_y_min)
        return final_4

In [None]:
def train_mono(dataloader, model, loss_fn, optimizer):
    model.train()
    num_batches = len(dataloader)
    report_epoch = int(num_batches/10)
    time_list=[]
    time0=time.time()
    print('Training function called at '+str(timestamp2datetime(time0)))
    time1=time.time()
    time2=time.time()
    current_loss = 0
    for batch, sample in enumerate(dataloader):
        time0=time.time()
        if batch%report_epoch==0:
            print('==========Current batch is '+str(batch)+'==========')
        X,y=MonotherapyDataset2device(sample,device)
        pred = model(X)
        loss,mse,corr = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        time2=time.time()
        current_loss+=loss
        time_list.append(time2-time0)
        if batch%report_epoch==0:
            print('Currently, each batch takes '+str(np.around(np.mean(time_list),1))+' seconds')
            print('rmse: '+ str(get_tensor_value(mse)**0.5)+'  pcc: '+str(get_tensor_value(corr))+ '   average loss: '+str(get_tensor_value(current_loss /(batch+1))))
    print('Time per batch: '+str(np.mean(time_list[1:]))+' seconds')

def test_mono(dataloader, model, loss_fn, num_initial_run=1):
    model.eval()
    num_batches = len(dataloader)
    test_loss = 0
    real_list=[]
    predicted_list=[]
    with torch.no_grad():
        for sample in dataloader:
            X,y=MonotherapyDataset2device(sample,device)
            real_list.append(y)
            pred = model(X)
            predicted_list.append(pred)
            loss,mse,corr=loss_fn(pred,y)
            test_loss += loss
    test_loss /= num_batches
    real_concat=get_tensor_value(torch.cat(real_list).view(-1))
    predicted_concat=get_tensor_value(torch.cat(predicted_list).view(-1))
    
    pcc=stats.pearsonr(predicted_concat,real_concat)[0]
    rmse=mean_squared_error(predicted_concat,real_concat)**0.5
    r2=r2_score(predicted_concat,real_concat)
    
    print('========Test result========')
    print('----Average Total Loss: '+ str(get_tensor_value(test_loss)))
    print('----PCC '+ str(pcc))
    print('----RMSE: '+ str(rmse))
    print('----R2: '+str(r2))
    return test_loss,pcc,rmse

def predict_mono(dataloader, model, num_initial_run=1):
    model.eval()
    num_batches = len(dataloader)
    real_list=[]
    predicted_list=[]
    with torch.no_grad():
        for sample in dataloader:
            X,y=MonotherapyDataset2device(sample,device)
            real_list.append(y)
            pred = model(X)
            predicted_list.append(pred)
    real_concat=torch.cat(real_list).view(-1)
    predicted_concat=torch.cat(predicted_list).view(-1)

    return real_concat, predicted_concat
