In [7]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import sklearn
import matplotlib.pyplot as plt
from datetime import datetime
from torch.nn.functional import mse_loss
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from sklearn.model_selection import train_test_split
from utils_pp import replace_cell_names_with_id
from utils_pp import Encoder
from utils_pp import EarlyStopper
from utils_pp import Dataset_from_pd
from utils_pp import train_one_epoch
from utils_pp import AE_DNN
from torch.utils.tensorboard import SummaryWriter
%load_ext autoreload
%autoreload 2

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

df_train, df_val, df_test

In [9]:
columns = ["cell_line", "drugA_name", "drugB_name", "drugA_conc", "drugB_conc", "target"]
data_train = pd.read_csv("../data_raw/oneil.csv", usecols=(1,2,3,4,5,12)).iloc[:,[0,1,3,2,4,5]].set_axis(columns, axis=1)
data_test = pd.read_csv("../data/test_yosua.csv").set_axis(columns + ["std"], axis=1).convert_dtypes()
data_train = replace_cell_names_with_id(dataframe=data_train, mapping_file="../data/mappingccl.csv")
data_test = replace_cell_names_with_id(dataframe=data_test, mapping_file="../data/mappingccl.csv")
drug_data = pd.read_pickle("../data/drug_data.pkl.compress", compression="gzip")
cell_data = pd.read_pickle("../data/cell_line_data.pkl.compress", compression="gzip")
data_train = data_train[data_train.cell_line.isin(cell_data.index)]

df_train, df_val = train_test_split(data_train, test_size=0.2, shuffle=True, random_state=42)
df_test = data_test

cell_data = cell_data[cell_data.index.isin(pd.concat([df_train.cell_line, df_test.cell_line]))]
drug_data = drug_data[drug_data.index.isin(pd.concat([df_train.drugA_name, df_train.drugB_name,df_test.drugA_name, df_test.drugB_name]))]
print("oneil", df_train.memory_usage().sum()/1e6, df_train.shape,"\n", df_train.dtypes)
print("drug_feat", drug_data.memory_usage().sum()/1e6, drug_data.shape)
print("cell_feat", cell_data.memory_usage().sum()/1e6, cell_data.shape)
DRUG_LENGTH = drug_data.shape[1]
CELL_LENGTH = cell_data.shape[1]
EMBED_SIZE = 770

oneil 13.13536 (234560, 6) 
 cell_line      object
drugA_name     object
drugB_name     object
drugA_conc    float64
drugB_conc    float64
target        float64
dtype: object
drug_feat 0.427236 (42, 2412)
cell_feat 0.686136 (32, 5011)


Loading model

In [28]:
model = AE_DNN([770,256,256,256, 128,128,128,64,64,64], drug_length=DRUG_LENGTH, cell_length=CELL_LENGTH)
model.load_state_dict(torch.load("../models/ae_dnn_model.pt"))

<All keys matched successfully>

In [None]:
def training(data_train, data_val, L, n_epochs=100):
    batch_size = 1024
    train_set  = Dataset_from_pd(data_train, drug_data, cell_data)
    val_set = Dataset_from_pd(data_val, drug_data, cell_data)
    test_set  = Dataset_from_pd(df_test, drug_data, cell_data)
    train_dl = DataLoader(train_set, batch_size=batch_size)
    xi, yi = next(iter(train_dl))
    val_dl = DataLoader(val_set, batch_size=batch_size)
    test_dl = DataLoader(test_set, batch_size=batch_size)

    model = AE_DNN([770,256,256,256, 128,128,128,64,64,64], DRUG_LENGTH, CELL_LENGTH)
    model.drug_encoder = Encoder(model.drug_encoder.h_sizes)
    model.drug_encoder.encoder.load_state_dict(torch.load("../models/drug_encoder.pt"))
    model.cell_encoder = Encoder(model.cell_encoder.h_sizes)
    model.cell_encoder.encoder.load_state_dict(torch.load("../models/cell_encoder.pt"))
    model.drug_encoder.eval()
    model.cell_encoder.eval()
    # print(summary(model.to("cuda"), 770))
    optimizer = torch.optim.Adam(model.hidden.parameters(), lr=1e-3)
    loss_fn = torch.nn.MSELoss()
    mae_fn = torch.nn.L1Loss()
    # Initializing in a separate cell so we can easily add more epochs to the same run
    # timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    # writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
    early_stopper = EarlyStopper(patience=10)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor=0.5, verbose=True, patience=5, min_lr=1e-7)

    epoch_number = 0
    EPOCHS = n_epochs

    for epoch in range(EPOCHS):

        # Make sure gradient tracking is on, and do a pass over the data
        model.hidden.train(True)
        model = model.to(device=device)
        avg_loss = train_one_epoch(model, epoch_number, "writer", train_dl, optimizer, loss_fn, device, 0)
        # Set the model to evaluation mode, disabling dropout and using population
        # statistics for batch normalization.
        model.eval()

        running_vloss = 0.0
        running_MAE = 0.
        running_PCC = 0.
        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(val_dl):

                vinputs, vlabels = vdata
                vinputs = vinputs.to(device)
                vlabels = vlabels.to(device)
                voutputs = model(vinputs)
                vloss = loss_fn(voutputs, vlabels)
                running_vloss += vloss.item()

                vx = voutputs - torch.mean(voutputs)
                vy = vlabels - torch.mean(vlabels)
                cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))
                running_PCC += cost.item()

                mae_loss = mae_fn(voutputs, vlabels)
                running_MAE += mae_loss.item()

        avg_vMAE = running_MAE/(i+1)
        avg_vPCC = running_PCC/(i+1)
        avg_vloss = running_vloss / (i + 1)
        print('epoch {} mse {:.{round}f} vmse {:.{round}f} vmae {:.{round}f} vpcc {:.{round}f} '.format(epoch_number+1, avg_loss, avg_vloss, avg_vMAE, avg_vPCC, round=4))
        scheduler.step(avg_loss)

        if early_stopper.early_stop(avg_vloss):             
            break
        epoch_number += 1
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(test_dl):
            inputs, labels = data
            inputs = inputs.to(device=device)
            labels = labels.to(device=device)
            outputs = model(inputs)
            loss_test = loss_fn(outputs, labels).item()
            mae_test = mae_fn(outputs, labels).item()
            vx = voutputs - torch.mean(outputs)
            vy = vlabels - torch.mean(labels)
            pcc_test = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2))).item()
    print([avg_vloss, avg_vMAE, avg_vPCC, loss_test, mae_test, pcc_test])
    L.append([avg_vloss, avg_vMAE, avg_vPCC, loss_test, mae_test, pcc_test, outputs.to("cpu").numpy().reshape(-1)])
    return model


LPO/LTO cv

In [None]:
unique_pairs = data_train.loc[:,["drugA_name","drugB_name"]].drop_duplicates()
unique_triplets = data_train.loc[:,["cell_line", "drugA_name","drugB_name"]].drop_duplicates()

from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True)
batch_size = 2048

L_leave_triplet = []
for train_index, val_index in kf.split(unique_triplets):
    train_unique_triplets = unique_triplets.iloc[train_index]
    val_unique_triplets =unique_triplets.iloc[val_index]
    combined_train = train_unique_triplets.loc[:,"cell_line"].str.cat([train_unique_triplets.loc[:,"drugA_name"], train_unique_triplets.loc[:,"drugB_name"]], sep= " + ")
    combined_val = val_unique_triplets.loc[:,"cell_line"].str.cat([val_unique_triplets.loc[:,"drugA_name"], val_unique_triplets.loc[:,"drugB_name"]], sep= " + ")
    
    data_train = df_train[data_train.loc[:,"cell_line"].str.cat([data_train.loc[:,"drugA_name"],data_train.loc[:,"drugB_name"]],sep=" + ").isin(combined_train)]
    data_val = df_train[data_train.loc[:,"cell_line"].str.cat([data_train.loc[:,"drugA_name"],data_train.loc[:,"drugB_name"]],sep=" + ").isin(combined_val)]
    print(data_train.shape, data_val.shape)
    training(data_train, data_val, L_leave_triplet)
L_leave_pair = []
for train_index, val_index in kf.split(unique_pairs):
    train_unique_pairs = unique_pairs.iloc[train_index]
    val_unique_pairs =unique_pairs.iloc[val_index]
    combined_train = train_unique_pairs.loc[:,"drugA_name"].str.cat(train_unique_pairs.loc[:,"drugB_name"], sep= " + ")
    combined_val = val_unique_pairs.loc[:,"drugA_name"].str.cat(val_unique_pairs.loc[:,"drugB_name"], sep= " + ")
    
    data_train = df_train[data_train.loc[:,"drugA_name"].str.cat(data_train.loc[:,"drugB_name"],sep=" + ").isin(combined_train)]
    data_val = df_train[data_train.loc[:,"drugA_name"].str.cat(data_train.loc[:,"drugB_name"],sep=" + ").isin(combined_val)]
    print(data_train.shape, data_val.shape)
    training(data_train, data_val, L_leave_pair)

L_leave_pair_records = np.empty((5,4))
for i in range(5):
    for j in range(4): L_leave_pair_records[i,j]=L_leave_pair[i][j]
np.save("Leave_pair_records", L_leave_pair_records)
new_array = np.load("Leave_pair_records.npy")
print(new_array)
L_leave_triplet_records = np.empty((5,4))
for i in range(5):
    for j in range(4): L_leave_triplet_records[i,j]=L_leave_triplet[i][j]
np.save("Leave_triplet_records", L_leave_triplet_records)
new_array = np.load("Leave_triplet_records.npy")
print(new_array)

test_pair_outputs = np.empty((5,24))
for i in range(5):
    for j in range(24): test_pair_outputs[i,j]=L_leave_pair[i][6][j]
np.save("test_pair_outputs", test_pair_outputs)
new_array = np.load("test_pair_outputs.npy")
print(new_array)
test_triplet_outputs = np.empty((5,24))
for i in range(5):
    for j in range(24): test_triplet_outputs[i,j]=L_leave_triplet[i][6][j]
np.save("test_triplet_outputs", test_triplet_outputs)
new_array = np.load("test_triplet_outputs.npy")
print(new_array)


FULL TRAINING LPO

In [67]:
L_leave_pair=[]
unique_pairs = data_train.loc[:,["drugA_name","drugB_name"]].drop_duplicates()
train_unique_pairs, val_unique_pairs = train_test_split(unique_pairs, test_size=0.2, random_state=42)
combined_train = train_unique_pairs.loc[:,"drugA_name"].str.cat(train_unique_pairs.loc[:,"drugB_name"], sep= " + ")
combined_val = val_unique_pairs.loc[:,"drugA_name"].str.cat(val_unique_pairs.loc[:,"drugB_name"], sep= " + ")

df_train = data_train[data_train.loc[:,"drugA_name"].str.cat(data_train.loc[:,"drugB_name"],sep=" + ").isin(combined_train)]
df_val = data_train[data_train.loc[:,"drugA_name"].str.cat(data_train.loc[:,"drugB_name"],sep=" + ").isin(combined_val)]
print(df_train.shape, df_val.shape)
training(df_train,  df_val, L_leave_pair, 200)

(168344, 6) (42732, 6)



epoch 1 mse 0.0892 vmse 0.0981 vmae 0.2629 vpcc 0.4269 
epoch 2 mse 0.0730 vmse 0.0783 vmae 0.2332 vpcc 0.5460 
epoch 3 mse 0.0636 vmse 0.0740 vmae 0.2235 vpcc 0.6018 
epoch 4 mse 0.0596 vmse 0.0686 vmae 0.2129 vpcc 0.6218 
epoch 5 mse 0.0574 vmse 0.0618 vmae 0.1998 vpcc 0.6490 
epoch 6 mse 0.0540 vmse 0.0596 vmae 0.1968 vpcc 0.6612 
epoch 7 mse 0.0527 vmse 0.0583 vmae 0.1930 vpcc 0.6731 
epoch 8 mse 0.0522 vmse 0.0575 vmae 0.1893 vpcc 0.6717 
epoch 9 mse 0.0503 vmse 0.0558 vmae 0.1860 vpcc 0.6760 
epoch 10 mse 0.0488 vmse 0.0547 vmae 0.1833 vpcc 0.6833 
epoch 11 mse 0.0488 vmse 0.0535 vmae 0.1808 vpcc 0.6895 
epoch 12 mse 0.0470 vmse 0.0530 vmae 0.1802 vpcc 0.6943 
epoch 13 mse 0.0460 vmse 0.0517 vmae 0.1767 vpcc 0.7023 
epoch 14 mse 0.0465 vmse 0.0522 vmae 0.1757 vpcc 0.7024 
epoch 15 mse 0.0461 vmse 0.0510 vmae 0.1734 vpcc 0.7045 
epoch 16 mse 0.0451 vmse 0.0498 vmae 0.1698 vpcc 0.7149 
epoch 17 mse 0.0442 vmse 0.0494 vmae 0.1694 vpcc 0.7151 
epoch 18 mse 0

AE_DNN(
  (drug_encoder): Encoder(
    (encoder): Sequential(
      (0): Linear(in_features=2412, out_features=512, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=512, out_features=512, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.2, inplace=False)
      (6): Linear(in_features=512, out_features=256, bias=True)
    )
  )
  (cell_encoder): Encoder(
    (encoder): Sequential(
      (0): Linear(in_features=5011, out_features=512, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=512, out_features=512, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.2, inplace=False)
      (6): Linear(in_features=512, out_features=256, bias=True)
    )
  )
  (hidden): ModuleList(
    (0): Linear(in_features=770, out_features=256, bias=True)
    (1): Dropout(p=0.1, inplace=False)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): Dropout(p=0.1, inplace=False)


FULL TRAINING LTO

FULL training

In [66]:
L_full_train_ep = []
df_train, df_val = train_test_split(data_train, test_size=0.2, shuffle=True, random_state=42)

model = training(df_train, df_val, L_full_train_ep, 100)




epoch 1 mse 0.0887 vmse 0.0884 vmae 0.2510 vpcc 0.4438 


KeyboardInterrupt: 

FULL training without validation

Saving model

In [18]:
from datetime import datetime
time = str(datetime.now()).replace(" ", "_").replace(":","_")
time = "model_finetune"
path  ="../models/ae_dnn_{}.pt".format(time)
torch.save(model.state_dict(), path)

Fine-tuning with testing data

In [33]:
def training_test(model,data_train, n_epochs=100, patience=10):
    batch_size = 1024
    test_set  = Dataset_from_pd(data_train, drug_data, cell_data)
    test_dl = DataLoader(test_set, batch_size=batch_size)    
    print(df_test)
    optimizer = torch.optim.Adam(model.hidden.parameters(), lr=1e-3)
    loss_fn = torch.nn.MSELoss()
    mae_fn = torch.nn.L1Loss()
    # Initializing in a separate cell so we can easily add more epochs to the same run
    # timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    # writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
    early_stopper = EarlyStopper(patience=patience)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor=0.5, verbose=True, patience=patience//2, min_lr=1e-7)

    epoch_number = 0
    EPOCHS = n_epochs

    for epoch in range(EPOCHS):

        # Make sure gradient tracking is on, and do a pass over the data
        model.hidden.train(True)
        model = model.to(device=device)
        avg_loss = train_one_epoch(model, epoch_number, "writer", test_dl, optimizer, loss_fn, device, 0, print_every=1)
        # Set the model to evaluation mode, disabling dropout and using population
        # statistics for batch normalization.
        model.eval()

        running_vloss = 0.0
        running_MAE = 0.
        running_PCC = 0.
        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(test_dl):

                vinputs, vlabels = vdata
                vinputs = vinputs.to(device)
                vlabels = vlabels.to(device)
                voutputs = model(vinputs)
                vloss = loss_fn(voutputs, vlabels)
                running_vloss += vloss.item()

                vx = voutputs - torch.mean(voutputs)
                vy = vlabels - torch.mean(vlabels)
                cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))
                running_PCC += cost.item()

                mae_loss = mae_fn(voutputs, vlabels)
                running_MAE += mae_loss.item()

        avg_vMAE = running_MAE/(i+1)
        avg_vPCC = running_PCC/(i+1)
        avg_vloss = running_vloss / (i + 1)
        print('epoch {} mse {:.{round}f} vmse {:.{round}f} vmae {:.{round}f} vpcc {:.{round}f} '.format(epoch_number+1, avg_loss, avg_vloss, avg_vMAE, avg_vPCC, round=4))
        scheduler.step(avg_loss)

        if early_stopper.early_stop(avg_vloss):             
            break
        epoch_number += 1

training_test(model, df_test,n_epochs=300, patience=20)


     cell_line   drugA_name    drugB_name  drugA_conc  drugB_conc    target  \
0   ACH-000768       GW9662      PD168393       2.975      4.0675   1.02498   
1   ACH-000768       GW9662      PD168393        5.95       8.135  0.788251   
2   ACH-000768       GW9662      PD168393       8.925     12.2025  0.711825   
3   ACH-000768       GW9662      PD168393        11.9       16.27  0.542289   
4   ACH-000768       GW9662  Rocilinostat       2.975       4.395  0.970456   
5   ACH-000768       GW9662  Rocilinostat        5.95        8.79   0.79969   
6   ACH-000768       GW9662  Rocilinostat       8.925      13.185  0.523822   
7   ACH-000768       GW9662  Rocilinostat        11.9       17.58  0.211632   
8   ACH-000768       GW9662   Saracatinib       2.975      0.0125  1.951693   
9   ACH-000768       GW9662   Saracatinib        5.95       0.025  1.492152   
10  ACH-000768       GW9662   Saracatinib       8.925      0.0375  1.485964   
11  ACH-000768       GW9662   Saracatinib        11