In [1]:
import numpy as np
import pandas as pd
import random
import learn2learn as l2l
import copy
from copy import deepcopy
from matplotlib import pyplot
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from matplotlib import pyplot
import pickle as pkl
from sklearn.metrics import roc_curve,auc
import matplotlib.pyplot as plt
import pickle
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, Dataset, random_split, ConcatDataset
from torch.autograd import Variable
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn import metrics 
from sklearn.metrics import roc_auc_score, roc_curve,precision_recall_curve, auc
from sklearn.metrics import f1_score

## Loading processed data

In [2]:
#cell line level data
df =  pd.read_pickle("E:/data/cl_label_data.pickle") 
drug_features =  pd.read_pickle('E:/data/drug_feature.pickle')
cell_features =  pd.read_pickle('E:/data/cellline_feature.pickle')
cuda=True
device = torch.device('cuda'if torch.cuda.is_available() else "cpu")

#patient level data

## Model

In [3]:
#Two layers of fully connected layers
class FC2(nn.Module):
    def __init__(self, in_features, out_features, dropout):
        super(FC2, self).__init__()
        
        self.fc1 = nn.Linear(in_features, int(in_features/2))
        self.fc2 = nn.Linear(int(in_features/2),out_features)
        self.dropout= nn.Dropout(dropout)
                
    def forward(self, x):
        x = self.dropout(x)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        
        return x
    
#Classification predictor
class COMBFC2(nn.Module):
    def __init__(self, in_features, out_features, dropout):
        super(COMBFC2, self).__init__()
        
        self.fc1 = nn.Linear(in_features, int(in_features/2))
        self.fc2 = nn.Linear(int(in_features/2), int(in_features/2))
        self.fc3= nn.Linear(int(in_features/2),out_features)
        self.dropout= nn.Dropout(dropout)
        self.sigmoid= nn.Sigmoid()
                
    def forward(self, x):

        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(self.fc2(x))
        x = self.fc3(x)
        x = self.sigmoid(x)
        
        return x

#DrugEncoder
class DrugEncoder(nn.Module):
    def __init__(self,
                drug_sm_conv = 96,
                 num_comp_char=48,
                 out_size=64,
                 dropout=0.3):
        super(DrugEncoder, self).__init__()
        
        self.dropout= dropout
        
        #smiles embedding 
        self.embed_comp = nn.Embedding(num_comp_char,num_comp_char, padding_idx=0)#padding's idx=0,
        #smiles 卷积
        self.conv1 = nn.Conv1d(in_channels=48,out_channels=32,kernel_size=4) ##stride=1
        self.conv2 = nn.Conv1d(in_channels=32,out_channels=64,kernel_size=6)
        self.conv3 = nn.Conv1d(in_channels=64,out_channels=96,kernel_size=8)
        self.maxpool = nn.AdaptiveMaxPool1d(1) 
 
     
        self.FC2 = FC2(96, out_size, dropout)
        
    def forward(self, d_list):
        """
            id: bsz*1
            fp: bsz*num_drug_fp
            sm: bsz*drug_sm_len
        """
        id, sm = d_list
        
        sm = self.embed_comp(sm) #becomes 2-dimensional after embedding
#         sm = sm.reshape(bsz,drug_sm_len,num_comp_char) 
        sm = sm.permute(0,2,1) #Dimension conversion (batch size,in_channels: dimension of word vector, data length: sentence length)
        sm = F.relu(self.conv1(sm))
        sm = F.relu(self.conv2(sm))
        sm = F.relu(self.conv3(sm))
        sm = self.maxpool(sm)
        sm = sm.squeeze(-1) 
        
        
        x = self.FC2(sm)
        
        return x
        
#CellEncoder      
class CellEncoder(nn.Module):
    def __init__(self,
               out_size=64,
               dropout=0.3):
        super(CellEncoder,self).__init__()
        
        self.dropout = dropout
        

        self.fc1 = nn.Linear(14753,2950)
        self.fc2 = nn.Linear(2950,590)
        self.fc3 = nn.Linear(590,64)
        self.dropout= nn.Dropout(dropout)
        
        
        
    def forward(self,c_list):
        id,ge  = c_list
        
        ge = self.dropout(ge)
        ge = self.dropout(F.relu(self.fc1(ge)))
        ge = self.dropout(F.relu(self.fc2(ge)))
        x = self.fc3(ge)

        return x

class Comb(nn.Module):
    def __init__(self,
              out_size = 64,
              dropout = 0.3):
        super(Comb, self).__init__()
        
        self.dropout = dropout
        #drug 
        self.DrugEncoder = DrugEncoder()
        #cell
        self.CellEncoder = CellEncoder()
        #fc
        self.fc_response = COMBFC2(out_size*3, 1, dropout) 
        
    def forward(self,d1_list,d2_list,c_list):
        d1 = self.DrugEncoder(d1_list)
        d2 = self.DrugEncoder(d2_list)
        c = self.CellEncoder(c_list)
        alll = torch.cat((d1, d2, c),1)
        y = self.fc_response(alll)
        
        return y
    

## Dataset

In [4]:
##取样本及其对应特征
class DrugCombDataset(Dataset):
    def __init__(self, df, drug_features, cell_features):
        self.df = df
        self.drug_features = drug_features
        self.cell_features = cell_features
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        
        d1 = self.df.iloc[idx, 0]
        d2 = self.df.iloc[idx, 1]
        cell = self.df.iloc[idx,2]
        label = self.df.iloc[idx,3]
       
        
        #external feature
        d1_sm = np.array(self.drug_features.loc[d1,'smiles'])

        
        d2_sm = np.array(self.drug_features.loc[d2,'smiles'])

        
        c_ge= np.array(self.cell_features.iloc[cell][:])  ##表达谱信息放置
        
        sample = {
            'd1': d1,
            'd1_sm': d1_sm,

            
            'd2': d2,
            'd2_sm': d2_sm,

            
            'cell': cell,
            'c_ge': c_ge,
            
#             'css': css,
            'label':label
        }
        
        return sample

In [6]:
singledrug = df['cell_line_name'].value_counts().reset_index()
singledrug.columns = ['cell_line_name','frequency']
singledrug_s = singledrug[singledrug['frequency'] > 50]
singledrug_t = singledrug[(singledrug['frequency'] < 50)&(singledrug['frequency'] > 10) ]
source_data = df[df['cell_line_name'].isin(singledrug_s['cell_line_name'])]
target_data = df[df['cell_line_name'].isin(singledrug_t['cell_line_name'])]
train_tasks = []
cell_lines = source_data['cell_line_name'].unique()
for cell_line in cell_lines:
    task = source_data[source_data['cell_line_name'] == cell_line]
    train_tasks.append(task)
test_tasks = []
cell_lines = target_data['cell_line_name'].unique()
for cell_line in cell_lines:
    task = target_data[target_data['cell_line_name'] == cell_line]
    test_tasks.append(task)

In [7]:
def label_sampling(df,k):
    df_positive = df[df['label'] == 1]
    df_negative = df[df['label'] == 0]

    num_positive = df_positive.shape[0]
    num_negative = df_negative.shape[0]

    num_positive_sample = min(num_positive, k)
    num_negative_sample = min(num_negative, k)

    if num_positive_sample < k:
        num_negative_sample = min(num_negative, 2*k - num_positive_sample)
    elif num_negative_sample < k:
        num_positive_sample = min(num_positive, 2*k - num_negative_sample)

    sample_positive = df_positive.sample(n=num_positive_sample)
    sample_negative = df_negative.sample(n=num_negative_sample)

    sample_df = pd.concat([sample_positive, sample_negative])
    
    return sample_df

## Meta-training

#### 源域的MAML

In [13]:
net_meta = Comb().to(device)
meta_lr = 1e-3
fast_lr = 0.1
criterion = nn.BCELoss()
maml = l2l.algorithms.MAML(net_meta, lr=fast_lr,allow_unused=True) 
meta_optimizer = optim.Adam(maml.parameters(), lr=meta_lr)

In [23]:
for counter in range(300):
    t0 = time.time()
    total_loss = 0.0
    all_train_task = []
    for i in range(0,len(train_tasks)):
        cell_task = train_tasks[i]
        cell_task  =label_sampling(cell_task,25)
        pos = cell_task[cell_task['label'] == 1]
        neg = cell_task[cell_task['label'] == 0]
        pos_train, pos_test = train_test_split(pos, test_size=0.5)
        neg_test, neg_train = train_test_split(neg, test_size=0.5) ##Balanced positive and negative sample numbers
        train = pd.concat([pos_train,neg_train],axis = 0)
        test = pd.concat([pos_test,neg_test],axis = 0)
        temp_dict={'train':train ,
               'test':test,
              }
        all_train_task.append(temp_dict)
        
    for j,d in enumerate(all_train_task):
        net_copy = maml.clone()
        support_set = d['train']
        querry_set = d['test']
        supportdata = DrugCombDataset(support_set,drug_features,cell_features)
        querrydata = DrugCombDataset(querry_set,drug_features,cell_features)
        support_loader = DataLoader(supportdata, batch_size=25, shuffle=True)
        querry_loader = DataLoader(querrydata, batch_size=25, shuffle=True)
        
        ##Gradient descent of support sets
        for _ in range(1):
            for iteration, sample in enumerate(support_loader):
                support_d1=Variable(sample['d1'])
                support_d1_sm = Variable(sample['d1_sm'])

        
                support_d2=Variable(sample['d2'])
                support_d2_sm = Variable(sample['d2_sm'])


        
                support_cell = Variable(sample['cell'])
                support_c_ge = Variable(sample['c_ge'].float())
        
                support_label = Variable(sample['label'].float())
        
        
                support_d1=support_d1.to(device)
                support_d1_sm = support_d1_sm.to(device)

         
                support_d2=support_d2.to(device)
                support_d2_sm = support_d2_sm.to(device)

                support_cell = support_cell.to(device)
                support_c_ge = support_c_ge.to(device)
        
                support_label = support_label.to(device)
                support_pred = net_meta((support_d1,support_d1_sm), (support_d2, support_d2_sm), (support_cell,support_c_ge))
                support_loss = criterion(support_pred, support_label.view(-1,1))
                support_pred = support_pred.cpu().detach().numpy()
                support_label = support_label.view(-1,1).cpu().detach().numpy()
    
                net_copy.adapt(support_loss)
                
        ##Accumulating loss of query sets    
        for iteration, sample in enumerate(querry_loader):
            querry_d1=Variable(sample['d1'])
            querry_d1_sm = Variable(sample['d1_sm'])

        
            querry_d2=Variable(sample['d2'])
            querry_d2_sm = Variable(sample['d2_sm'])


        
            querry_cell = Variable(sample['cell'])
            querry_c_ge = Variable(sample['c_ge'].float())
        
            querry_label = Variable(sample['label'].float())
        
        
            querry_d1=querry_d1.to(device)
            querry_d1_sm = querry_d1_sm.to(device)

         
            querry_d2=querry_d2.to(device)
            querry_d2_sm = querry_d2_sm.to(device)


        
            querry_cell = querry_cell.to(device)
            querry_c_ge = querry_c_ge.to(device)
        
            querry_label = querry_label.to(device)
            querry_pred = net_meta((querry_d1,querry_d1_sm), (querry_d2, querry_d2_sm), (querry_cell,querry_c_ge))
            querry_loss = criterion(querry_pred, querry_label.view(-1,1))
            querry_pred = querry_pred.cpu().detach().numpy()
            querry_label = querry_label.view(-1,1).cpu().detach().numpy()
            total_loss += querry_loss
                
    meta_loss = querry_loss
    
    print(f"Patch: {counter+1}, Meta Train Loss: {meta_loss.item()}")
    meta_optimizer.zero_grad()
    total_loss.backward(retain_graph=True)
    meta_optimizer.step()

    
torch.save(net_meta.state_dict(), "E:/result/meta_model.pth")

Patch: 1, Meta Train Loss: 0.787327766418457
Patch: 2, Meta Train Loss: 0.76658695936203
Patch: 3, Meta Train Loss: 0.7010809779167175
Patch: 4, Meta Train Loss: 0.6850055456161499
Patch: 5, Meta Train Loss: 0.6884840726852417
Patch: 6, Meta Train Loss: 0.6691299676895142
Patch: 7, Meta Train Loss: 0.7063966393470764
Patch: 8, Meta Train Loss: 0.7058179378509521
Patch: 9, Meta Train Loss: 0.7011981010437012
Patch: 10, Meta Train Loss: 0.6909726858139038
Patch: 11, Meta Train Loss: 0.6830936670303345
Patch: 12, Meta Train Loss: 0.6997944116592407
Patch: 13, Meta Train Loss: 0.702325165271759
Patch: 14, Meta Train Loss: 0.7120388746261597
Patch: 15, Meta Train Loss: 0.6953700184822083
Patch: 16, Meta Train Loss: 0.6870925426483154
Patch: 17, Meta Train Loss: 0.6963696479797363
Patch: 18, Meta Train Loss: 0.7059096097946167
Patch: 19, Meta Train Loss: 0.6915417313575745
Patch: 20, Meta Train Loss: 0.6957017183303833
Patch: 21, Meta Train Loss: 0.6971877813339233
Patch: 22, Meta Train Loss

Patch: 175, Meta Train Loss: 0.5885035991668701
Patch: 176, Meta Train Loss: 0.6503459811210632
Patch: 177, Meta Train Loss: 0.6192144155502319
Patch: 178, Meta Train Loss: 0.5800887942314148
Patch: 179, Meta Train Loss: 0.7001382112503052
Patch: 180, Meta Train Loss: 0.7011526823043823
Patch: 181, Meta Train Loss: 0.6814320087432861
Patch: 182, Meta Train Loss: 0.5755913257598877
Patch: 183, Meta Train Loss: 0.5411161184310913
Patch: 184, Meta Train Loss: 0.7367717623710632
Patch: 185, Meta Train Loss: 0.6010003685951233
Patch: 186, Meta Train Loss: 0.6404076218605042
Patch: 187, Meta Train Loss: 0.6350972652435303
Patch: 188, Meta Train Loss: 0.5611374378204346
Patch: 189, Meta Train Loss: 0.5453050136566162
Patch: 190, Meta Train Loss: 0.6219230890274048
Patch: 191, Meta Train Loss: 0.7567082643508911
Patch: 192, Meta Train Loss: 0.5734137892723083
Patch: 193, Meta Train Loss: 0.4775654971599579
Patch: 194, Meta Train Loss: 0.5325498580932617
Patch: 195, Meta Train Loss: 0.585676431

## Meta-testing

### challeng 1

In [7]:
#Divide tasks according to cell line  types
test_tasks = []
cell_lines = target_data['cell_line_name'].unique()
for cell_line in cell_lines:
    task = target_data[target_data['cell_line_name'] == cell_line]
    test_tasks.append(task)

In [112]:
preds = []
labels = []
all_roc = []
all_pr = []
for i in range (0,100):
    all_test_task = []
    for t in range(0,len(test_tasks)):
        cell_task = test_tasks[t]
        train, test = train_test_split(cell_task, test_size=0.5)
        temp_dict={'train':train,
               'test':test,
              }
        all_test_task.append(temp_dict)
    for j,d in enumerate(all_test_task):
        net_meta = Comb().to(device)
        net_meta.load_state_dict(torch.load("E:/result/meta_model.pth"))
        criterion = nn.BCELoss()
        optimizer = optim.Adam(net_meta.parameters(), lr=0.001)
        support_set = d['train']
        querry_set = d['test']
        supportdata = DrugCombDataset(support_set,drug_features,cell_features)
        querrydata = DrugCombDataset(querry_set,drug_features,cell_features)
        support_loader = DataLoader(supportdata, batch_size=len(support_set), shuffle=True)
        querry_loader = DataLoader(querrydata, batch_size=len(querry_set), shuffle=True)
        
    #fine-tuning meta model
        net_meta.train()
        for epoch in range(3):  # 3 gradient descent of support sets
            for iteration, sample in enumerate(support_loader):
                support_d1=Variable(sample['d1'])
                support_d1_sm = Variable(sample['d1_sm'])


        
                support_d2=Variable(sample['d2'])
                support_d2_sm = Variable(sample['d2_sm'])
     

        
                support_cell = Variable(sample['cell'])
                support_c_ge = Variable(sample['c_ge'].float())
        
                support_label = Variable(sample['label'].float())
        
        
                support_d1=support_d1.to(device)
                support_d1_sm = support_d1_sm.to(device)

         
                support_d2=support_d2.to(device)
                support_d2_sm = support_d2_sm.to(device)

        
                support_cell = support_cell.to(device)
                support_c_ge = support_c_ge.to(device)
        
                support_label = support_label.to(device)
                support_pred = net_meta((support_d1,support_d1_sm), (support_d2, support_d2_sm), (support_cell,support_c_ge))
                support_loss = criterion(support_pred, support_label.view(-1,1))
                support_loss.backward()
                optimizer.step()
                
     # model-predicting
        net_meta.eval()
        with torch.no_grad():
            for iteration, sample in enumerate(querry_loader):
                querry_d1=Variable(sample['d1'])
                querry_d1_sm = Variable(sample['d1_sm'])

                querry_d2=Variable(sample['d2'])
                querry_d2_sm = Variable(sample['d2_sm'])

        
                querry_cell = Variable(sample['cell'])
                querry_c_ge = Variable(sample['c_ge'].float())
        
                querry_label = Variable(sample['label'].float())
        
        
                querry_d1=querry_d1.to(device)
                querry_d1_sm = querry_d1_sm.to(device)

                querry_d2=querry_d2.to(device)
                querry_d2_sm = querry_d2_sm.to(device)

                querry_cell = querry_cell.to(device)
                querry_c_ge = querry_c_ge.to(device)
        
                querry_label = querry_label.to(device)
                querry_pred = net_meta((querry_d1,querry_d1_sm), (querry_d2, querry_d2_sm), (querry_cell,querry_c_ge))
                preds.extend(querry_pred.cpu().detach().numpy())
                labels.extend(querry_label.cpu().detach().numpy())
                
                
                
    fpr, tpr, thersholds = roc_curve(labels, preds)
    roc_auc = auc(fpr, tpr)
    precision, recall, _ = precision_recall_curve(labels, preds)
    aupr = auc(recall, precision)



    all_roc.append(roc_auc)
    all_pr.append(aupr)
    
mean_roc = np.mean(all_roc)
mean_pr = np.mean(all_pr)


print('MEAN auroc of the model on all the task data: %f' % mean_roc)   
print('MEAN auor of the model on all the task data: %f' % mean_pr)  

MEAN auroc of the model on all the task data: 0.750177
MEAN auor of the model on all the task data: 0.723584
