# granite.materials.smi-TED - INFERENCE (Classification)

In [None]:
# Install extra packages for notebook
%pip install seaborn xgboost

In [1]:
import sys
sys.path.append('../inference')

In [2]:
# materials.smi-ted
from smi_ted_light.load import load_smi_ted

# Data
import torch
import pandas as pd

# Chemistry
from rdkit import Chem
from rdkit.Chem import PandasTools
from rdkit.Chem import Descriptors
PandasTools.RenderImagesInAllDataFrames(True)

In [3]:
# function to canonicalize SMILES
def normalize_smiles(smi, canonical=True, isomeric=False):
    try:
        normalized = Chem.MolToSmiles(
            Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
        )
    except:
        normalized = None
    return normalized

### Import smi-ted

In [None]:
model_smi_ted = load_smi_ted(
    folder='../inference/smi_ted_light',
    ckpt_filename='smi-ted-Light_40.pt'
)

## BBBP Dataset

### Experiments - Data Load

In [5]:
df_train = pd.read_csv("../finetune/moleculenet/bbbp/train.csv")
df_test = pd.read_csv("../finetune/moleculenet/bbbp/test.csv")

### SMILES canonization

In [6]:
df_train['norm_smiles'] = df_train['smiles'].apply(normalize_smiles)
df_train_normalized = df_train.dropna()
print(df_train_normalized.shape)
df_train_normalized.head()

[10:35:10] Explicit valence for atom # 1 N, 4, is greater than permitted
[10:35:10] Explicit valence for atom # 6 N, 4, is greater than permitted
[10:35:10] Explicit valence for atom # 6 N, 4, is greater than permitted
[10:35:10] Explicit valence for atom # 11 N, 4, is greater than permitted
[10:35:10] Explicit valence for atom # 5 N, 4, is greater than permitted


Unnamed: 0,num,name,p_np,smiles,norm_smiles
0,1,Propanolol,1,[Cl].CC(C)NCC(O)COc1cccc2ccccc12,CC(C)NCC(O)COc1cccc2ccccc12.[Cl]
1,2,Terbutylchlorambucil,1,C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl,CC(C)(C)OC(=O)CCCc1ccc(N(CCCl)CCCl)cc1
2,3,40730,1,c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...,CC1COc2c(N3CCN(C)CC3)c(F)cc3c(=O)c(C(=O)O)cn1c23
3,4,24,1,C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C,CC(=O)NCCCOc1cccc(CN2CCCCC2)c1
4,6,cefoperazone,1,CCN1CCN(C(=O)N[C@@H](C(=O)N[C@H]2[C@H]3SCC(=C(...,CCN1CCN(C(=O)NC(C(=O)NC2C(=O)N3C(C(=O)O)=C(CSc...


In [7]:
df_test['norm_smiles'] = df_test['smiles'].apply(normalize_smiles)
df_test_normalized = df_test.dropna()
print(df_test_normalized.shape)
df_test_normalized.head()

[10:35:10] Explicit valence for atom # 12 N, 4, is greater than permitted
[10:35:10] Explicit valence for atom # 5 N, 4, is greater than permitted


Unnamed: 0,num,name,p_np,smiles,norm_smiles
0,13,18,1,C(Cl)Cl,ClCCl
1,23,SKF-93619,0,c1cc2c(cc(CC3=CNC(=NC3=O)NCCSCc3oc(cc3)CN(C)C)...,CN(C)Cc1ccc(CSCCNc2nc(=O)c(Cc3ccc4ccccc4c3)c[n...
2,36,etomidate,1,CCOC(=O)c1cncn1C(C)c2ccccc2,CCOC(=O)c1cncn1C(C)c1ccccc1
3,37,11a,0,CN(C)c1cc(C2=NC(N)=NN2)ccn1,CN(C)c1cc(-c2nc(N)n[nH]2)ccn1
4,79,compound 45,1,N1(Cc2cc(OCCCNc3oc4ccccc4n3)ccc2)CCCCC1,c1cc(CN2CCCCC2)cc(OCCCNc2nc3ccccc3o2)c1


### Embeddings extraction 

#### smi-ted embeddings extraction

In [8]:
with torch.no_grad():
    df_embeddings_train = model_smi_ted.encode(df_train_normalized['norm_smiles'])
df_embeddings_train.head()

100%|██████████████████████████████████████████████████████████████████████████████████| 51/51 [01:42<00:00,  2.01s/it]


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.437216,-0.591722,0.064327,0.37402,0.530674,-0.644061,1.308132,0.089766,0.79052,0.208745,...,-1.325166,-0.083576,0.169542,0.359247,-0.652745,0.720491,-0.674187,0.692998,0.58614,-0.15964
1,0.344518,-0.417002,0.095745,0.355958,0.573043,-0.590275,1.069693,0.067722,0.788808,0.159196,...,-1.312417,-0.108733,0.217022,0.303689,-0.598965,0.647906,-0.665971,0.791802,0.62069,-0.107868
2,0.42919,-0.463548,0.056444,0.449927,0.536803,-0.749922,1.193833,0.082606,0.860289,0.162551,...,-1.304989,-0.148627,0.242042,0.344735,-0.70464,0.644775,-0.781006,0.737218,0.585378,-0.101726
3,0.43309,-0.523085,0.089723,0.410124,0.543406,-0.643019,1.203863,0.034183,0.769412,0.202444,...,-1.358917,-0.077457,0.228708,0.317883,-0.680223,0.531607,-0.709795,0.731389,0.567804,-0.087715
4,0.388439,-0.505907,0.072542,0.366502,0.533685,-0.701548,1.035558,0.038412,0.822911,0.163067,...,-1.271006,-0.176408,0.119732,0.294137,-0.677723,0.647654,-0.844426,0.756316,0.57051,-0.240002


In [9]:
with torch.no_grad():
    df_embeddings_test = model_smi_ted.encode(df_test_normalized['norm_smiles'])
df_embeddings_test.head()

100%|████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:18<00:00,  3.06s/it]


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.374249,-0.319258,-0.007041,0.44474,0.326733,-0.791473,1.121704,-0.082401,0.611456,0.289226,...,-1.462538,-0.302052,0.295552,-0.058292,-0.830317,0.545098,-0.460271,1.121116,0.685015,-0.452696
1,0.429165,-0.568106,0.11274,0.352434,0.512559,-0.604146,1.181835,0.067952,0.786974,0.128075,...,-1.226945,-0.078928,0.209471,0.266114,-0.762261,0.610678,-0.755716,0.734546,0.592978,-0.148244
2,0.411906,-0.510475,0.073013,0.346873,0.512772,-0.617251,1.191626,0.040101,0.722577,0.188637,...,-1.300556,-0.150738,0.148254,0.282793,-0.694715,0.556031,-0.660643,0.771227,0.559001,-0.000663
3,0.356794,-0.530958,0.050351,0.433593,0.592596,-0.573506,1.221863,0.025492,0.833165,0.214606,...,-1.406139,-0.107166,0.200126,0.289468,-0.770145,0.572746,-0.776744,0.855061,0.662799,-0.194416
4,0.422147,-0.490604,0.044331,0.367862,0.579012,-0.629399,1.139819,0.039814,0.728822,0.145328,...,-1.312775,-0.105048,0.175281,0.336174,-0.738811,0.530219,-0.763359,0.764997,0.583682,-0.109681


### Experiments - BBBP prediction using smi-ted latent spaces

#### XGBoost prediction using the whole Latent Space

In [10]:
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score

In [11]:
xgb_predict = XGBClassifier(n_estimators=2000, learning_rate=0.04, max_depth=8)
xgb_predict.fit(df_embeddings_train, df_train_normalized['p_np'])

In [12]:
# get XGBoost predictions
y_prob = xgb_predict.predict_proba(df_embeddings_test)[:, 1]

In [16]:
roc_auc = roc_auc_score(df_test_normalized["p_np"], y_prob)
print(f"ROC-AUC Score: {roc_auc:.4f}")

ROC-AUC Score: 0.9135
