##### This notebook creates Transfer Learning (TL) experiment, with base model trained on RTECS, transferred to predict on ClinTox

Using split data already saved.

RTECS dataset has been commented out. 

Notebook shows results for seed = 124, but we also ran on seed 122, 123. 

To create TL model. Base model was trained with a MTDNN that had two shared layers and two separate layers for all tasks. Then the two shared layers were extracted and frozen, and two additional layers training on ClinTox for 1 epoch was added. 

Before use define desired pathways to save models,:
- path variable, in "Create checkpoint" section for models
- writer variable, in "Train the neural network model" section for tensorboard summary
- base_modelpath variable, in "Save base model" section 

In [None]:
# general and data handling
import numpy as np
import pandas as pd
import os
from collections import Counter

# Required RDKit modules
import rdkit as rd
from rdkit import DataStructs
from rdkit.Chem import AllChem

# modeling
import sklearn as sk
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

# Graphing
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import time
import random
import joblib

In [None]:
import torch
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
# To ensure runs on GPU
import tensorflow as tf
import datetime, os

##### Settings

In [None]:
# set seed value
seed_value = 124 #122 123 124, as used in MoleculeNet
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
np.random.seed(seed_value)
random.seed(seed_value)
torch.backends.cudnn.enabled=False
torch.backends.cudnn.deterministic=True

In [None]:
# number of bits for morgan fingerprints
morgan_bits = 4096

In [None]:
# number of radius for morgan fingerprints
morgan_radius = 2

In [None]:
train_epoch = 50

In [None]:
batch = 512

In [None]:
# No limit on max columns displayed in a dataframe
pd.set_option('display.max_columns', None)

##### Load data

In [None]:
clintox_file = '../../data/datasets/clintox/raw_data/clintox.csv'
clintox_task = ['CT_TOX']

clintox_data = pd.read_csv(clintox_file)
print('Reading {}... {} data loaded.'.format(clintox_file, len(clintox_data)))
clintox_data.head(2)

In [None]:
a_oral_file = # cannot provide 

a_oral_data = pd.read_csv(a_oral_file)
a_oral_tasks = ['toxic_a_oral'] 
a_oral_data.head()

### <font color = "blue"> Create and train base model </font>

##### Setting all tasks 

In [None]:
data = [a_oral_data]

In [None]:
all_tasks = a_oral_tasks 

##### Load split data

In [None]:
######## ONLY RTECS ########
# load saved rtecs train/test/valid data 
data_path = # cannot provide 
train_data=torch.load(data_path + 'train_data_rtecs.pth')
test_data=torch.load(data_path + 'test_data_rtecs.pth')
valid_data=torch.load(data_path + 'valid_data_rtecs.pth')

data = [train_data, test_data, valid_data]

In [None]:
print("Total number of examples, train: " + str(data[0].shape[0]))
print("Total number of examples, test: " + str(data[1].shape[0]))
print("Total number of examples, valid: " + str(data[2].shape[0]))
print("Total number of examples, train+test+valid: " + str(data[0].shape[0] + data[1].shape[0] + data[2].shape[0]))

##### Construct Morgan Fingerprints 

In [None]:
%%time
# construct morgan fingerprints 
for i in range(len(data)):
    data[i]['mol'] = [rd.Chem.MolFromSmiles(x) for x in data[i]['smiles']]

    bi = [{} for _ in range(len(data[i]))]
    data[i]['morgan'] = [AllChem.GetMorganFingerprintAsBitVect(data[i].iloc[j]['mol'], morgan_radius, nBits = morgan_bits, bitInfo=bi[j]) 
                         for j in range(len(data[i]))]
    data[i]['bitInfo'] = bi

##### Create train, valid and test set 

In [None]:
data = [train_data, test_data, valid_data]

In [None]:
# replace NA with -1 -- used to deal with missing labels, 
#                       along with Binary Cross-Entropy loss
data[0] = data[0].fillna(-1)
data[1] = data[1].fillna(-1)
data[2] = data[2].fillna(-1)

In [None]:
train_data = data[0]
test_data  = data[1]
valid_data = data[2]

In [None]:
## Arrays for train / test / valid sets used for DNN 
# convert the RDKit explicit vectors into numpy arrays
x_train = []
for fp in train_data['morgan']:
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    x_train.append(arr)
x_train = np.array(x_train)
x_train = x_train - 0.5

y_train = train_data[all_tasks].values

In [None]:
# convert the RDKit explicit vectors into numpy arrays
x_test = []
for fp in test_data['morgan']:
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    x_test.append(arr)
x_test = np.array(x_test)
x_test = x_test - 0.5

y_test = test_data[all_tasks].values

In [None]:
# convert the RDKit explicit vectors into numpy arrays
x_valid = []
for fp in valid_data['morgan']:
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    x_valid.append(arr)
x_valid = np.array(x_valid)
x_valid = x_valid - 0.5

y_valid = valid_data[all_tasks].values

In [None]:
# count the number of data points per class
N_train = np.sum(y_train >= 0, 0)
N_test  = np.sum(y_test >= 0, 0)
N_valid  = np.sum(y_valid >= 0, 0)

#### Deep Neural Network (pytorch)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
# convert data for pytorch
x_train_torch = x_train.astype(np.float32)
y_train_torch = y_train.astype(np.float32)

x_test_torch = x_test.astype(np.float32)
y_test_torch = y_test.astype(np.float32)

x_valid_torch = x_valid.astype(np.float32)
y_valid_torch = y_valid.astype(np.float32)

In [None]:
input_shape = x_train_torch.shape[1]
input_shape

In [None]:
# Class for MTDNN data
class MTDNNData(Dataset):

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [None]:
training_set = MTDNNData(x_train_torch, y_train_torch)
training_generator = DataLoader(training_set, batch_size=batch, shuffle=True)

testing_set = MTDNNData(x_test_torch, y_test_torch)
testing_generator = DataLoader(testing_set, batch_size=len(testing_set), shuffle=False)

valid_set = MTDNNData(x_valid_torch, y_valid_torch)
valid_generator = DataLoader(valid_set, batch_size=len(valid_set), shuffle=False)

In [None]:
### MTDNN Model class
### 2 shared layers for all tasks, followed by 2 layers for each separate task (1 in RTECS)

class MTDNN(torch.nn.Module):
    def __init__(self, input_shape, all_tasks):
        super(MTDNN, self).__init__()

        # neural network layers
        self.shared_1 = torch.nn.Linear(input_shape, 2048)
        self.batchnorm_1 = torch.nn.BatchNorm1d(2048)
        
        self.shared_2 = torch.nn.Linear(2048, 1024)
        self.batchnorm_2 = torch.nn.BatchNorm1d(1024)
        
        self.hidden_3 = torch.nn.ModuleList([torch.nn.Linear(1024, 512) for task in all_tasks])
        self.batchnorm_3 = torch.nn.ModuleList([torch.nn.BatchNorm1d(512) for task in all_tasks])
        
        self.hidden_4 = torch.nn.ModuleList([torch.nn.Linear(512, 256) for task in all_tasks])
        self.batchnorm_4 = torch.nn.ModuleList([torch.nn.BatchNorm1d(256) for task in all_tasks])
        
        self.output   = torch.nn.ModuleList([torch.nn.Linear(256, 1) for task in all_tasks])
        
        # function for leaky ReLU
        self.leakyReLU = torch.nn.LeakyReLU(0.05)

    def forward(self, x):
        # shared layers
        x = self.shared_1(x)
        x = self.batchnorm_1(x)
        x = self.leakyReLU(x)
        
        x = self.shared_2(x)
        x = self.batchnorm_2(x)
        x = self.leakyReLU(x)
        
        x_task = [None for i in range(len(self.output))]  # initialize
        for task in range(len(self.output)):
            x_task[task] = self.hidden_3[task](x)
            x_task[task] = self.batchnorm_3[task](x_task[task])
            x_task[task] = self.leakyReLU(x_task[task])
            
            x_task[task] = self.hidden_4[task](x_task[task])
            x_task[task] = self.batchnorm_4[task](x_task[task])
            x_task[task] = self.leakyReLU(x_task[task])
            
            x_task[task] = self.output[task](x_task[task])
            x_task[task] = torch.sigmoid(x_task[task])
        
        y_pred = x_task
        
        return y_pred
    
model = MTDNN(input_shape, all_tasks).to(device)

###### Create checkpoint - saving and loading best model 

In [None]:
## Method from : https://gist.github.com/vsay01/45dfced69687077be53dbdd4987b6b17

import shutil
def save_ckp(state, is_best, checkpoint_path, best_model_path):
    """
    state: checkpoint we want to save
    is_best: is this the best checkpoint; min validation loss
    checkpoint_path: path to save checkpoint
    best_model_path: path to save best model
    """
    f_path = checkpoint_path
    # save checkpoint data to the path given, checkpoint_path
    torch.save(state, f_path)
    # if it is a best model, min validation loss
    if is_best:
        best_fpath = best_model_path
        # copy that checkpoint file to best path given, best_model_path
        shutil.copyfile(f_path, best_fpath)

In [None]:
def load_ckp(checkpoint_fpath, input_model, optimizer):
    """
    checkpoint_path: path to save checkpoint
    model: model that we want to load checkpoint parameters into       
    optimizer: optimizer we defined in previous training
    """
    # load check point
    checkpoint = torch.load(checkpoint_fpath)
    # initialize state_dict from checkpoint to model
    input_model.load_state_dict(checkpoint['state_dict'])
    # initialize optimizer from checkpoint to optimizer
    optimizer.load_state_dict(checkpoint['optimizer'])
    # initialize valid_loss_min from checkpoint to valid_loss_min
    train_loss_min = checkpoint['train_loss_min']
    # return model, optimizer, epoch value, min validation loss 
    return model, optimizer, checkpoint['epoch'], train_loss_min.item()

In [None]:
path = # define pathway

In [None]:
if not os.path.exists(path):
    os.makedirs(path)

In [None]:
###### Pathways to save models 
checkpoint_path = path + '/current_checkpoint.pt'

#Path to saved model when train_epoch_loss <= train_loss_min
bestmodel_path = path + '/best_model.pt'  

#Path to saved model at minimum valid loss
bestmodel_byvalid = path + '/best_model_by_valid.pt' 

#Path to saved  when train_epoch_loss >= val_epoch_loss
bestmodel_byvalid_crossed = path + '/best_model_by_valid-crossed.pt'   

##### Train the neural network model

In [None]:
# Define the loss
criterion = torch.nn.BCELoss()

# Optimizers require the parameters to optimize and a learning rate
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [None]:
from torch.utils.tensorboard import SummaryWriter

# Define the desired pathway
writer = SummaryWriter('define-pathway/')

In [None]:
%%time
##################### With Tensorboard ######################
loss_history=[]  
correct_history=[]  
val_loss_history=[]  
val_correct_history=[] 
train_loss_min = np.Inf
val_loss_min = np.Inf


# Training
for e in range(train_epoch):
    
    model.train()
    # keep track of the loss over an epoch
    running_train_loss = 0
    running_valid_loss = 0
    running_train_correct = 0
    running_val_correct = 0
    y_train_true = []
    y_train_pred = []
    y_valid_true = []
    y_valid_pred = []
    batch = 0
    for x_batch, y_batch in training_generator:
        batch += 1
        if torch.cuda.is_available():
            x_batch, y_batch = x_batch.cuda(), y_batch.cuda() 
        
        # Forward pass: Compute predicted y by passing x to the model
        y_pred = model(x_batch)  # for all tasks
        
        # Compute loss over all tasks
        loss = 0
        correct = 0
        y_train_true_task = []
        y_train_pred_task = []
        for i in range(len(all_tasks)):
            y_batch_task = y_batch[:,i]
            y_pred_task  = y_pred[i][:,0] #check if predictions na
            
            # compute loss for labels that are not NA
            indice_valid = y_batch_task >= 0
            loss_task = criterion(y_pred_task[indice_valid], y_batch_task[indice_valid]) / N_train[i]
            
            loss += loss_task

            pred_train = np.round(y_pred_task[indice_valid].detach().cpu().numpy())
            target_train = y_batch_task[indice_valid].float()
            y_train_true.extend(target_train.tolist()) 
            y_train_pred.extend(pred_train.reshape(-1).tolist())

        # Zero gradients, perform a backward pass, and update the weights.
        writer.add_scalar("Accuracy/train", loss, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        # sum up the losses from each batch
        running_train_loss += loss.item()
        writer.add_scalar("Loss/train", running_train_loss, e)
        
    else:
        with torch.no_grad():    
        ## evaluation part 
            model.eval()
            for val_x_batch, val_y_batch in valid_generator:
                
                if torch.cuda.is_available():
                    val_x_batch, val_y_batch = val_x_batch.cuda(), val_y_batch.cuda() 
                
                val_output = model(val_x_batch)

                ## 2. loss calculation over all tasks 
                val_loss = 0
                val_correct = 0
                y_valid_true_task = []
                y_valid_pred_task = []
                for i in range(len(all_tasks)):
                    val_y_batch_task = val_y_batch[:,i]
                    val_output_task  = val_output[i][:,0]

                    # compute loss for labels that are not NA
                    indice_valid = val_y_batch_task >= 0
                    val_loss_task = criterion(val_output_task[indice_valid], val_y_batch_task[indice_valid]) / N_valid[i]

                    val_loss += val_loss_task
                    
                    pred_valid = np.round(val_output_task[indice_valid].detach().cpu().numpy())
                    target_valid = val_y_batch_task[indice_valid].float()
                    y_valid_true.extend(target_valid.tolist()) 
                    y_valid_pred.extend(pred_valid.reshape(-1).tolist())
                

                running_valid_loss+=val_loss.item()
                writer.add_scalar("Loss/valid", running_valid_loss, e)
        
        #epoch loss
        train_epoch_loss=np.mean(running_train_loss)
        val_epoch_loss=np.mean(running_valid_loss)  
       
        #epoch accuracy     
        train_epoch_acc = accuracy_score(y_train_true,y_train_pred)
        val_epoch_acc = accuracy_score(y_valid_true,y_valid_pred)
        
        #history
        loss_history.append(train_epoch_loss)  
        correct_history.append(train_epoch_acc)
        val_loss_history.append(val_epoch_loss)  
        val_correct_history.append(val_epoch_acc)  
        
        print("Epoch:", e, "Training Loss:", train_epoch_loss, "Valid Loss:", val_epoch_loss)
        print("Training Acc:", train_epoch_acc, "Valid Acc:", val_epoch_acc)
        
        # create checkpoint variable and add important data
        checkpoint = {
            'epoch': e + 1,
            'train_loss_min': train_epoch_loss,
            'val_loss_min': val_epoch_loss, 
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        
        # save checkpoint
        save_ckp(checkpoint, False, checkpoint_path, bestmodel_path)
        
        ## TODO: save the model if validation loss has decreased
        if train_epoch_loss <= train_loss_min:
            print('Training loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(train_loss_min,train_epoch_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path, bestmodel_path)
            train_loss_min = train_epoch_loss
            
        if train_epoch_loss >= val_epoch_loss:
            print('Training loss greater than validation loss ({:.6f} --> {:.6f}).  Saving model ...'.format(train_epoch_loss,val_epoch_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path, bestmodel_byvalid_crossed)
            train_loss_min = train_epoch_loss
            
        if val_epoch_loss <= val_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(val_loss_min,val_epoch_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path, bestmodel_byvalid)
            val_loss_min = val_epoch_loss

In [None]:
# Loads model at lowest validation loss 
loaded_model, optimizer, start_epoch, train_loss_min = load_ckp(bestmodel_byvalid, model, optimizer)

In [None]:
print("model = ", loaded_model)
print("optimizer = ", optimizer)
print("start_epoch = ", start_epoch)
print("train_loss_min = ", train_loss_min)
print("train_loss_min = {:.6f}".format(train_loss_min))

##### Evaluate on test set

In [None]:
# print test loss
for x_test_torch, y_test_torch in testing_generator:
    y_test_pred = model.eval().to(device).cpu()(x_test_torch)
    
    # Compute loss over all tasks
    loss = 0
    for i in range(len(all_tasks)):
        y_test_task = y_test_torch[:,i]
        y_pred_task  = y_test_pred[i][:,0]

        # compute loss for labels that are not NA
        indice_valid = y_test_task >= 0
        loss_task = criterion(y_pred_task[indice_valid], y_test_task[indice_valid]) / N_test[i]

        loss += loss_task
    
print(loss.item())

In [None]:
results = {}
# Collects performance metrics for all tasks on test set in base model
for i in range(len(all_tasks)):
    
    valid_datapoints = y_test[:,i] >= 0
    y_test_task = y_test[valid_datapoints,i] 
    y_test_pred_task = y_test_pred[i].detach().numpy()[valid_datapoints,0]
    
    acc = accuracy_score(y_test_task, np.round(y_test_pred_task))
    print('Accuracy for MTDNN on Morgan Fingerprint:', acc)
    
    bacc = sk.metrics.balanced_accuracy_score(y_test_task, np.round(y_test_pred_task))

    f1 = f1_score(y_test_task, np.round(y_test_pred_task), pos_label=1)
    print('F1 for MTDNN on Morgan Fingerprint:', f1)

    cfm = sk.metrics.confusion_matrix(y_test_task, np.round(y_test_pred_task))
    cfm = cfm / cfm.astype(np.float).sum(axis=1)

    tn, fp, fn, tp = cfm.ravel()
    pr = tp / (tp + fp)
    rc = tp / (tp + fn)
    print(' True Positive:', tp)
    print(' True Negative:', tn)
    print('False Positive:', fp)
    print('False Negative:', fn)
    
    
    auc = roc_auc_score(y_test_task, y_test_pred_task)
    print('Test ROC AUC ({}):'.format(all_tasks[i]), auc)
    
    results[all_tasks[i]] = [auc, acc, bacc, tn, tp, pr, rc, f1]

    fpr, tpr, threshold = sk.metrics.roc_curve(y_test_task, y_test_pred_task)
    plt.plot(fpr, tpr, 'b', label = 'AUC')
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()

In [None]:
print('Task'.ljust(10), '\t', '  AUC ', ' ACC ', ' BACC ', ' TN  ', ' TP  ', ' PR  ', ' RC  ', ' F1  ')
for task, auc in results.items():
    print(task.ljust(10), '\t', np.round(auc,3))

##### See Valid set performance

In [None]:
# print test loss
for x_valid_torch, y_valid_torch in valid_generator:
    y_valid_pred = model.eval().to(device).cpu()(x_valid_torch)
    
    # Compute loss over all tasks
    loss = 0
    for i in range(len(all_tasks)):
        y_test_task = y_valid_torch[:,i]
        y_pred_task  = y_valid_pred[i][:,0]

        # compute loss for labels that are not NA
        indice_valid = y_test_task >= 0
        loss_task = criterion(y_pred_task[indice_valid], y_test_task[indice_valid]) / N_valid[i]

        loss += loss_task
    
print(loss.item())

In [None]:
results_valid = {}
# Collects performance metrics for all tasks on valid set in base model
for i in range(len(all_tasks)):
    
    valid_datapoints = y_valid[:,i] >= 0
    y_valid_task = y_valid[valid_datapoints,i] 
    y_valid_pred_task = y_valid_pred[i].detach().numpy()[valid_datapoints,0]
    
    
    acc = accuracy_score(y_valid_task, np.round(y_valid_pred_task))
    print('Accuracy for deepnn on Morgan Fingerprint:', acc)
    
    bacc = sk.metrics.balanced_accuracy_score(y_valid_task, np.round(y_valid_pred_task))

    f1 = f1_score(y_valid_task, np.round(y_valid_pred_task), pos_label=1)
    print('F1 for deepnn on Morgan Fingerprint:', f1)

    cfm = sk.metrics.confusion_matrix(y_valid_task, np.round(y_valid_pred_task))
    cfm = cfm / cfm.astype(np.float).sum(axis=1)

    print('Confusion Matrix for deepnn on Morgan Fingerprint:\n', cfm)

    tn, fp, fn, tp = cfm.ravel()
    pr = tp / (tp + fp)
    rc = tp / (tp + fn)
    print(' True Positive:', tp)
    print(' True Negative:', tn)
    print('False Positive:', fp)
    print('False Negative:', fn)
    
    
    auc = roc_auc_score(y_valid_task, y_valid_pred_task)
    print('Test ROC AUC ({}):'.format(all_tasks[i]), auc)
    
    results_valid[all_tasks[i]] = [auc, acc, bacc, tn, tp, pr, rc, f1]

    fpr, tpr, threshold = sk.metrics.roc_curve(y_valid_task, y_valid_pred_task)
    plt.plot(fpr, tpr, 'b', label = 'AUC')
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()

In [None]:
print('Task'.ljust(35), '\t', '  AUC ', ' ACC ', ' BACC ', ' TN  ', ' TP  ', ' PR  ', ' RC  ', ' F1  ')
for task, auc in results_valid.items():
    print(task.ljust(35), '\t', np.round(auc,3))

#### Save base model

In [None]:
base_modelpath = # define pathway for base model

In [None]:
if not os.path.exists(base_modelpath):
    os.makedirs(base_modelpath)

In [None]:
# Tensorflow saved model
torch.save(model.state_dict(), base_modelpath+f'base_mtdnn_rtecs.pt')

In [None]:
base = base_modelpath+f'base_mtdnn_rtecs.pt'

### <font color = "blue"> Transfer Learning </font>

#### Load ClinTox Data

In [None]:
all_tasks = ['CT_TOX']
task = all_tasks[0]

In [None]:
data_ct = [clintox_data]

#### Load ClinTox split data

In [None]:
# load saved clintox train/test/valid data
clintox_data_path = f"../../data/datasets/clintox/split_data/seed_{seed_value}/" 
train_data_ct=torch.load(clintox_data_path + 'train_data_clintox.pth')
test_data_ct=torch.load(clintox_data_path + 'test_data_clintox.pth')
valid_data_ct=torch.load(clintox_data_path + 'valid_data_clintox.pth')

data_ct = [train_data_ct, test_data_ct, valid_data_ct]

In [None]:
print("Total number of examples, train: " + str(data_ct[0].shape[0]))
print("Total number of examples, test: " + str(data_ct[1].shape[0]))
print("Total number of examples, valid: " + str(data_ct[2].shape[0]))
print("Total number of examples, train+test+valid: " + str(data_ct[0].shape[0] + data_ct[1].shape[0] + data_ct[2].shape[0]))

#### Construct FP for ClinTox

In [None]:
%%time
for i in range(len(data_ct)):
    data_ct[i]['mol'] = [rd.Chem.MolFromSmiles(x) for x in data_ct[i]['smiles']]

    bi = [{} for _ in range(len(data_ct[i]))]
    data_ct[i]['morgan'] = [AllChem.GetMorganFingerprintAsBitVect(data_ct[i].iloc[j]['mol'], morgan_radius, nBits = morgan_bits, bitInfo=bi[j]) 
                         for j in range(len(data_ct[i]))]
    data_ct[i]['bitInfo'] = bi

##### Create train, test, valid sets for ClinTox

In [None]:
data_ct = [train_data_ct, test_data_ct, valid_data_ct]

# replace NA with -1 -- used to deal with missing labels, 
#                       along with Binary Cross-Entropy loss
data_ct[0] = data_ct[0].fillna(-1)
data_ct[1] = data_ct[1].fillna(-1)
data_ct[2] = data_ct[2].fillna(-1)

train_data_ct = data_ct[0]
test_data_ct  = data_ct[1]
valid_data_ct = data_ct[2]

In [None]:
# convert the RDKit explicit vectors into numpy arrays
x_train_ct = []
for fp in train_data_ct['morgan']:
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    x_train_ct.append(arr)
x_train_ct = np.array(x_train_ct)
x_train_ct = x_train_ct - 0.5

y_train_ct = train_data_ct[all_tasks].values

# convert the RDKit explicit vectors into numpy arrays
x_test_ct = []
for fp in test_data_ct['morgan']:
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    x_test_ct.append(arr)
x_test_ct = np.array(x_test_ct)
x_test_ct = x_test_ct - 0.5

y_test_ct = test_data_ct[all_tasks].values

# convert the RDKit explicit vectors into numpy arrays
x_valid_ct = []
for fp in valid_data_ct['morgan']:
    arr = np.zeros((1,))
    DataStructs.ConvertToNumpyArray(fp, arr)
    x_valid_ct.append(arr)
x_valid_ct = np.array(x_valid_ct)
x_valid_ct = x_valid_ct - 0.5

y_valid_ct = valid_data_ct[all_tasks].values

# count the number of data points per class
N_train_ct = np.sum(y_train_ct >= 0, 0)
N_test_ct  = np.sum(y_test_ct >= 0, 0)
N_valid_ct  = np.sum(y_valid_ct >= 0, 0)

In [None]:
# convert data for pytorch
x_train_torch_ct = x_train_ct.astype(np.float32)
y_train_torch_ct = y_train_ct.astype(np.float32)

x_test_torch_ct = x_test_ct.astype(np.float32)
y_test_torch_ct = y_test_ct.astype(np.float32)

x_valid_torch_ct = x_valid_ct.astype(np.float32)
y_valid_torch_ct = y_valid_ct.astype(np.float32)

In [None]:
input_shape = x_train_torch_ct.shape[1]
input_shape

In [None]:
training_set_ct = MTDNNData(x_train_torch_ct, y_train_torch_ct)
training_generator_ct = DataLoader(training_set_ct, batch_size=batch, shuffle=True)

testing_set_ct = MTDNNData(x_test_torch_ct, y_test_torch_ct)
testing_generator_ct = DataLoader(testing_set_ct, batch_size=len(testing_set_ct), shuffle=False)

valid_set_ct = MTDNNData(x_valid_torch_ct, y_valid_torch_ct)
valid_generator_ct = DataLoader(valid_set_ct, batch_size=len(valid_set_ct), shuffle=False)

#### Define base model, containing only the two shared layers of the MTDNN and its trained weights

In [None]:
# MTDNN base model - contains only the first two shared layers of the base model 

class MTDNN_base(torch.nn.Module):
    def __init__(self, input_shape, all_tasks):
        super(MTDNN_base, self).__init__()

        # neural network layers
        self.shared_1 = torch.nn.Linear(input_shape, 2048)
        self.batchnorm_1 = torch.nn.BatchNorm1d(2048)
        
        self.shared_2 = torch.nn.Linear(2048, 1024)
        self.batchnorm_2 = torch.nn.BatchNorm1d(1024)
        
        # function for leaky ReLU
        self.leakyReLU = torch.nn.LeakyReLU(0.05)
        
    def forward(self, x):
        # shared layers
        x = self.shared_1(x)
        x = self.batchnorm_1(x)
        x = self.leakyReLU(x)
        
        x = self.shared_2(x)
        x = self.batchnorm_2(x)
        x = self.leakyReLU(x)
 
        return x
    
base_model_rtecs = MTDNN_base(input_shape, all_tasks).to(device)

In [None]:
# Load trained parameters of base model

pretrained_dict = torch.load(base, map_location=device)
base_model_rtecs_dict = base_model_rtecs.state_dict()
# remove the keys corresponing to the linear layer in the pretrained base dict
import itertools  
from collections import OrderedDict
N = 14
loaded_dict = OrderedDict(itertools.islice(pretrained_dict.items(), N))  

# now update the model dict with pretrained dict
base_model_rtecs_dict.update(loaded_dict)

In [None]:
## Make sure we only have the first two layers of the MTDNN base_model_rtecs = torch.nn.Sequential(*(list(base_model_rtecs.children())[0:5]))
for param in base_model_rtecs.parameters():
    param.requires_grad = False

In [None]:
base_model_rtecs

#### Add two layers to base model to train on ClinTox

In [None]:
base_model_rtecs.hidden_ct_1 = torch.nn.Linear(1024, 512)
base_model_rtecs.batchnorm_1 = torch.nn.BatchNorm1d(512)

base_model_rtecs.hidden_ct_2 = torch.nn.Linear(512, 256)
base_model_rtecs.batchnorm_2 = torch.nn.BatchNorm1d(256)
base_model_rtecs.output = torch.nn.Linear(256, 1)
base_model_rtecs.sigmoid = torch.nn.Sigmoid()


base_model_rtecs = base_model_rtecs.to(device)

In [None]:
base_model_rtecs

#### Create checkpoints

In [None]:
path = # define path to save TL model 

In [None]:
if not os.path.exists(path):
    os.makedirs(path)

In [None]:
###### Pathways to save models 
checkpoint_path = path + '/current_checkpoint.pt'

#Path to saved model when train_epoch_loss <= train_loss_min
bestmodel_path = path + '/best_model.pt'  

#Path to saved model at minimum valid loss
bestmodel_byvalid = path + '/best_model_by_valid.pt' 

#Path to saved  when train_epoch_loss >= val_epoch_loss
bestmodel_byvalid_crossed = path + '/best_model_by_valid-crossed.pt'   

##### Train TL model with ClinTox - 1 epoch

In [None]:
##### Freeze the layers of the base model 
count = 0
for param in base_model_rtecs.parameters():
    count +=1
    if count < 6: #freezing first 5 layers
        param.requires_grad = False

In [None]:
# Define the loss
criterion = torch.nn.BCELoss()

# Optimizers require the parameters to optimize and a learning rate, providing the parameters of the base model too
optimizer_ct = torch.optim.Adam(filter(lambda p: p.requires_grad, base_model_rtecs.parameters()), lr = 0.001)
for state in optimizer_ct.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

In [None]:
# Only one epoch training with added ClinTox layers - to ensure not heavily trained on ClinTox
train_epoch = 1

In [None]:
%%time
##################### With Tensorboard ######################
loss_history=[]  
correct_history=[]  
val_loss_history=[]  
val_correct_history=[] 
train_loss_min = np.Inf
val_loss_min = np.Inf

# Training
for e in range(train_epoch):
    
    base_model_rtecs.train()
    # keep track of the loss over an epoch
    running_train_loss = 0
    running_valid_loss = 0
    running_train_correct = 0
    running_val_correct = 0
    y_train_true = []
    y_train_pred = []
    y_valid_true = []
    y_valid_pred = []
    batch_num = 0
    for x_batch, y_batch in training_generator_ct:
        batch_num += 1
        if torch.cuda.is_available():
            x_batch, y_batch = x_batch.cuda(), y_batch.cuda() 
        
        # Forward pass: Compute predicted y by passing x to the model
        y_pred = base_model_rtecs(x_batch)  # for all tasks
        
        # Compute loss over all tasks
        loss = 0
        correct = 0
        y_train_true_task = []
        y_train_pred_task = []
        for i in range(len(all_tasks)):
            y_batch_task = y_batch[:,i]
            y_pred_task  = y_pred[:,0] #check if predictions na
            
            # compute loss for labels that are not NA
            indice_valid = y_batch_task >= 0
            loss_task = criterion(y_pred_task[indice_valid], y_batch_task[indice_valid]) / N_train_ct[i]
            
            loss += loss_task

            pred_train = np.round(y_pred_task[indice_valid].detach().cpu().numpy())
            target_train = y_batch_task[indice_valid].float()
            y_train_true.extend(target_train.tolist()) 
            y_train_pred.extend(pred_train.reshape(-1).tolist())

        # Zero gradients, perform a backward pass, and update the weights.
        optimizer_ct.zero_grad()
        loss.backward()
        optimizer_ct.step()
    
        # sum up the losses from each batch
        running_train_loss += loss.item()
        writer.add_scalar("Loss/train", running_train_loss, e)
        
    else:
        with torch.no_grad():    
        ## evaluation part 
            base_model_rtecs.eval()
            for val_x_batch, val_y_batch in valid_generator_ct:
                
                if torch.cuda.is_available():
                    val_x_batch, val_y_batch = val_x_batch.cuda(), val_y_batch.cuda() 
                
                val_output = base_model_rtecs(val_x_batch)

                ## 2. loss calculation over all tasks 
                val_loss = 0
                val_correct = 0
                y_valid_true_task = []
                y_valid_pred_task = []
                for i in range(len(all_tasks)):
                    val_y_batch_task = val_y_batch[:,i]
                    val_output_task  = val_output[:,0]

                    # compute loss for labels that are not NA
                    indice_valid = val_y_batch_task >= 0
                    val_loss_task = criterion(val_output_task[indice_valid], val_y_batch_task[indice_valid]) / N_valid_ct[i]

                    val_loss += val_loss_task
                    
                    pred_valid = np.round(val_output_task[indice_valid].detach().cpu().numpy())
                    target_valid = val_y_batch_task[indice_valid].float()
                    y_valid_true.extend(target_valid.tolist()) 
                    y_valid_pred.extend(pred_valid.reshape(-1).tolist())

                running_valid_loss+=val_loss.item()
                writer.add_scalar("Loss/valid", running_valid_loss, e)
        
        #epoch loss
        train_epoch_loss=np.mean(running_train_loss)
        val_epoch_loss=np.mean(running_valid_loss)  
       
        #epoch accuracy     
        train_epoch_acc = accuracy_score(y_train_true,y_train_pred)
        val_epoch_acc = accuracy_score(y_valid_true,y_valid_pred)
        
        #history
        loss_history.append(train_epoch_loss)  
        correct_history.append(train_epoch_acc)
        val_loss_history.append(val_epoch_loss)  
        val_correct_history.append(val_epoch_acc)  
        
        print("Epoch:", e, "Training Loss:", train_epoch_loss, "Valid Loss:", val_epoch_loss)
        print("Training Acc:", train_epoch_acc, "Valid Acc:", val_epoch_acc)
        
        # create checkpoint variable and add important data
        checkpoint = {
            'epoch': e + 1,
            'train_loss_min': train_epoch_loss,
            'val_loss_min': val_epoch_loss, 
            'state_dict': base_model_rtecs.state_dict(),
            'optimizer': optimizer_ct.state_dict(),
        }
        
        # save checkpoint
        save_ckp(checkpoint, False, checkpoint_path_ct, bestmodel_path_ct)
        
        ## TODO: save the model if validation loss has decreased
        if train_epoch_loss <= train_loss_min:
            print('Training loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(train_loss_min,train_epoch_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path_ct, bestmodel_path_ct)
            train_loss_min = train_epoch_loss
            
        if train_epoch_loss >= val_epoch_loss:
            print('Training loss greater than validation loss ({:.6f} --> {:.6f}).  Saving model ...'.format(train_epoch_loss,val_epoch_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path_ct, bestmodel_byvalid_crossed_ct)
            train_loss_min = train_epoch_loss
            
        if val_epoch_loss <= val_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(val_loss_min,val_epoch_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path_ct, bestmodel_byvalid_ct)
            val_loss_min = val_epoch_loss

In [None]:
# Load TL model with lowest valid loss 
loaded_model_ct, optimizer_ct, start_epoch, valid_loss_min, train_loss_min = load_ckp(bestmodel_byvalid_ct, base_model_rtecs, optimizer_ct)

In [None]:
print("model = ", loaded_model_ct)
print("optimizer = ", optimizer_ct)
print("start_epoch = ", start_epoch)
print("val_loss_min = {:.6f}".format(val_loss_min))
print("train_loss_min = {:.6f}".format(train_loss_min))

##### Evaluate test performance of TL model on ClinTox

In [None]:
# print test loss
for x_test_torch, y_test_torch in testing_generator_ct:
    y_test_pred = loaded_model_ct.eval().to(device).cpu()(x_test_torch)
    
    # Compute loss over all tasks
    loss = 0
    for i in range(len(all_tasks)):
        y_test_task = y_test_torch[:,i]
        y_pred_task  = y_test_pred[:,0]

        # compute loss for labels that are not NA
        indice_valid = y_test_task >= 0
        loss_task = criterion(y_pred_task[indice_valid], y_test_task[indice_valid]) / N_test_ct[i]

        loss += loss_task
    
print(loss.item())

In [None]:
results_ct = {}
# Collects performance metrics on ClinTox for test set
# Note ClinTox is very skewed, for this test set there were no "toxic" samples. 
for i in range(len(all_tasks)):
    
    valid_datapoints = y_test_ct[:,i] >= 0
    y_test_task = y_test_ct[valid_datapoints,i] 
    y_test_pred_task = y_test_pred.detach().numpy()[valid_datapoints,0]
    
    acc = accuracy_score(y_test_task, np.round(y_test_pred_task))
    print('Accuracy for MTDNN on Morgan Fingerprint:', acc)
    
    bacc = sk.metrics.balanced_accuracy_score(y_test_task, np.round(y_test_pred_task))

    f1 = f1_score(y_test_task, np.round(y_test_pred_task), pos_label=1)
    print('F1 for MTDNN on Morgan Fingerprint:', f1)

    cfm = sk.metrics.confusion_matrix(y_test_task, np.round(y_test_pred_task))
    cfm = cfm / cfm.astype(np.float).sum(axis=1)

    tn, fp, fn, tp = cfm.ravel()
    pr = tp / (tp + fp)
    rc = tp / (tp + fn)
    print(' True Positive:', tp)
    print(' True Negative:', tn)
    print('False Positive:', fp)
    print('False Negative:', fn)
    
    
    auc = roc_auc_score(y_test_task, y_test_pred_task)
    print('Test ROC AUC ({}):'.format(all_tasks[i]), auc)
    
    results_ct[all_tasks[i]] = [auc, acc, bacc, tn, tp, pr, rc, f1]

    fpr, tpr, threshold = sk.metrics.roc_curve(y_test_task, y_test_pred_task)
    plt.plot(fpr, tpr, 'b', label = 'AUC')
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()

In [None]:
print('Task'.ljust(10), '\t', '  AUC ', ' ACC ', ' BACC ', ' TN  ', ' TP  ', ' PR  ', ' RC  ', ' F1  ')
for task, auc in results_ct.items():
    print(task.ljust(10), '\t', np.round(auc,3))

##### Evaluate valid performance of TL model on ClinTox

In [None]:
# print test loss
for x_valid_torch, y_valid_torch in valid_generator_ct:
    y_valid_pred = base_model_rtecs.eval().to(device).cpu()(x_valid_torch)
    
    # Compute loss over all tasks
    loss = 0
    for i in range(len(all_tasks)):
        y_test_task = y_valid_torch[:,i]
        y_pred_task  = y_valid_pred[:,0]

        # compute loss for labels that are not NA
        indice_valid = y_test_task >= 0
        loss_task = criterion(y_pred_task[indice_valid], y_test_task[indice_valid]) / N_valid_ct[i]

        loss += loss_task
    
print(loss.item())

In [None]:
results_valid_ct = {}
# Collects performance metrics on ClinTox for valid set
for i in range(len(all_tasks)):
    
    valid_datapoints = y_valid_ct[:,i] >= 0
    y_valid_task = y_valid_ct[valid_datapoints,i] 
    y_valid_pred_task = y_valid_pred.detach().numpy()[valid_datapoints,0]
    
    
    acc = accuracy_score(y_valid_task, np.round(y_valid_pred_task))
    print('Accuracy for deepnn on Morgan Fingerprint:', acc)
    
    bacc = sk.metrics.balanced_accuracy_score(y_valid_task, np.round(y_valid_pred_task))

    f1 = f1_score(y_valid_task, np.round(y_valid_pred_task), pos_label=1)
    print('F1 for deepnn on Morgan Fingerprint:', f1)

    cfm = sk.metrics.confusion_matrix(y_valid_task, np.round(y_valid_pred_task))#normalize='true'
    cfm = cfm / cfm.astype(np.float).sum(axis=1)

    print('Confusion Matrix for deepnn on Morgan Fingerprint:\n', cfm)

    tn, fp, fn, tp = cfm.ravel()
    pr = tp / (tp + fp)
    rc = tp / (tp + fn)
    print(' True Positive:', tp)
    print(' True Negative:', tn)
    print('False Positive:', fp)
    print('False Negative:', fn)
    
    
    auc = roc_auc_score(y_valid_task, y_valid_pred_task)
    print('Test ROC AUC ({}):'.format(all_tasks[i]), auc)
    
    results_valid_ct[all_tasks[i]] = [auc, acc, bacc, tn, tp, pr, rc, f1]

    fpr, tpr, threshold = sk.metrics.roc_curve(y_valid_task, y_valid_pred_task)
    plt.plot(fpr, tpr, 'b', label = 'AUC')
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()

In [None]:
print('Task'.ljust(35), '\t', '  AUC ', ' ACC ', ' BACC ', ' TN  ', ' TP  ', ' PR  ', ' RC  ', ' F1  ')
for task, auc in results_valid_ct.items():
    print(task.ljust(35), '\t', np.round(auc,3))