In [1]:
import pandas as pd
import numpy as np
import json
import torch
from torch.autograd import Variable
from rdkit.Chem import PandasTools
from rdkit import Chem
from rdkit.Chem.MolStandardize.standardize import Standardizer
from models.model_mtnn import MultiTaskNN

In [2]:
# Endpoint list
ASSAYS_TASKS = ['BSEPi', 'BSEPs', 'PGPi', 'PGPs', 'MRP4i', 'MRP3i', 'MRP3s', 'MRP2i', 'MRP2s', 'BCRPi', 'BCRPs', 'OATP1B1i', 
               'OATP1B3i', 'NRF2', 'LXR', 'AHR', 'PPARa', 'PPARg', 'PXR', 'FXR', 'MTX_MP', 'MTX_RC', 'MTX_FOM', 'PLD', 'PLD_HTS', 'HTX', 'ERS', 'ARE']
DILI_TASKS = ['DILI_majority', 'DILI_sensitive', 'DILI_secure']

# Helper functions

## SMILES preparation and descriptor calculation

In [3]:
standardizer = Standardizer(max_tautomers=10)
include_stereoinfo = False
def standardizeMol(mol):
    """
    Standardizer inspired by MELLODDY consortium
    :param mol: rdkit molecule object
    :return: cleaned rdkit molecule object
    """
    if mol is None:
        return mol
    try:
        mol = standardizer.charge_parent(mol)
        mol = standardizer.isotope_parent(mol)
        if include_stereoinfo is False:
            mol = standardizer.stereo_parent(mol)
        mol = standardizer.tautomer_parent(mol)
        return standardizer.standardize(mol)
    except Chem.rdchem.AtomValenceException:
        return None
    
def prepare_structures(input_df, smiles_col='smiles'):
    '''
    Standardize compounds, get canonical smiles and deduplicate structures based on canonical smiles
    '''
    df = input_df.copy()
    PandasTools.AddMoleculeColumnToFrame(df, smiles_col,'molecule', includeFingerprints=False)
    len_1 = len(df)
    df.dropna(subset=['molecule'], axis=0, inplace=True)
    len_2 = len(df)
    print(f'No. of missing molecules: {len_1-len_2}')

    # Standardize mols and get canonical smiles
    for idx in df.index:
        try:
            stand_mol = standardizeMol(df.loc[idx, 'molecule'])
            df.loc[idx, 'canonical_smiles'] = Chem.MolToSmiles(stand_mol, canonical=True)
        except:
            print(f'Failed to standardize {df.loc[idx, smiles_col]}')
            df.loc[idx, 'canonical_smiles'] = None
    
    df.drop(['molecule'], axis=1, inplace=True)
    df.dropna(subset=['canonical_smiles'], inplace=True)
    return df

## Inference functions

In [4]:
def get_mean_prediction_ensemble(pred_task, tasks):
    ### Average predictions and get STD for uncertainty estimation
    mean_pred = pd.DataFrame()
    for task in tasks:
        pred_task_mean = np.mean(pred_task[task], axis=0)
        pred_task_std = np.std(pred_task[task], axis=0)
        mean_pred[f'{task}_mean'] = pred_task_mean
        mean_pred[f'{task}_std'] = pred_task_std
        # binary class label
        mean_pred[f'{task}_class'] = (pred_task_mean >= 0.5).astype(int)
    return mean_pred

def make_predictions_with_ensemble(Xtest):
    ### apply ensemble model on ONTOX tasks to test set
    tasks = DILI_TASKS + ASSAYS_TASKS
    # read HPs from json file
    with open('models/hyperparameters.json', 'r') as openfile:
        hparam_dict = json.load(openfile)

        
    pred_task = {}
    for split_id in range(5): # ensemble with 5 models (one per split)
        print(f'Split: {split_id}')
        # load models and make predictions
        model_file = f'models/model_mtnn_all_tasks_cddd_{split_id}.pth'
        hparam = hparam_dict[str(split_id)]

        device='cpu'
        model = MultiTaskNN(input_size=512, params=hparam, n_tasks=len(tasks))
        model.load_state_dict(torch.load(model_file, map_location=device))
        model.to(device)
        model.eval()

        # get predictions from single model
        test_data = Variable(torch.from_numpy(Xtest).float())
        preds_list = model(test_data)
        predictions = [torch.Tensor.cpu(p).detach().numpy() for p in preds_list]
        predictions = np.array(predictions).squeeze().T

        # calculate predictions for all tasks
        for i, task in enumerate(tasks):
            if split_id == 0:
                pred_task[task] = predictions[:,i]
            else:
                pred_task[task] = np.vstack((pred_task[task], predictions[:,i]))
        display(pred_task[task].shape)
        
    # get mean prediction and STD from ensemble
    mean_pred = get_mean_prediction_ensemble(pred_task, tasks)
    return mean_pred, pred_task

# Make predictions

## Load test set and prepare SMILES

In [5]:
test_df = pd.read_csv('example_data_with_CDDD.csv')
test_df

Unnamed: 0,compoundName,compoundCas,ID,SMILES,cddd_0,cddd_1,cddd_2,cddd_3,cddd_4,cddd_5,...,cddd_502,cddd_503,cddd_504,cddd_505,cddd_506,cddd_507,cddd_508,cddd_509,cddd_510,cddd_511
0,cefadroxil,50370-12-2,193,O=C(C(c1ccc(cc1)O)N)NC1C(=O)N2C1SCC(=C2C(=O)O)C,-0.345609,0.435677,0.443579,-0.230919,-0.15522,0.435055,...,-0.012822,-0.649319,-0.259375,0.154243,-0.998582,-0.656491,-0.047186,-0.139226,0.290642,0.249692
1,cefalexin,15686-71-2,194,NC(c1ccccc1)C(=O)NC1C(=O)N2C1SCC(=C2C(=O)O)C,-0.206731,0.223643,0.343376,-0.291107,-0.530366,0.471886,...,0.337684,-0.777834,-0.483351,0.02282,-0.997673,-0.656915,-0.076998,0.225208,0.507705,0.675936
2,cefroxadine,51762-05-1,195,COC1=C(C(=O)O)N2C(SC1)C(C2=O)NC(=O)C(C1=CCC=CC1)N,0.280529,0.379214,-0.035095,-0.207573,-0.135618,0.873147,...,0.612471,-0.52737,-0.742913,0.255554,-0.999407,0.026243,-0.243761,0.219092,0.144513,0.274512
3,cefaclor,53994-73-3,196,NC(c1ccccc1)C(=O)NC1C(=O)N2C1SCC(=C2C(=O)O)Cl,-0.463277,0.523447,0.274103,-0.004757,0.219428,-0.110411,...,0.632502,-0.643733,-0.514364,0.199848,-0.998535,-0.372263,-0.219406,0.443599,0.693524,0.61474
4,cidofovir,113852-37-2,197,OCC(Cn1ccc(nc1=O)N)OCP(=O)(O)O,0.827905,0.151661,-0.634442,-0.624787,-0.288666,-0.424812,...,0.073548,-0.340514,0.626049,-0.161234,-0.997134,0.118878,-0.430494,-0.616421,-0.406417,0.623281
5,adefovir,106941-25-7,198,Nc1ncnc2c1ncn2CCOCP(=O)(O)O,0.878535,0.225458,-0.857043,-0.529709,-0.536128,-0.627331,...,0.28931,0.145662,0.324091,0.730698,-0.99445,0.76464,-0.707223,-0.717028,-0.842849,0.807021
6,tenidap,120210-48-2,199,Clc1ccc2c(c1)c(C(=O)c1cccs1)c(n2C(=O)N)O,0.546082,0.275444,0.562035,-0.475908,0.142944,0.121655,...,-0.527007,0.749456,0.639992,-0.27546,-0.990277,-0.758119,0.439286,0.140462,0.040768,0.497276
7,vinorelbine,71486-22-1,200,CCC1=CC2CN(C1)Cc1c3ccccc3[nH]c1C(C2)(C(=O)OC)c...,0.565227,-0.215023,0.411864,0.035846,-0.514231,0.32349,...,0.437771,-0.251952,-0.03995,0.711807,-0.997629,-0.760012,0.284355,0.895465,0.377786,-0.2757
8,2-(phosphonomethyl)pentanedioic acid,173039-10-6,201,OC(=O)CCC(C(=O)O)CP(=O)(O)O,-0.267367,0.143161,-0.960844,0.161823,-0.61905,-0.629508,...,0.007566,-0.600597,-0.514969,-0.079095,-0.958287,-0.522038,-0.22399,-0.554607,0.661822,0.335066
9,sulfadiazine,68-35-9,202,Nc1ccc(cc1)S(=O)(=O)Nc1ncccn1,-0.411361,-0.235456,-0.189376,-0.08984,0.002908,-0.29624,...,-0.280758,-0.214666,-0.109768,-0.58141,-0.935882,-0.124828,0.327465,-0.676434,-0.438199,0.238666


In [6]:
test_df = prepare_structures(test_df, smiles_col='SMILES')
print(len(test_df))

No. of missing molecules: 0
16


### Note: in this example the CDDD descriptors used as input for the model are already calculated. For new compounds, the CDDD descriptors need to be calculated at this step as described here: https://github.com/jrwnter/cddd

In [7]:
Xtest = test_df[[c for c in test_df if 'cddd' in c]].values

## Make predictions for all tasks with the 10 models in the ensemble

In [8]:
mean_pred, pred_task = make_predictions_with_ensemble(Xtest)
test_pred = test_df.merge(mean_pred, right_index=True, left_index=True)
test_pred

Split: 0


(16,)

Split: 1


(2, 16)

Split: 2


(3, 16)

Split: 3


(4, 16)

Split: 4


(5, 16)

Unnamed: 0,compoundName,compoundCas,ID,SMILES,cddd_0,cddd_1,cddd_2,cddd_3,cddd_4,cddd_5,...,PLD_HTS_class,HTX_mean,HTX_std,HTX_class,ERS_mean,ERS_std,ERS_class,ARE_mean,ARE_std,ARE_class
0,cefadroxil,50370-12-2,193,O=C(C(c1ccc(cc1)O)N)NC1C(=O)N2C1SCC(=C2C(=O)O)C,-0.345609,0.435677,0.443579,-0.230919,-0.15522,0.435055,...,0,0.00036,0.000606,0,0.000211,0.000422,0,0.002386,0.00365,0
1,cefalexin,15686-71-2,194,NC(c1ccccc1)C(=O)NC1C(=O)N2C1SCC(=C2C(=O)O)C,-0.206731,0.223643,0.343376,-0.291107,-0.530366,0.471886,...,0,0.000211,0.000363,0,2.2e-05,4.4e-05,0,0.000227,0.000303,0
2,cefroxadine,51762-05-1,195,COC1=C(C(=O)O)N2C(SC1)C(C2=O)NC(=O)C(C1=CCC=CC1)N,0.280529,0.379214,-0.035095,-0.207573,-0.135618,0.873147,...,0,0.000293,0.000506,0,2.4e-05,4.7e-05,0,0.002713,0.003312,0
3,cefaclor,53994-73-3,196,NC(c1ccccc1)C(=O)NC1C(=O)N2C1SCC(=C2C(=O)O)Cl,-0.463277,0.523447,0.274103,-0.004757,0.219428,-0.110411,...,0,0.001413,0.002644,0,3e-05,6e-05,0,0.00667,0.011748,0
4,cidofovir,113852-37-2,197,OCC(Cn1ccc(nc1=O)N)OCP(=O)(O)O,0.827905,0.151661,-0.634442,-0.624787,-0.288666,-0.424812,...,0,0.000269,0.000501,0,0.000194,0.000387,0,0.034363,0.030649,0
5,adefovir,106941-25-7,198,Nc1ncnc2c1ncn2CCOCP(=O)(O)O,0.878535,0.225458,-0.857043,-0.529709,-0.536128,-0.627331,...,0,0.000548,0.000889,0,0.000445,0.000888,0,0.031498,0.018437,0
6,tenidap,120210-48-2,199,Clc1ccc2c(c1)c(C(=O)c1cccs1)c(n2C(=O)N)O,0.546082,0.275444,0.562035,-0.475908,0.142944,0.121655,...,0,0.057262,0.09989,0,0.015813,0.017492,0,0.38261,0.305057,0
7,vinorelbine,71486-22-1,200,CCC1=CC2CN(C1)Cc1c3ccccc3[nH]c1C(C2)(C(=O)OC)c...,0.565227,-0.215023,0.411864,0.035846,-0.514231,0.32349,...,1,0.262759,0.304819,0,0.013343,0.015159,0,0.19587,0.280803,0
8,2-(phosphonomethyl)pentanedioic acid,173039-10-6,201,OC(=O)CCC(C(=O)O)CP(=O)(O)O,-0.267367,0.143161,-0.960844,0.161823,-0.61905,-0.629508,...,0,0.000258,0.000516,0,0.000343,0.000683,0,0.047065,0.071224,0
9,sulfadiazine,68-35-9,202,Nc1ccc(cc1)S(=O)(=O)Nc1ncccn1,-0.411361,-0.235456,-0.189376,-0.08984,0.002908,-0.29624,...,0,0.000591,0.000728,0,3.6e-05,7.1e-05,0,0.008832,0.008288,0
