# Inference

This script showcases the different models available in fishbAIT and how to use them efficiently.

In [1]:
import torch
import pandas as pd
import numpy as np
from inference_utils.fishbAIT_for_inference import fishbAIT_for_inference

Specify the model version and load the model

In [2]:
MODEL_TYPE = 'EC50EC10'
SPECIES_GROUP = 'invertebrates'
MODEL_VERSION = f'{MODEL_TYPE}_{SPECIES_GROUP}'

In [3]:
fishbait = fishbAIT_for_inference(model_version=MODEL_VERSION)
fishbait.load_fine_tuned_model()

Load the SMILES you wish to predict

In [4]:
data = pd.read_excel('../data/tutorials/Inference_example_2.xlsx')
data

Unnamed: 0,SMILES,cmpdname
0,CC(=O)Oc1ccccc1C(O)=O,Aspirin
1,[Cr],Chromium
2,[H+].[Cl-].CNCCC(Oc1ccc(cc1)C(F)(F)F)c2ccccc2,Fluoxetine hydrochloride
3,Clc1ccc(cc1)C(c2ccc(Cl)cc2)C(Cl)(Cl)Cl,Clofenotane
4,[Cu],Copper
...,...,...
995,[Pb++].[O-]c1c(cc(c([O-])c1[N+]([O-])=O)[N+]([...,Lead styphnate
996,CC(C)(C)C(O)(CCc1ccc(Cl)cc1)Cn2cncn2,Tebuconazole
997,[Na+].[Na+].[Na+].[Na+].OCCN(CCO)c1nc(Nc2ccc(c...,OpticalBrightenerBbu220
998,CNC.OC(=O)COc1ccc(Cl)cc1Cl,"2,4-D dimethylamine salt"


Specify the endpoint and effect you wish to predict and make the prediction

In [5]:
PREDICTION_ENDPOINT = 'EC10'
PREDICTION_EFFECT = 'MOR'
EXPOSURE_DURATION = 96
SMILES_COLUMN_NAME = 'SMILES'

In [6]:
results = fishbait.predict_toxicity(SMILES = data[SMILES_COLUMN_NAME].iloc[0:10].tolist(), exposure_duration=EXPOSURE_DURATION, endpoint=PREDICTION_ENDPOINT, effect=PREDICTION_EFFECT, return_cls_embeddings=True)
results

Renamed NOEC *EC10* in 0 positions


  0%|          | 0/2 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 2/2 [00:00<00:00,  9.26it/s]


Unnamed: 0,SMILES,exposure_duration,endpoint,effect,SMILES_Canonical_RDKit,predictions log10(mg/L),predictions (mg/L),CLS_embeddings
0,CC(=O)Oc1ccccc1C(O)=O,1.982271,EC10,MOR,CC(=O)Oc1ccccc1C(=O)O,1.588708,38.788929,"[0.9782297015190125, -0.9878295660018921, 0.33..."
1,[Cr],1.982271,EC10,MOR,[Cr],0.470526,2.954784,"[0.02584332972764969, -2.3276708126068115, 0.0..."
2,[H+].[Cl-].CNCCC(Oc1ccc(cc1)C(F)(F)F)c2ccccc2,1.982271,EC10,MOR,CNCCC(Oc1ccc(C(F)(F)F)cc1)c1ccccc1.[Cl-].[H+],-1.28614,0.051744,"[-0.3345903158187866, -1.1471542119979858, 0.7..."
3,Clc1ccc(cc1)C(c2ccc(Cl)cc2)C(Cl)(Cl)Cl,1.982271,EC10,MOR,Clc1ccc(C(c2ccc(Cl)cc2)C(Cl)(Cl)Cl)cc1,-2.561309,0.002746,"[-1.8162728548049927, -0.1933540254831314, 1.6..."
4,[Cu],1.982271,EC10,MOR,[Cu],-1.399814,0.039828,"[-2.431044578552246, -0.8788691759109497, -0.1..."
5,CCNc1nc(Cl)nc(NC(C)C)n1,1.982271,EC10,MOR,CCNc1nc(Cl)nc(NC(C)C)n1,-0.075029,0.841339,"[0.38325634598731995, -1.7486392259597778, -0...."
6,CN(C)C1=NC(=O)N(C2CCCCC2)C(=O)N1C,1.982271,EC10,MOR,CN(C)c1nc(=O)n(C2CCCCC2)c(=O)n1C,0.572498,3.736786,"[0.08783436566591263, -1.0063296556472778, -0...."
7,CC(Br)(CO)[N+]([O-])=O,1.982271,EC10,MOR,CC(Br)(CO)[N+](=O)[O-],1.217041,16.483162,"[0.16042600572109222, -1.1170353889465332, -0...."
8,c1ccc2c(c1)c3cccc4cccc2c34,1.982271,EC10,MOR,c1ccc2c(c1)-c1cccc3cccc-2c13,-1.436595,0.036594,"[-2.0834171772003174, -0.95207279920578, 0.297..."
9,[Cl-].[Cl-].[Zn++],1.982271,EC10,MOR,[Cl-].[Cl-].[Zn+2],-0.267941,0.539584,"[-1.3194429874420166, -1.9545468091964722, 0.2..."


If you run into truble using this model use the `__help__` function

In [7]:
fishbait.__help__()


        This is a python class used to load and use the fine-tuned deep-learning model `fishbAIT` for environmental toxicity predictions in fish algae and aquatic invertebrates.
        The models have been trained on a large corpus of SMILES (chemical representations) on data collected from various sources.

        Currently there are nine models available for use. The models are divided by toxicity endpoint and by species group. The models are the following:
        **Fish**
        - `EC50_fish` The EC50 model is trained on EC50 mortality (MOR) data and is thus suitable for the prediction of said endpoints.
        - `EC10_fish` The EC10 model is trained on EC10/NOEC data with various effects (mortality, intoxication, development, reproduction, morphology, growth and population) ab. (MOR, ITX, DVP, REP, MPH, GRO, POP)
        - `EC50EC10_fish` The EC50EC10 model is trained on EC50, EC10 and NOEC data with various effects (mortality, intoxication, development, reproduction, morphol

# Check the predictions compared to our training sets

In [8]:
from inference_utils.plots_for_space import PlotPCA_CLSProjection
from inference_utils.pytorch_data_utils import check_closest_chemical, check_training_data

Check if chemicals are present in training data. They may be present as either an:
- 'endpoint_match' i.e. the chemical was used for training this model for this species and endpoint.
- 'effect_match' i.e. the chemical was used for training this model for this species, endpoint and effect.

In [9]:
results = check_training_data(results, model_type=MODEL_TYPE, species_group=SPECIES_GROUP, endpoint=PREDICTION_ENDPOINT, effect=PREDICTION_EFFECT)
results

Unnamed: 0,SMILES,exposure_duration,endpoint,effect,SMILES_Canonical_RDKit,predictions log10(mg/L),predictions (mg/L),CLS_embeddings,endpoint match,effect match
0,CC(=O)Oc1ccccc1C(O)=O,1.982271,EC10,MOR,CC(=O)Oc1ccccc1C(=O)O,1.588708,38.788929,"[0.9782297015190125, -0.9878295660018921, 0.33...",1,1
1,[Cr],1.982271,EC10,MOR,[Cr],0.470526,2.954784,"[0.02584332972764969, -2.3276708126068115, 0.0...",1,1
2,[H+].[Cl-].CNCCC(Oc1ccc(cc1)C(F)(F)F)c2ccccc2,1.982271,EC10,MOR,CNCCC(Oc1ccc(C(F)(F)F)cc1)c1ccccc1.[Cl-].[H+],-1.28614,0.051744,"[-0.3345903158187866, -1.1471542119979858, 0.7...",1,1
3,Clc1ccc(cc1)C(c2ccc(Cl)cc2)C(Cl)(Cl)Cl,1.982271,EC10,MOR,Clc1ccc(C(c2ccc(Cl)cc2)C(Cl)(Cl)Cl)cc1,-2.561309,0.002746,"[-1.8162728548049927, -0.1933540254831314, 1.6...",1,1
4,[Cu],1.982271,EC10,MOR,[Cu],-1.399814,0.039828,"[-2.431044578552246, -0.8788691759109497, -0.1...",1,1
5,CCNc1nc(Cl)nc(NC(C)C)n1,1.982271,EC10,MOR,CCNc1nc(Cl)nc(NC(C)C)n1,-0.075029,0.841339,"[0.38325634598731995, -1.7486392259597778, -0....",1,1
6,CN(C)C1=NC(=O)N(C2CCCCC2)C(=O)N1C,1.982271,EC10,MOR,CN(C)c1nc(=O)n(C2CCCCC2)c(=O)n1C,0.572498,3.736786,"[0.08783436566591263, -1.0063296556472778, -0....",1,1
7,CC(Br)(CO)[N+]([O-])=O,1.982271,EC10,MOR,CC(Br)(CO)[N+](=O)[O-],1.217041,16.483162,"[0.16042600572109222, -1.1170353889465332, -0....",0,0
8,c1ccc2c(c1)c3cccc4cccc2c34,1.982271,EC10,MOR,c1ccc2c(c1)-c1cccc3cccc-2c13,-1.436595,0.036594,"[-2.0834171772003174, -0.95207279920578, 0.297...",1,1
9,[Cl-].[Cl-].[Zn++],1.982271,EC10,MOR,[Cl-].[Cl-].[Zn+2],-0.267941,0.539584,"[-1.3194429874420166, -1.9545468091964722, 0.2...",1,1


Next we check if which chemical is closest to the predicted chemicals by evaluating the CLS-embeddings against the training set's CLS-embedding by means of their cosine-similarity:
- cosine-similarity=1 --> Identical structures
- cosine-similarity=-1 --> completely oposite in terms of toxicity

In [10]:
results = check_closest_chemical(results=results, MODELTYPE=MODEL_TYPE, PREDICTION_SPECIES=SPECIES_GROUP, PREDICTION_ENDPOINT=PREDICTION_ENDPOINT, PREDICTION_EFFECT=PREDICTION_EFFECT)
results

Unnamed: 0,SMILES,exposure_duration,endpoint,effect,SMILES_Canonical_RDKit,predictions log10(mg/L),predictions (mg/L),CLS_embeddings,endpoint match,effect match,most similar chemical,cosine similarity
0,CC(=O)Oc1ccccc1C(O)=O,1.982271,EC10,MOR,CC(=O)Oc1ccccc1C(=O)O,1.588708,38.788929,"[0.9782297015190125, -0.9878295660018921, 0.33...",1,1,CC(=O)Oc1ccccc1C(=O)O,0.999999
1,[Cr],1.982271,EC10,MOR,[Cr],0.470526,2.954784,"[0.02584332972764969, -2.3276708126068115, 0.0...",1,1,[Cr],1.0
2,[H+].[Cl-].CNCCC(Oc1ccc(cc1)C(F)(F)F)c2ccccc2,1.982271,EC10,MOR,CNCCC(Oc1ccc(C(F)(F)F)cc1)c1ccccc1.[Cl-].[H+],-1.28614,0.051744,"[-0.3345903158187866, -1.1471542119979858, 0.7...",1,1,CNCCC(Oc1ccc(C(F)(F)F)cc1)c1ccccc1.[Cl-].[H+],1.0
3,Clc1ccc(cc1)C(c2ccc(Cl)cc2)C(Cl)(Cl)Cl,1.982271,EC10,MOR,Clc1ccc(C(c2ccc(Cl)cc2)C(Cl)(Cl)Cl)cc1,-2.561309,0.002746,"[-1.8162728548049927, -0.1933540254831314, 1.6...",1,1,Clc1ccc(C(c2ccc(Cl)cc2)C(Cl)(Cl)Cl)cc1,1.0
4,[Cu],1.982271,EC10,MOR,[Cu],-1.399814,0.039828,"[-2.431044578552246, -0.8788691759109497, -0.1...",1,1,[Cu],1.0
5,CCNc1nc(Cl)nc(NC(C)C)n1,1.982271,EC10,MOR,CCNc1nc(Cl)nc(NC(C)C)n1,-0.075029,0.841339,"[0.38325634598731995, -1.7486392259597778, -0....",1,1,CCNc1nc(Cl)nc(NC(C)C)n1,1.0
6,CN(C)C1=NC(=O)N(C2CCCCC2)C(=O)N1C,1.982271,EC10,MOR,CN(C)c1nc(=O)n(C2CCCCC2)c(=O)n1C,0.572498,3.736786,"[0.08783436566591263, -1.0063296556472778, -0....",1,1,CN(C)c1nc(=O)n(C2CCCCC2)c(=O)n1C,1.0
7,CC(Br)(CO)[N+]([O-])=O,1.982271,EC10,MOR,CC(Br)(CO)[N+](=O)[O-],1.217041,16.483162,"[0.16042600572109222, -1.1170353889465332, -0....",0,0,[In+3].[In+3].[O-2].[O-2].[O-2],0.903284
8,c1ccc2c(c1)c3cccc4cccc2c34,1.982271,EC10,MOR,c1ccc2c(c1)-c1cccc3cccc-2c13,-1.436595,0.036594,"[-2.0834171772003174, -0.95207279920578, 0.297...",1,1,c1ccc2c(c1)-c1cccc3cccc-2c13,1.0
9,[Cl-].[Cl-].[Zn++],1.982271,EC10,MOR,[Cl-].[Cl-].[Zn+2],-0.267941,0.539584,"[-1.3194429874420166, -1.9545468091964722, 0.2...",1,1,[Cl-].[Cl-].[Zn+2],1.0


# Plot PCA for all chemicals

Finally we can plot the chemical space built during training of the Transformer module in the model. The space is built by the CLS-embedidngs present in the training set of the model but can be used to project new chemicals onto. The space prepared in this example uses `show_all_predictions=True` which plots additional SMILES, not included in the training data to add interpretability. We also use `inference_df=results` to plot the predicted SMILES from above into the space, however this can be set to `None` if not desired.

The plot can be saved as interactive HTML by `fig.write_html(figurename.html')`

Note that in the hover text of each point, the L1Error from our 10x10-fold cross-validation is included from when the used model was evaluated on that chemical.

In [11]:
from inference_utils.plots_for_space import PlotPCA_CLSProjection

In [12]:
PlotPCA_CLSProjection(model_type=MODEL_TYPE, endpoint=PREDICTION_ENDPOINT, effect=PREDICTION_EFFECT, species_group=SPECIES_GROUP, show_all_predictions=True, inference_df=results)

  all_preds['pc1'], all_preds['pc2'] = pcac[:len(all_preds),0], pcac[:len(all_preds),1]
  all_preds['pc1'], all_preds['pc2'] = pcac[:len(all_preds),0], pcac[:len(all_preds),1]
