In [None]:
%run 00_Utils.ipynb
%run 00_MonotherapyUtils.ipynb

In [None]:
import torch
import random
import datetime
import time
import matplotlib.pyplot as plt
import pickle

import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import csv

import torch.multiprocessing as mp

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, BatchSampler, RandomSampler,random_split
from torch.multiprocessing import Pool, Process, set_start_method

from tqdm import tqdm

import scipy.stats

import os.path


In [None]:
class CombinationDataset(Dataset):
    def __init__(self,dataframe, pathway_attention_df, viability_df):
        self.df = dataframe
        self.pathway_attention_df = pathway_attention_df
        self.viability_df = viability_df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        df_tmp=self.df.iloc[idx]
        #df_mono1=df_tmp[['NSC1','CONCENTRATION1','CELLNAME']]
        #df_mono2=df_tmp[['NSC2','CONCENTRATION2','CELLNAME']]
        idx1_for_attention=pd.MultiIndex.from_frame(df_tmp[['NSC1','CELLNAME']])
        idx2_for_attention=pd.MultiIndex.from_frame(df_tmp[['NSC2','CELLNAME']])
        idx1_for_viability=pd.MultiIndex.from_frame(df_tmp[['NSC1','CONCENTRATION1','CELLNAME']])
        idx2_for_viability=pd.MultiIndex.from_frame(df_tmp[['NSC2','CONCENTRATION2','CELLNAME']])
        attention1=torch.Tensor(np.array([x for x in self.pathway_attention_df.loc[idx1_for_attention].values]))
        attention2=torch.Tensor(np.array([x for x in self.pathway_attention_df.loc[idx2_for_attention].values]))
        viability1=torch.Tensor(self.viability_df.loc[idx1_for_viability].values).type(torch.float)
        viability2=torch.Tensor(self.viability_df.loc[idx2_for_viability].values).type(torch.float)
        data=[attention1, attention2, viability1, viability2]
        
        viability=torch.Tensor(df_tmp[['VIABILITY']].values).type(torch.float)
        return data,viability

In [None]:
def CombinationDataset2device(sample,device):
    # sample = [data, viability]
    # data = [pathway_attention1, pathway_attention2, viability1, viability2]
    data, viability = sample
    pathway_attention1, pathway_attention2, viability1, viability2 = data
    pathway_attention1 = pathway_attention1.to(device)
    pathway_attention2 = pathway_attention2.to(device)
    viability1 = viability1.to(device)
    viability2 = viability2.to(device)
    data_n = [pathway_attention1, pathway_attention2, viability1, viability2] 

    viability = viability.to(device)
    return data_n, viability

In [None]:
class CombinationTherapyModel(nn.Module):
    def __init__(self):
        super(CombinationTherapyModel, self).__init__()

        self.num_pathway=186
        
        #Pathway activity embedding network
        embedding_modules = []
        embedding_modules.append(nn.Linear(self.num_pathway, self.num_pathway))
        embedding_modules.append(nn.BatchNorm1d(self.num_pathway))
        embedding_modules.append(nn.ReLU())
        embedding_modules.append(nn.Tanh())
        embedding_modules.append(nn.Softmax(dim = 1))
        self.embedding_network = nn.Sequential(*embedding_modules)
        
        #Monotherapy coefficient network
        monotherapy_modules = []
        monotherapy_modules.append(nn.Linear(self.num_pathway, 64))
        monotherapy_modules.append(nn.BatchNorm1d(64))
        monotherapy_modules.append(nn.ReLU())
        monotherapy_modules.append(nn.Linear(64, 16))
        monotherapy_modules.append(nn.BatchNorm1d(16))
        monotherapy_modules.append(nn.ReLU())
        monotherapy_modules.append(nn.Linear(16, 4))
        monotherapy_modules.append(nn.BatchNorm1d(4))
        monotherapy_modules.append(nn.ReLU())
        monotherapy_modules.append(nn.Linear(4, 1))
        self.monotherapy_network = nn.Sequential(*monotherapy_modules)
        
        self.monotherapy_coefficient_network = nn.Softmax(dim=1)

        #Synergy coefficient network
        synergy_modules = []
        synergy_modules.append(nn.Linear(2 * self.num_pathway, 64))
        synergy_modules.append(nn.BatchNorm1d(64))
        synergy_modules.append(nn.ReLU())
        synergy_modules.append(nn.Linear(64, 16))
        synergy_modules.append(nn.BatchNorm1d(16))
        synergy_modules.append(nn.ReLU())
        synergy_modules.append(nn.Linear(16, 4))
        synergy_modules.append(nn.BatchNorm1d(4))
        synergy_modules.append(nn.ReLU())
        synergy_modules.append(nn.Linear(4, 1))
        self.synergy_network = nn.Sequential(*synergy_modules)
        
        self.efficacy_relu = nn.ReLU()
        self.viability_relu = nn.ReLU()
              
    def forward(self, input_feature):
        #input_feature = [pathway_attention1, pathway_attention2, viability1 ,viability2]
        pathway_attention1 = input_feature[0]
        pathway_attention2 = input_feature[1]
        viability1 = input_feature[2]
        viability2 = input_feature[3]
        
        #efficacy of each drug=1-viability 
        efficacy1 = torch.sub(1, viability1)
        efficacy2 = torch.sub(1, viability2)
        
        embedding2to1 = self.embedding_network(pathway_attention2)
        embedding1to2 = self.embedding_network(pathway_attention1)
        
        pathway1with2 = torch.multiply(pathway_attention1, embedding2to1)
        pathway2with1 = torch.multiply(pathway_attention2, embedding1to2)

        processed_pathway1 = self.monotherapy_network(pathway1with2)
        processed_pathway2 = self.monotherapy_network(pathway2with1)
        
        processed_pathways = torch.cat([processed_pathway1, processed_pathway2],dim=1)
        efficacies = torch.cat([efficacy1, efficacy2],dim=1)
        processed_pathways_with_efficacies = torch.multiply(processed_pathways, efficacies)
        coefficient_mono = self.monotherapy_coefficient_network(processed_pathways_with_efficacies)
        monotherapy_effect = batch_dot(coefficient_mono, efficacies)
        
        pathway_attention1_effected=torch.mul(pathway_attention1,efficacy1)
        pathway_attention2_effected=torch.mul(pathway_attention2,efficacy2)

        pathway_concat = torch.cat([pathway_attention1_effected, pathway_attention2_effected], dim=1)
        synergy_effect = self.synergy_network(pathway_concat)

        combination_efficacy = torch.add(monotherapy_effect, synergy_effect)
        combination_viability = torch.sub(1, combination_efficacy)
        
        return coefficient_mono, synergy_effect, combination_viability

In [None]:
def train_comb(dataloader, model, loss_fn, optimizer):
    model.train()
    num_batches = len(dataloader)
    report_epoch=int(num_batches/10)
    current_loss=0
    for batch, sample in enumerate(dataloader):
        X,y=CombinationDataset2device(sample,device)
        coefficient_mono, gamma, pred = model(X)
        loss,mse,corr = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        current_loss += loss

def test_comb(dataloader, model, loss_fn):
    model.train()
    X,y = CombinationDataset2device(sample,device)
    tmp1,tmp2,tmp3 = model(X)
    model.eval()
    num_batches = len(dataloader)
    test_loss = 0
    real_list=[]
    predicted_list=[]
    with torch.no_grad():
        for sample in dataloader:
            X,y=CombinationDataset2device(sample,device)            
            real_list.append(y)
            coefficient_mono, gamma, pred = model(X)
            predicted_list.append(pred)
            loss,mse,corr=loss_fn(pred,y)
            test_loss += loss
    test_loss /= num_batches
    real_concat=get_tensor_value(torch.cat(real_list).view(-1))
    predicted_concat=get_tensor_value(torch.cat(predicted_list).view(-1))
    
    pcc=stats.pearsonr(predicted_concat,real_concat)[0]
    rmse=mean_squared_error(predicted_concat,real_concat)**0.5
    r2=r2_score(predicted_concat,real_concat)
    
    print('========Test result========')
    print('----Average Total Loss: '+ str(get_tensor_value(test_loss)))
    print('----PCC '+ str(pcc))
    print('----RMSE: '+ str(rmse))
    print('----R2: '+str(r2))
    return test_loss,pcc,rmse

def predict_comb(dataloader, model):
    model.train()
    X,y = CombinationDataset2device(sample,device)
    tmp1,tmp2,tmp3 = model(X)
    model.eval()
    num_batches = len(dataloader)
    real_list=[]
    predicted_list=[]
    with torch.no_grad():
        for sample in dataloader:
            X,y=CombinationDataset2device(sample,device)
            real_list.append(y)
            coefficient_mono, gamma, pred = model(X)
            predicted_list.append(pred)
    real_concat=torch.cat(real_list).view(-1)
    predicted_concat=torch.cat(predicted_list).view(-1)

    return real_concat, predicted_concat