### Introduction
In this notebook we demonstrate how to use Thompson Sampling with a machine learning (ML) model as an objective.  The notebook has two parts. 
1. Build a classification model for the MAPK1 dataset from [LIT-PCBA](https://pubs.acs.org/doi/10.1021/acs.jcim.0c00155) and save it to disk.
2. Use the ML model as an objective in Thompson Sampling  
  
Note that you don't have to run Part 1 to run Part 2.  Part 2 uses a stored model, which is provided in the repository.

In [6]:
import pandas as pd
import useful_rdkit_utils as uru
from lightgbm import LGBMClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import numpy as np
import joblib

### Part 1: Build an ML model for the MAPK1 dataset
Read the data

In [7]:
df = pd.read_csv("data/MAPK1.csv")

Add a fingerprint to the dataframe

In [8]:
df['fp'] = df.SMILES.apply(uru.smi2numpy_fp)

Split into training and test sets

In [9]:
train, test = train_test_split(df)

Instantiate a classifier

In [10]:
cls = LGBMClassifier()

Train the classifier

In [11]:
cls.fit(np.stack(train.fp),train.active)

Predict the test set

In [None]:
prob = cls.predict_proba(np.stack(test.fp))

Calculate the ROC AUC

In [None]:
roc_auc_score(test.active, prob[:,1])

Save the model

In [None]:
joblib.dump(cls, 'mapk1_modl.pkl')

Read the model from disk

In [None]:
cls_pickle = joblib.load('mapk1_modl.pkl')

Predict based on the saved model

In [None]:
pred_pickle = cls_pickle.predict_proba(np.stack(test.fp))

Calculate the ROC AUC based on the saved model

In [None]:
roc_auc_score(test.active, pred_pickle[:,1])

### Part 2: Use the ML model as an objective in Thompson Sampling

In [1]:
from ts_main import read_input, run_ts, parse_input_dict

Read and process the input JSON file 

In [15]:
ts_input_dict = read_input('examples/quinazoline_classification_model.json')

Run Thompson Sampling

In [16]:
score_df = run_ts(ts_input_dict)

Warmup 1 of 3:   0%|          | 0/376 [00:00<?, ?it/s]

Warmup 2 of 3:   0%|          | 0/500 [00:00<?, ?it/s]

Warmup 3 of 3:   0%|          | 0/500 [00:00<?, ?it/s]

Cycle:   0%|          | 0/10000 [00:00<?, ?it/s]

         score                                             SMILES  \
4404  0.860390  CNS(=O)(=O)CCn1c([C@@H]2CN=C(N)NC2)nc2nc3ccccc...   
20    0.763437  NC1=NC[C@@H](c2nc3nc4ccccc4cc3c(=O)n2-c2ncn(CC...   
8212  0.733015  NC1=N[C@@H](c2nc3nc4ccccc4cc3c(=O)n2C23CNC(C(=...   
8213  0.733015  NC1=N[C@H](c2nc3nc4ccccc4cc3c(=O)n2C23CNC(C(=O...   
25    0.711902  NC1=N[C@H](c2nc3nc4ccccc4cc3c(=O)n2-c2ncn(CC(=...   
24    0.711902  NC1=N[C@@H](c2nc3nc4ccccc4cc3c(=O)n2-c2ncn(CC(...   
8211  0.709294  NC1=NC[C@@H](c2nc3nc4ccccc4cc3c(=O)n2C23CNC(C(...   
4405  0.684755  CNS(=O)(=O)CCn1c([C@@H]2N=C(N)NN2)nc2nc3ccccc3...   
4406  0.684755  CNS(=O)(=O)CCn1c([C@H]2N=C(N)NN2)nc2nc3ccccc3c...   
8214  0.652122  CN1C(=O)CC[C@@H]1c1nc2nc3ccccc3cc2c(=O)n1C12CN...   

                            Name  
4404   1723442_12416971_38415755  
20      1723442_2537440_38415755  
8212   1723442_575421981_6862752  
8213  1723442_575421981_16137394  
25      1723442_2537440_16137394  
24       1723442_2537440_68627

In [14]:
score_df.sort_values("score", ascending=False).head(10)

Unnamed: 0,score,SMILES,Name
7360,0.78741,O=C(NC1NCCN1)n1c(-c2n[nH]c(O)n2)nc2c(cnc3ccccc...,14982158_38864359_4343380
7357,0.785511,O=C(NC1NCCN1)n1c(-c2nn[nH]n2)nc2c(cnc3ccccc32)...,14982158_38864359_6505186
7359,0.751079,Nc1cnc(-c2nc3c(cnc4ccccc43)c(=O)n2C(=O)NC2NCCN...,14982158_38864359_26896534
4005,0.746542,O=C1CCC(C(=O)n2c(-c3n[nH]c(O)n3)nc3c(cnc4ccccc...,14982158_32016232_4343380
4003,0.744397,O=C1CCC(C(=O)n2c(-c3nn[nH]n3)nc3c(cnc4ccccc43)...,14982158_32016232_6505186
4004,0.705836,Nc1cnc(-c2nc3c(cnc4ccccc43)c(=O)n2C(=O)C2=NNC(...,14982158_32016232_26896534
2478,0.66122,O=C(n1c(-c2ncc(O)cn2)nc2c(cnc3ccccc32)c1=O)[C@...,14982158_96034624_39052256
5065,0.66122,O=C(n1c(-c2ncc(O)cn2)nc2c(cnc3ccccc32)c1=O)C1(...,14982158_140060494_39052256
4548,0.66122,O=C(n1c(-c2ncc(O)cn2)nc2c(cnc3ccccc32)c1=O)[C@...,14982158_72417683_39052256
332,0.66122,O=C(n1c(-c2ncc(O)cn2)nc2c(cnc3ccccc32)c1=O)[C@...,14982158_96034625_39052256
