## Library imoprt & Function definition

In [None]:
import numpy as np
import pandas as pd
import dgl
import torch
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer
from dgllife.utils import smiles_to_bigraph
from torch.utils.data import DataLoader
import copy
from model.main.DMPNN import * 
from model.main.utils import *
from model.main.models import *
from model.main.trainer import *

In [None]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [None]:
def mtl_building(node_input_dim=74,
                      edge_input_dim=12,
                      node_hidden_dim=int(2 ** 7),
                      edge_hidden_dim=int(2 ** 7),
                      num_step_message_passing=4,
                      num_step_mha=1, tox21_task_num = 12, task_num = 3):

    model = MTL_invivo_inference(
                 node_input_dim=node_input_dim,
                 edge_input_dim=edge_input_dim,
                 node_hidden_dim=node_hidden_dim,
                 edge_hidden_dim=edge_hidden_dim,
                 num_step_message_passing=num_step_message_passing, 
                 num_step_mha=num_step_mha,
                tox21_task_num = tox21_task_num,
                task_num = task_num
            )
    
    model_tox21 = Tox21_embed(
            node_input_dim = node_input_dim,
            edge_input_dim = edge_input_dim,
            node_hidden_dim = node_hidden_dim,
            edge_hidden_dim = edge_hidden_dim,
            num_step_message_passing = num_step_message_passing, 
            num_step_mha = num_step_mha,task_num = tox21_task_num
        )    
    
    return model, model_tox21

In [None]:
def collate(sample):
    graphs, labels = map(list,zip(*sample))
    batched_graph = dgl.batch(graphs)
    batched_graph.set_n_initializer(dgl.init.zero_initializer)
    batched_graph.set_e_initializer(dgl.init.zero_initializer)
    return batched_graph, torch.tensor(labels)

def load_data(df, labels, atom_featurizer, bond_featurizer):
    print("---------------- Target loading --------------------")
    test_g = [smiles_to_bigraph(smi, node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for smi in df['smiles']]
    test_y = df[labels].values.tolist()
    test_data = list(zip(test_g, test_y))
    print("---------------- Target loading complete --------------------")
    return test_data

def dataloader_pred(train, labels, batch_size):
    s_tr = load_data(train, labels, CanonicalAtomFeaturizer(), CanonicalBondFeaturizer())
    tr_loader = DataLoader(s_tr, batch_size=batch_size, shuffle = False, collate_fn = collate, drop_last = False)

    return tr_loader

## Inference & Performance evaluation

In [None]:
GPU_NUM = 0
device = torch.device(f'cuda:{GPU_NUM}' if GPU_NUM >= 0 else 'cpu')
torch.cuda.set_device(device)

In [None]:
data_path = './data/external_data/external_total.csv'
df_external = pd.read_csv(data_path)

In [None]:
df_list = []

for col in ['carcino', 'dili', 'genotox']:
    df = df_external[['smiles', col]].dropna().reset_index(drop = True)
    df.columns = ['smiles', 'label']
    df['data'] = col
    df_list.append(df)

In [None]:
total = pd.concat(df_list, axis = 0).reset_index(drop = True)

In [None]:
batch_size = 64
ex_loader= dataloader_pred(total,'label', batch_size)

In [None]:
GPU_NUM = 0
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device)

In [None]:
model, tox21_model = mtl_building()
tox21_state = torch.load("./model/pretrained_ckpts/tox21.pth", map_location=device)
tox21_model.load_state_dict(tox21_state['model_state_dict'], strict = False)
model.to(device)
tox21_model.to(device)

In [None]:
state = torch.load("./invivo_best.pth", map_location=device)
model.load_state_dict(state['model_state_dict'], strict = False)
model.gnn.load_state_dict(copy.deepcopy(tox21_model.gnn.state_dict()))
for i in range(12):
    model.readout[i].load_state_dict(copy.deepcopy(tox21_model.readout[i].state_dict()))

In [None]:
test_pred = invivo_inference(model, ex_loader, device)

In [None]:
total['carcino_pred'] = test_pred[:, 0, 0]
total['dili_pred'] = test_pred[:, 1, 0]
total['genotox_pred'] = test_pred[:, 2, 0]

score_dict = {}

for n_data in set(total['data']):
    df = total[total['data'] == n_data]
    score_dict[n_data] = score(df['label'], df[n_data +  '_pred'])

In [None]:
pred_metrics = pd.DataFrame(score_dict).T.reset_index()
pred_metrics.columns = ['task','loss', 'pre', 'sen', 'spe', 'acc', 'bac', 'f1', 'aupr', 'auc']

In [None]:
pred_metrics