In [94]:
from functools import partial
import json
import sys

import numpy as np
import pandas as pd
from rdkit import RDLogger
from sklearn.neighbors import KNeighborsRegressor

sys.path.insert(0, '../agenticadmet')
from eval import extract_preds, extract_refs, eval_admet
from utils import ECFP_from_smiles, tanimoto_similarity

In [95]:
logger = RDLogger.logger()
logger.setLevel(RDLogger.CRITICAL)

In [96]:
RANDOM_SEED = 42
SPLIT = 0
TOPK = 8

In [97]:
TARGET_COLUMNS = ["HLM", "MLM", "LogD", "KSOL", "MDR1-MDCKII"]
PROPERTIES = ['LogHLM', 'LogMLM', 'LogD', 'LogKSOL', 'LogMDR1-MDCKII']
PROPERTY = 'LogD'

In [98]:
data = pd.read_csv(f'../data/asap/datasets/rnd_splits/split_{SPLIT}.csv')
data

Unnamed: 0,smiles,HLM,KSOL,LogD,MLM,MDR1-MDCKII,smiles_std,cxsmiles_std,mol_idx,smiles_ext,LogHLM,LogMLM,LogKSOL,LogMDR1-MDCKII,split
0,COC1=CC=CC(Cl)=C1NC(=O)N1CCC[C@H](C(N)=O)C1 |a...,,,0.3,,2.0,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1 |a:16|,191,|a:16|,,,,0.477121,val
1,O=C(NCC(F)F)[C@H](NC1=CC2=C(C=C1Br)CNC2)C1=CC(...,,333.0,2.9,,0.2,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,335,|&1:7|,,,2.523746,0.079181,train
2,O=C(NCC(F)F)[C@H](NC1=CC=C2CNCC2=C1)C1=CC(Br)=...,,,0.4,,0.5,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,336,|&1:7|,,,,0.176091,train
3,NC(=O)[C@H]1CCCN(C(=O)CC2=CC=CC3=C2C=CO3)C1 |&...,,376.0,1.0,,8.5,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1 |&1:3|,300,|&1:3|,,,2.576341,0.977724,train
4,CC1=CC(CC(=O)N2CCC[C@H](C(N)=O)C2)=CC=N1 |&1:11|,,375.0,-0.3,,0.9,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1 |&1:11|,249,|&1:11|,,,2.575188,0.278754,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
399,CC(C)NC[C@H](O)COC1=CC=CC2=CC=CC=C12 |&1:5|,25.5,,,63.0,,CC(C)NC[C@H](O)COc1cccc2ccccc12,CC(C)NC[C@H](O)COc1cccc2ccccc12 |&1:5|,22,|&1:5|,1.423246,1.806180,,,val
400,O=C(O)CC1=CC=CC=C1NC1=C(Cl)C=CC=C1Cl,216.0,,,386.0,,O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,380,,2.336460,2.587711,,,val
401,NCC1=CC(Cl)=CC(C(=O)NC2=CC=C3CNCC3=C2)=C1,,,2.0,,,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,303,,,,,,train
402,COC(=O)NC1=NC2=CC=C(C(=O)C3=CC=CC=C3)C=C2N1,,,2.9,,,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,166,,,,,,train


In [99]:
train = data[(data['split'] == 'train') & ~data[PROPERTY].isna()].reset_index(drop=True)
val = data[data['split'] == 'val'].reset_index(drop=True)
train_ecfp = np.array(train['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
val_ecfp = np.array(val['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
train2train_dist = tanimoto_similarity(train_ecfp, train_ecfp)
val2train_dist = tanimoto_similarity(val_ecfp, train_ecfp)
train2train_dist.shape, val2train_dist.shape

((257, 257), (81, 257))

In [100]:
def get_all_topk_smiles_with_properties(ref_data, query_data, query2ref_dist, topk=TOPK, property=PROPERTY):
    for i in range(query2ref_dist.shape[0]):
        query_smiles = query_data.iloc[i]['cxsmiles_std']
        query_property = query_data.iloc[i][property]
        # if np.isnan(query_property):
        #     continue

        dist = query2ref_dist[i]
        order = np.argsort(dist)[::-1]
        ordered_dist = dist[order]
        order = order[~np.isclose(ordered_dist, 1.0)]  # remove self-similarity including duplicates
        topk_idx = order[:topk]
        topk_smiles = ref_data.iloc[topk_idx]['cxsmiles_std'].tolist()
        topk_properties = ref_data.iloc[topk_idx][property].tolist()
        yield topk_smiles, topk_properties, query_smiles, query_property

In [101]:
train_dataset = np.stack([
    input_properties
    for input_smiles, input_properties, query_smiles, query_property in get_all_topk_smiles_with_properties(
        ref_data=train,
        query_data=train,
        query2ref_dist=train2train_dist,
        topk=TOPK,
        property=PROPERTY
    )
])
val_dataset = np.stack([
    input_properties
    for input_smiles, input_properties, query_smiles, query_property in get_all_topk_smiles_with_properties(
        ref_data=train,
        query_data=val,
        query2ref_dist=val2train_dist,
        topk=TOPK,
        property=PROPERTY
    )
])
train_dataset.shape, val_dataset.shape

((257, 8), (81, 8))

In [102]:
val

Unnamed: 0,smiles,HLM,KSOL,LogD,MLM,MDR1-MDCKII,smiles_std,cxsmiles_std,mol_idx,smiles_ext,LogHLM,LogMLM,LogKSOL,LogMDR1-MDCKII,split
0,COC1=CC=CC(Cl)=C1NC(=O)N1CCC[C@H](C(N)=O)C1 |a...,,,0.30,,2.0,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1 |a:16|,191,|a:16|,,,,0.477121,val
1,O=C(NCC(F)F)[C@H](NC1=CN=C2CNCC2=C1)C1=CC(Cl)=...,,362.0,1.50,,0.8,O=C(NCC(F)F)[C@H](Nc1cnc2c(c1)CNC2)c1cc(Cl)cc(...,O=C(NCC(F)F)[C@H](Nc1cnc2c(c1)CNC2)c1cc(Cl)cc(...,341,|&1:7|,,,2.559907,0.255273,val
2,CC(C)NC(=O)[C@H](NC1=CC=C2CNCC2=C1)C1=CC(Cl)=C...,,134.0,2.80,11.0,0.2,CC(C)NC(=O)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Cl)cc2[...,CC(C)NC(=O)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Cl)cc2[...,19,|o1:6|,,1.079181,2.130334,0.079181,val
3,O=C(NC1=CC=C2CNCC2=C1)C1=CC(F)=CC2=C1N=C(C1=CC...,,6.0,2.90,36.8,0.1,O=C(Nc1ccc2c(c1)CNC2)c1cc(F)cc2[nH]c(-c3ccc(F)...,O=C(Nc1ccc2c(c1)CNC2)c1cc(F)cc2[nH]c(-c3ccc(F)...,369,,,1.577492,0.845098,0.041393,val
4,O=C(NC1=CC=C2CNCC2=C1)C1=CC(Cl)=CC2=C1C=NN2C1CCC1,,172.0,2.00,13.4,1.0,O=C(Nc1ccc2c(c1)CNC2)c1cc(Cl)cc2c1cnn2C1CCC1,O=C(Nc1ccc2c(c1)CNC2)c1cc(Cl)cc2c1cnn2C1CCC1,365,,,1.158362,2.238046,0.301030,val
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
76,COC1=CC=C([C@H](CC(=O)O)NC(=O)C2=NC=NC3=C2C=CN...,7.0,,-0.38,2.0,1.7,COc1ccc([C@H](CC(=O)O)NC(=O)c2ncnc3[nH]ccc23)cc1,COc1ccc([C@H](CC(=O)O)NC(=O)c2ncnc3[nH]ccc23)c...,184,|&1:6|,0.903090,0.477121,,0.431364,val
77,CNC(=O)CN1C[C@]2(CCN(C3=CN=CC4=CC=C(OC[C@H](O)...,8.0,383.0,0.19,2.0,1.4,CNC(=O)CN1C[C@]2(CCN(c3cncc4ccc(OC[C@H](O)CN(C...,CNC(=O)CN1C[C@]2(CCN(c3cncc4ccc(OC[C@H](O)CN(C...,132,"|&1:7,&2:21|",0.954243,0.477121,2.584331,0.380211,val
78,C=CC(=O)NC1=CC=CC(N(CC2=CC=CC(Cl)=C2)C(=O)CC2=...,1070.0,24.7,3.80,2380.0,8.0,C=CC(=O)Nc1cccc(N(Cc2cccc(Cl)c2)C(=O)Cc2cncc3c...,C=CC(=O)Nc1cccc(N(Cc2cccc(Cl)c2)C(=O)Cc2cncc3c...,6,,3.029789,3.376759,1.409933,0.954243,val
79,CC(C)NC[C@H](O)COC1=CC=CC2=CC=CC=C12 |&1:5|,25.5,,,63.0,,CC(C)NC[C@H](O)COc1cccc2ccccc12,CC(C)NC[C@H](O)COc1cccc2ccccc12 |&1:5|,22,|&1:5|,1.423246,1.806180,,,val


In [103]:
knn = KNeighborsRegressor(n_neighbors=TOPK, weights='distance', metric='euclidean')
knn.fit(train_dataset, train['LogD'])
predictions = knn.predict(val_dataset)
val[f'pred_{PROPERTY}'] = predictions
val_preds = extract_preds(val, target_columns=[PROPERTY])
val_refs = extract_refs(val, target_columns=[PROPERTY])
metrics = eval_admet(val_preds, val_refs, target_columns=[PROPERTY])
print(json.dumps(metrics, indent=2))

{
  "LogD": {
    "mean_absolute_error": 0.6224259958062731,
    "r2": 0.5486611529443856
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.6224259958062731,
    "macro_r2": 0.5486611529443856
  }
}
