# Toxicity Prediction with Discovery Workbench (DWb)

**Required libraries:** torch, pytorch-fast-transformers, pytorch-lightning, transformers, scipy, scikit-learn

In [1]:
from kbcomposer import DWB, KBComposer

import torch
from models.molformer_predict_tox import LightningModule
from fast_transformers.masking import LengthMask as LM
from helper import dotdict
from helper import convert_to_mgkg
from helper import convert_to_epa
import yaml
from yaml.loader import SafeLoader
from tokenizer import MolTranBertTokenizer
import pandas as pd
import numpy as np
import time
from pprint import pprint

In [2]:
# Configuring DWb's instance
kbc = KBComposer.get_kbc()
dwb = DWB.get_dwb()

In [3]:
# Function to load datasets from DWb using Pandas
def download_semantic_dataset_by_filename(dwb, filename):
    ds = [x for x in dwb.get_semantic_datasets() if x['label'] == filename][0]
    return dwb.download_semantic_dataset(ds['uri'])

## Upload a semantic dataset to DWb

In [13]:
# Select molecule concept
concepts = dwb.get_concepts()
concept = next(filter(lambda c: c['label'] == 'Molecule', concepts))
concept_uri = concept['uri']

# Upload semantic dataset
df = pd.read_csv('data/toxicity-prediction/toxicity-prediction_test_dwb.csv')
semantic_ds = dwb.upload_semantic_dataset(dataset_label='Toxicity prediction test',
                                          dataset_description='Test dataset for toxicity prediction.',
                                          member_type=concept_uri,
                                          dataframe=df,
                                          label_column='SMILES')
pprint(dwb.get_semantic_dataset_info(semantic_ds['uri']))

{'additionalAttributes': ['Unnamed: 0',
                          'BCUTi-1h',
                          'BCUTd-1h',
                          'TopoPSA(NO)',
                          'BCUTs-1l',
                          'VR3_D',
                          'SMR_VSA1',
                          'BCUTv-1h',
                          'SlogP_VSA2',
                          'BCUTd-1l',
                          'BCUTc-1h',
                          'SMR_VSA5',
                          'SdsssP',
                          'IC0',
                          'BCUTm-1l',
                          'Mv',
                          'BCUTi-1l',
                          'EState_VSA1',
                          'BCUTc-1l',
                          'TopoPSA',
                          'PEOE_VSA8',
                          'BCUTdv-1h',
                          'Xch-7d',
                          'MID_h',
                          'MDEC-23',
                          'MDEC-33',
                        

## Inference function

Model checkpoint can be downloaded from [box](https://ibm.ent.box.com/folder/201653400157).

In [8]:
def inference():
    
    print('Import model.')
    
    # Network loading and parameters importing
    with open('data/hparams.yaml') as f:
        data = yaml.load(f, Loader=SafeLoader)
        print(data)

    hparams = dotdict(data)
    tokenizer = MolTranBertTokenizer('data/bert_vocab.txt')
    model = LightningModule(hparams, tokenizer).load_from_checkpoint('data/last.ckpt',
                                                                     strict=False,
                                                                     config=hparams,
                                                                     tokenizer=tokenizer,
                                                                     vocab=len(tokenizer.vocab),
                                                                     map_location=torch.device('cpu'))
    model.eval()
    
    print('Retrieving semantic dataset from DWb.')
    
    # Importing data and transforming to the Network 
    #df = pd.read_csv('data/toxicity-prediction/toxicity-prediction_test.csv', nrows=20)
    df = download_semantic_dataset_by_filename(dwb, 'Toxicity prediction test')
    df = df.iloc[:20]
    
    print('Data transformation.')

    # Tokenizer - Creating tokens from SMILES
    tokens = model.tokenizer(df['SMILES'].tolist(), padding=True, truncation =True, add_special_tokens=True,return_tensors="pt" )
    idx = torch.tensor(tokens['input_ids'])
    mask = torch.tensor(tokens['attention_mask'])

    # Data transformation to feed the model
    token_embeddings = model.tok_emb(idx) # each index maps to a (learnable) vector
    x = model.drop(token_embeddings)
    x = model.blocks(x, length_mask=LM(mask.sum(-1)))
    token_embeddings = x

    input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    loss_input = sum_embeddings / sum_mask

    outmap_min, _ = torch.min(loss_input, dim=1, keepdim=True)
    outmap_max, _ = torch.max(loss_input, dim=1, keepdim=True)
    outmap = (loss_input - outmap_min) / (outmap_max - outmap_min) # Broadcasting rules apply
    
    print('Predicting...')
    
    outputs = model.net.forward(outmap).squeeze()
    
    # Converting to Epa Categories
    pred_epa = list(convert_to_epa(outputs.squeeze(),df['SMILES']))
    # Converting to Mg/Kg  Units
    pred_epa_mgkg = list(convert_to_mgkg(outputs.squeeze(),df['SMILES']))

    print(list(zip(pred_epa, pred_epa_mgkg)))

## Execution

In [20]:
import time
start_time = time.time()

# Load
inference()

total_time = time.time() - start_time        
if(total_time > 60):
    print(f" Total time = {total_time:.1f} seconds. Total time = {(total_time/60):.1f} minutes.")
else: 
    print(f" Total time = {total_time:.1f} seconds.")

Import model.
{'aug': None, 'batch_size': 32, 'checkpoint_dir': './checkpoints_toxicity-causal-epa/ld50/models', 'checkpoint_every': 1000, 'checkpoint_root': './checkpoints_toxicity-causal-epa/ld50', 'checkpoints_folder': './checkpoints_toxicity-causal-epa', 'd_dropout': 0.1, 'data_root': '../data/toxicity-prediction-causal', 'dataset_name': 'toxicity-prediction', 'dataset_names': ['valid', 'test'], 'desc_skip_connection': False, 'device': 'cuda', 'dims': [768, 768, 768, 1], 'dropout': 0.1, 'eval_dataset_length': None, 'fc_h': 512, 'fold': 0, 'from_scratch': False, 'lr_multiplier': 1, 'lr_start': 3e-05, 'max_epochs': 2000, 'measure_name': 'ld50', 'mode': 'avg', 'n_batch': 512, 'n_embd': 768, 'n_head': 12, 'n_jobs': 1, 'n_layer': 12, 'num_classes': None, 'num_feats': 32, 'num_workers': 8, 'results_dir': './checkpoints_toxicity-causal-epa/ld50/results', 'run_name': 'toxicity-prediction_ld50_rot_0_avg_lr_3e-05_batch_32_drop_0.1_[768, 768, 768, 1]', 'seed': 12345, 'seed_path': '../data/che

Lightning automatically upgraded your loaded checkpoint from v1.1.5 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file data/last.ckpt`


Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
Using Rotation Embedding
dropout is 0.1
smiles_embed_dim:  768
Retrieving semantic dataset from DWb.


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Data transformation.


  idx = torch.tensor(tokens['input_ids'])
  mask = torch.tensor(tokens['attention_mask'])
The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2425.)
  Q, _ = torch.qr(block)


Predicting...
[(2, 1439.8095763700335), (2, 2705.3795426443676), (2, 3617.1633299139435), (2, 1380.1931552040612), (1, 111.24934474181616), (2, 1453.3453255839745), (1, 423.8074093303452), (2, 3747.3138012803192), (1, 178.36611823361258), (2, 2885.6612386331763), (2, 650.359590992821), (2, 564.3009972254724), (1, 484.3250089075924), (1, 346.4622882608363), (1, 309.9638203285186), (2, 923.5172480555726), (2, 2098.4612331892195), (2, 1002.6421872697192), (2, 1373.438963361423), (3, 7929.463927319171)]
 Total time = 1259.7 seconds. Total time = 21.0 minutes.
