## Library imoprt & Function definition

In [None]:
import random
import os
import numpy as np
import pandas as pd
import copy
import dgl
import torch
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from dgllife.utils import smiles_to_bigraph
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from tqdm import tqdm
from torch.autograd import Variable
from model.main.DMPNN import * 
from model.main.utils import *
from model.main.scheduler import NoamLR
from model.main.models import *
from model.main.min_norm_solvers import MinNormSolver, gradient_normalizers
from model.main.trainer import *

In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [None]:
def mtl_building(node_input_dim=74,
                      edge_input_dim=12,
                      node_hidden_dim=int(2 ** 7),
                      edge_hidden_dim=int(2 ** 7),
                      num_step_message_passing=4,
                      num_step_mha=1, tox21_task_num = 12, task_num = 3):
    
    model_chembl = ChEMBL_fullmodel(
                 node_input_dim=node_input_dim,
                 edge_input_dim=edge_input_dim,
                 node_hidden_dim=node_hidden_dim,
                 edge_hidden_dim=edge_hidden_dim,
                 num_step_message_passing=num_step_message_passing, 
                 num_step_mha=num_step_mha)    
    
    model_tox21 = Tox21_embed(
                 node_input_dim=node_input_dim,
                 edge_input_dim=edge_input_dim,
                 node_hidden_dim=node_hidden_dim,
                 edge_hidden_dim=edge_hidden_dim,
                 num_step_message_passing=num_step_message_passing, 
                 num_step_mha=num_step_mha,
                 task_num = tox21_task_num)
    
    model = MTL_invivo(
                 node_input_dim=node_input_dim,
                 edge_input_dim=edge_input_dim,
                 node_hidden_dim=node_hidden_dim,
                 edge_hidden_dim=edge_hidden_dim,
                 num_step_message_passing=num_step_message_passing, 
                 num_step_mha=num_step_mha,
                tox21_task_num = tox21_task_num,
                task_num = task_num)
    
    return model_chembl, model_tox21, model

In [None]:
def tox21_collate(sample):
    graphs, mask = map(list,zip(*sample))
    batched_graph = dgl.batch(graphs)
    batched_graph.set_n_initializer(dgl.init.zero_initializer)
    batched_graph.set_e_initializer(dgl.init.zero_initializer)
    return batched_graph, mask

def tox21_load_data(df, labels, atom_featurizer, bond_featurizer):
    print("---------------- Target loading --------------------")
    test_g = [smiles_to_bigraph(smi, node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for smi in df['smiles']]
    mask = np.array(df[labels].notna()).astype('int').tolist()
    test_data = list(zip(test_g, mask))
    print("---------------- Target loading complete --------------------")
    return test_data

def dataloader_tox21(train, labels, batch_size):
    s_tr = tox21_load_data(train, labels, CanonicalAtomFeaturizer(), CanonicalBondFeaturizer())
    tr_loader = DataLoader(s_tr, batch_size=batch_size, shuffle = False, collate_fn = tox21_collate, drop_last = False)

    return tr_loader

def collate(sample):
    graphs, labels, embeds, mask = map(list,zip(*sample))
    batched_graph = dgl.batch(graphs)
    batched_graph.set_n_initializer(dgl.init.zero_initializer)
    batched_graph.set_e_initializer(dgl.init.zero_initializer)
    return batched_graph, torch.tensor(labels), embeds, torch.tensor(mask)

def load_data(df, labels, embed, atom_featurizer, bond_featurizer):
    print("---------------- Target loading --------------------")
    test_g = [smiles_to_bigraph(smi, node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for smi in df['smiles']]
    test_y = df[labels].values.tolist()
    mask = np.array(df[labels].notna()).astype('int').tolist()
    embed_list = [embed[i, :, :] for i in range(len(df))]
    test_data = list(zip(test_g, test_y, embed_list, mask))
    print("---------------- Target loading complete --------------------")
    return test_data

def dataloader_train(train, valid, labels, tr_embeds, va_embeds, batch_size, sampler):
    s_tr = load_data(train, labels, tr_embeds, CanonicalAtomFeaturizer(), CanonicalBondFeaturizer())
    s_va =load_data(valid, labels, va_embeds, CanonicalAtomFeaturizer(), CanonicalBondFeaturizer())
    if sampler is not None:
        tr_loader = DataLoader(s_tr, batch_size=batch_size, shuffle = False, collate_fn = collate, drop_last = True, sampler = sampler)
    else:
        tr_loader = DataLoader(s_tr, batch_size=batch_size, shuffle = True, collate_fn = collate, drop_last = False)
    vr_loader = DataLoader(s_va, batch_size=batch_size, shuffle = False, collate_fn = collate, drop_last = False)

    return tr_loader, vr_loader

def dataloader_test(train, labels, embeds, batch_size):
    s_tr = load_data(train, labels, embeds, CanonicalAtomFeaturizer(), CanonicalBondFeaturizer())
    tr_loader = DataLoader(s_tr, batch_size=batch_size, shuffle = False, collate_fn = collate, drop_last = False)

    return tr_loader

In [None]:
def tox21_embed_calculate(model_pred, tr_loader, device):

    graph_out_list = []

    with torch.no_grad():
        model_pred.eval()

        for num, (tox_g, tox_mask) in enumerate(tr_loader):

            tox_g = tox_g.to(device)
            tox_atom = tox_g.ndata.pop('h').to(device)
            tox_bond = tox_g.edata.pop('e').to(device)
            outputs = model_pred.forward(tox_g, tox_atom, tox_bond)

            graph_out_list.append(outputs.detach().to('cpu').numpy())
            
    total_graph_out = torch.tensor(np.vstack(graph_out_list))
    
    return total_graph_out

## Model configuration & Implementation

In [None]:
GPU_NUM = 0
device = torch.device(f'cuda:{GPU_NUM}' if GPU_NUM >= 0 else 'cpu')
torch.cuda.set_device(device)

In [None]:
data_path = './data/internal_data/In_vivo/'
df_train = pd.read_csv(data_path + 'invivo_train.csv')
df_valid = pd.read_csv(data_path + 'invivo_valid.csv')
df_test = pd.read_csv(data_path + 'invivo_test.csv')
assay_list = list(df_train.columns[1:])

In [None]:
def set_seed(seed: int = 42) :
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")
    

seed = 109
set_seed(seed)

In [None]:
batch_size = 64
n_epochs = 60
lr = 1e-4
warmup_epoch = 3
decay_step = 13
weight_decay = 1e-6

In [None]:
num_task = len(assay_list)
mode = 'invivo'
loss_list = []
for i, col in enumerate(assay_list):
    loss_weighted = weight_crossentropy(df_train, col, mode, device)
    loss_list.append(loss_weighted)

src_model, tox21_model, model = mtl_building()

state = torch.load("./model/pretrained_ckpts/chembl.pth", map_location=device)
src_model = src_model.to(device)
src_model.load_state_dict(state['model_state_dict'], strict = False)

state = torch.load("./model/pretrained_ckpts/tox21.pth", map_location=device)
tox21_model = tox21_model.to(device)
tox21_model.load_state_dict(state['model_state_dict'], strict = False)

In [None]:
model.to(device)
model.invivo_gnn.load_state_dict(copy.deepcopy(src_model.featurizer.state_dict()))
for i in range(3):
    model.invivo_readout[i].load_state_dict(copy.deepcopy(src_model.readout.state_dict()))

sample_weights = calculate_sample_weights(np.array(df_train[assay_list].notna()).astype('int'))
sampler = WeightedRandomSampler(sample_weights.type('torch.DoubleTensor'), num_samples = len(sample_weights))

tr_tox21_loader = dataloader_tox21(df_train, assay_list, 100)
va_tox21_loader = dataloader_tox21(df_valid, assay_list, 100)
ts_tox21_loader = dataloader_tox21(df_test, assay_list, 100)

In [None]:
tr_tox21_embed = tox21_embed_calculate(tox21_model, tr_tox21_loader, device)
va_tox21_embed = tox21_embed_calculate(tox21_model, va_tox21_loader, device)
ts_tox21_embed = tox21_embed_calculate(tox21_model, ts_tox21_loader, device)

In [None]:
tr_loader, va_loader = dataloader_train(df_train, df_valid, assay_list, tr_tox21_embed, va_tox21_embed, batch_size, sampler = sampler)
data_NT = len(tr_loader)

In [None]:
model_path = f'./invivo_ckpts/'

model_optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = weight_decay)
model_scheduler = NoamLR(optimizer = model_optimizer,
         warmup_epochs = [warmup_epoch],
         total_epochs = [decay_step],
         steps_per_epoch = data_NT,
         init_lr = [1e-5],
         max_lr = [lr],
         final_lr = [3e-5])

In [None]:
metric_dict, top_epoch = invivo_model_train(model, model_path, tr_loader, va_loader, model_optimizer, model_scheduler, loss_list, device, 
              seed, epochs = n_epochs)
print("Finished at :{}".format(top_epoch))

## Performance evaluation

In [None]:
ts_loader = dataloader_test(df_test, assay_list, ts_tox21_embed, 100)

state = torch.load("./invivo_ckpts/seed109/epoch_36.pth", map_location=device)
model.load_state_dict(state['model_state_dict'])
test_pred = invivo_model_test(model, ts_loader, device)

In [None]:
evaluate_results(test_pred, df_valid, assay_list, score)