## Library import & Function definition

In [None]:
import numpy as np
import pandas as pd
import copy
import dgl
import torch
from tqdm import tqdm
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from dgllife.utils import smiles_to_bigraph
from torch.utils.data import DataLoader
from model.main.DMPNN import * 
from model.main.utils import *
from model.main.scheduler import NoamLR
from model.main.models import *
from model.main.trainer import *

In [None]:
def model_building(node_input_dim=74,
                      edge_input_dim=12,
                      node_hidden_dim=int(2 ** 7),
                      edge_hidden_dim=int(2 ** 7),
                      num_step_message_passing=4,
                      num_step_mha=1, tox21_task_num = 12, task_num = 3):
    
    model_chembl = ChEMBL_fullmodel(
                 node_input_dim=node_input_dim,
                 edge_input_dim=edge_input_dim,
                 node_hidden_dim=node_hidden_dim,
                 edge_hidden_dim=edge_hidden_dim,
                 num_step_message_passing=num_step_message_passing, 
                 num_step_mha=num_step_mha)    
    
    model_tox21 = Tox21_fullmodel(
                 node_input_dim=node_input_dim,
                 edge_input_dim=edge_input_dim,
                 node_hidden_dim=node_hidden_dim,
                 edge_hidden_dim=edge_hidden_dim,
                 num_step_message_passing=num_step_message_passing, 
                 num_step_mha=num_step_mha,
                 task_num = tox21_task_num)
     
    return model_chembl, model_tox21

In [None]:
def collate(sample):
    graphs, labels, mask = map(list,zip(*sample))
    batched_graph = dgl.batch(graphs)
    batched_graph.set_n_initializer(dgl.init.zero_initializer)
    batched_graph.set_e_initializer(dgl.init.zero_initializer)
    return batched_graph, torch.tensor(labels), torch.tensor(mask)

def load_data(df, labels, atom_featurizer, bond_featurizer):
    print("---------------- Target loading --------------------")
    test_g = [smiles_to_bigraph(smi, node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for smi in df['smiles']]
    test_y = df[labels].values.tolist()
    mask = np.array(df[labels].notna()).astype('int').tolist()
    test_data = list(zip(test_g, test_y, mask))
    print("---------------- Target loading complete --------------------")
    return test_data

def dataloader_tox21(train, valid, labels, batch_size):
    s_tr = load_data(train, labels, CanonicalAtomFeaturizer(), CanonicalBondFeaturizer())
    s_va =load_data(valid, labels, CanonicalAtomFeaturizer(), CanonicalBondFeaturizer())
    tr_loader = DataLoader(s_tr, batch_size=batch_size, shuffle = True, collate_fn = collate, drop_last = False)
    vr_loader = DataLoader(s_va, batch_size=batch_size, shuffle = False, collate_fn = collate, drop_last = False)

    return  tr_loader, vr_loader

## Model configuration & Implementation

In [None]:
GPU_NUM = 0
device = torch.device(f'cuda:{GPU_NUM}' if GPU_NUM >= 0 else 'cpu')
torch.cuda.set_device(device)

In [None]:
data_path = './data/internal_data/Tox21/'
df_train = pd.read_csv(data_path + 'tox21_train.csv')
df_valid = pd.read_csv(data_path + 'tox21_valid.csv')
assay_list = list(df_train.columns[1:])

In [None]:
batch_size = 128
n_epochs = 60
lr = 5e-4
warmup_epoch = 5
decay_step = 20
weight_decay = 1e-6

num_task = len(assay_list)

In [None]:
loss_list = []
mode = 'tox21'
for i, col in enumerate(assay_list):
    loss_weighted = weight_crossentropy(df_train, col, mode, device)
    loss_list.append(loss_weighted)

src_model, tox21_model = model_building()

In [None]:
state = torch.load("./model/pretrained_ckpts/chembl.pth", map_location=device)
src_model = src_model.to(device)
src_model.load_state_dict(state['model_state_dict'], strict = False)

In [None]:
src_model = src_model.to(device)
tox21_model = tox21_model.to(device)

In [None]:
tox21_model.gnn.load_state_dict(copy.deepcopy(src_model.featurizer.state_dict()))
for i in range(num_task):
    tox21_model.readout[i].load_state_dict(copy.deepcopy(src_model.readout.state_dict()))

tr_loader, vr_loader = dataloader_tox21(df_train, df_valid, assay_list, batch_size)

In [None]:
model_path = f'./Tox21_ckpts/'

In [None]:
model_optimizer = torch.optim.Adam(tox21_model.parameters(), lr = lr, weight_decay = weight_decay)
model_scheduler = NoamLR(optimizer = model_optimizer,
         warmup_epochs = [warmup_epoch],
         total_epochs = [decay_step],
         steps_per_epoch = len(tr_loader),
         init_lr = [1e-5],
         max_lr = [lr],
         final_lr = [1e-5])

In [None]:
metric_dict, top_epoch = tox21_train(tox21_model, model_path, tr_loader, vr_loader, model_optimizer, model_scheduler, loss_list, device, epochs = n_epochs)

## Performance evaluation

In [None]:
state = torch.load("./model/pretrained_ckpts/tox21.pth", map_location=device)
tox21_model.load_state_dict(state['model_state_dict'])
test_pred = tox21_model_prediction(tox21_model, vr_loader, device)


In [None]:
evaluate_results(test_pred, df_valid, assay_list, score)