### Import Data

In [1]:
import duckdb
import pandas as pd
import numpy as np

In [2]:
train_path = './data/train.parquet'

con = duckdb.connect()

df = con.query(f"""
                    (SELECT *
                    FROM parquet_scan('{train_path}')
                    WHERE protein_name = 'sEH'
                    ORDER BY random()
                    )""").df()

con.close()
#df.to_csv("sEH.csv")

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [2]:
df

Unnamed: 0.1,Unnamed: 0,id,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,protein_name,binds
0,0,183534197,O=C(Nc1ccc(C(=O)O)c(C(F)(F)F)c1)OCC1c2ccccc2-c...,Cl.Cl.NCc1cccc(-n2ccnn2)c1,Cc1nccn1-c1ncccc1CN,Cc1nccn1-c1ncccc1CNc1nc(NCc2cccc(-n3ccnn3)c2)n...,sEH,0
1,1,100275704,O=C(N[C@@H](Cc1ccc(I)cc1)C(=O)O)OCC1c2ccccc2-c...,CC(C)c1nnc([C@H]2C[C@H](CN)[C@H](O)C2)[nH]1,CC1CCCC(CN)O1,CC1CCCC(CNc2nc(NC[C@H]3C[C@H](c4nnc(C(C)C)[nH]...,sEH,0
2,2,294041408,O=C(O)[C@H]1Cc2ccccc2CN1C(=O)OCC1c2ccccc2-c2cc...,Nc1ccc2nccnc2c1,CCC1=NN(Cc2ccccc2C)C(=O)C1CCN,CCC1=NN(Cc2ccccc2C)C(=O)C1CCNc1nc(Nc2ccc3nccnc...,sEH,0
3,3,164133368,O=C(Nc1cc(C(=O)O)ccc1Br)OCC1c2ccccc2-c2ccccc21,N#Cc1cccnc1N,Cc1c([C@@H]2[C@@H](CN)CC(=O)N2C)cnn1C,Cc1c([C@@H]2[C@@H](CNc3nc(Nc4cc(C(=O)N[Dy])ccc...,sEH,0
4,4,275717072,O=C(O)C[C@H]1CCCN1C(=O)OCC1c2ccccc2-c2ccccc21,CC1CC(CN)C(C)O1,Nc1nc(Cl)c(C=O)c(Cl)n1,CC1CC(CNc2nc(Nc3nc(Cl)c(C=O)c(Cl)n3)nc(N3CCC[C...,sEH,0
...,...,...,...,...,...,...,...,...
98415605,98415605,198682535,O=C(Nc1ccc(Cl)cc1C(=O)O)OCC1c2ccccc2-c2ccccc21,Cc1ccc(N)nn1,Nc1ccc2nccnc2c1,Cc1ccc(Nc2nc(Nc3ccc4nccnc4c3)nc(Nc3ccc(Cl)cc3C...,sEH,0
98415606,98415606,68752757,N#Cc1ccc([C@H](CC(=O)O)NC(=O)OCC2c3ccccc3-c3cc...,CCOC(=O)c1ncccc1N,COC(=O)c1cc(N)cs1,CCOC(=O)c1ncccc1Nc1nc(Nc2csc(C(=O)OC)c2)nc(N[C...,sEH,0
98415607,98415607,2495852,C#CC[C@@](C)(NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,CS(=O)CC(O)CN.Cl,CC1CC(CN)C(C)O1,C#CC[C@@](C)(Nc1nc(NCC(O)CS(C)=O)nc(NCC2CC(C)O...,sEH,0
98415608,98415608,50927273,Cc1cc(C(=O)O)ccc1NC(=O)OCC1c2ccccc2-c2ccccc21,NCc1cccc(C(F)(F)F)n1,COc1cc(Br)ccc1N,COc1cc(Br)ccc1Nc1nc(NCc2cccc(C(F)(F)F)n2)nc(Nc...,sEH,0


### Load Data to ChemProp

In [3]:
from lightning import pytorch as pl

from chemprop import data, featurizers, nn
from chemprop.models import multi
from chemprop.models.utils import save_model, load_model

In [4]:
def loadData(df_input, smiles_cols, w):
    
    target_cols = ['binds']

    smiss = df_input.loc[:, smiles_cols].values
    ys = df_input.loc[:,target_cols].values

    all_data = [[data.MoleculeDatapoint.from_smi(smis[0], y, weight=(0.99-w)*y+w) for smis, y in zip(smiss, ys)]]
    all_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_cols))]

    train_idx, val_idx, test_idx = data.make_split_indices(all_data[0], seed=42)
    train_data, val_data, test_data = data.split_data_by_indices(all_data, train_idx, val_idx, test_idx)

    featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

    train_datasets = [data.MoleculeDataset(train_data[i], featurizer) for i in range(len(smiles_cols))]
    val_datasets = [data.MoleculeDataset(val_data[i], featurizer) for i in range(len(smiles_cols))]
    test_datasets = [data.MoleculeDataset(test_data[i], featurizer) for i in range(len(smiles_cols))]

    train_mcdset = data.MulticomponentDataset(train_datasets)
    val_mcdset = data.MulticomponentDataset(val_datasets)
    test_mcdset = data.MulticomponentDataset(test_datasets)

    nworkers = 4
    train_loader = data.build_dataloader(train_mcdset, num_workers=nworkers)
    val_loader = data.build_dataloader(val_mcdset, num_workers=nworkers, shuffle=False)
    test_loader = data.build_dataloader(test_mcdset, num_workers=nworkers, shuffle=False)
    
    return [train_loader, val_loader, test_loader]

In [5]:
def loadModel(smiles_cols):
    mcmp = nn.MulticomponentMessagePassing(
        blocks=[nn.BondMessagePassing(depth=4, d_h=600) for _ in range(len(smiles_cols))],
        n_components=len(smiles_cols),
    )
    agg = nn.NormAggregation()
    ffn = nn.BinaryClassificationFFN(input_dim=mcmp.output_dim, n_layers=4, dropout=0.2)
    metric_list = [nn.metrics.BCEMetric(), nn.metrics.BinaryAUPRCMetric(), ] 
    # Only the first metric is used for training and early stopping

    mcmpnn = multi.MulticomponentMPNN(mcmp, agg, ffn, batch_norm=True, metrics=metric_list)
    return mcmpnn

In [6]:
def downSampler(df, batch=[500000, 100000]):
    return pd.concat([df.loc[df["binds"]==0, :].sample(n=batch[0]), df.loc[df["binds"]==1, :].sample(n=batch[1])])

### Train the Model

In [7]:
df["binds"].value_counts()

binds
0    97691078
1      724532
Name: count, dtype: int64

In [8]:
import torch
torch.set_float32_matmul_precision('medium')

smiles_cols = ['buildingblock1_smiles', 'buildingblock2_smiles', 'buildingblock3_smiles', 'molecule_smiles']
#smiles_cols = ['molecule_smiles']

In [None]:
draw=[724532*5, 724532]
loader = loadData(downSampler(df, batch=draw), smiles_cols, w=draw[1]/draw[0])
mcmpnn = loadModel(smiles_cols)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=4,
    max_epochs=5,
)
trainer.fit(mcmpnn, loader[0], loader[1])
results = trainer.test(mcmpnn, loader[2])

model = mcmpnn
save_model("sEH_full_blk.pt", model)
model_dict = {"hyper_parameters": model.hparams, "state_dict": model.state_dict()}
torch.save(model_dict, "sEH_full_blk.pt")

### Test the Model

In [3]:
test_path = './data/test.parquet'

con = duckdb.connect()

df_test = con.query(f"""
                    (SELECT *
                    FROM parquet_scan('{test_path}')
                    WHERE protein_name = 'sEH'
                    )""").df()

con.close()

In [4]:
def loadTest(df_input, smiles_cols):
    
    target_cols = ['binds']

    smiss = df_input.loc[:, smiles_cols].values

    test_data = [[data.MoleculeDatapoint.from_smi(smis[0]) for smis in smiss]]
    test_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_cols))]

    featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

    test_datasets = [data.MoleculeDataset(test_data[i], featurizer) for i in range(len(smiles_cols))]
    test_mcdset = data.MulticomponentDataset(test_datasets)
    test_loader = data.build_dataloader(test_mcdset, num_workers=4, shuffle=False)
    
    return test_loader

In [9]:
model = multi.MulticomponentMPNN.load_from_file("sEH_full_blk.pt")

with torch.inference_mode():
    trainer = pl.Trainer(
        logger=None,
        enable_progress_bar=True,
        accelerator="auto",
        devices=1
    )
    test_preds = trainer.predict(model, loadTest(df_test, smiles_cols))
    
test_preds = np.concatenate(test_preds, axis=0)
df_test['binds'] = test_preds
df_test[['id', 'binds']].to_csv("pred_sEH.csv", index=False)
df_test

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting: |          | 0/? [00:00<?, ?it/s]

Unnamed: 0,id,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,protein_name,binds
0,295246832,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,sEH,4.422666e-10
1,295246835,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC(O)Cn1cnc2c(N)ncnc21,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,sEH,3.337481e-04
2,295246838,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC1(C)CCCC1(O)CN,C#CCCC[C@H](Nc1nc(NCC2(O)CCCC2(C)C)nc(Nc2ccc(C...,sEH,3.343873e-03
3,295246841,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,COC(=O)c1cc(Cl)sc1N,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2sc(Cl)c...,sEH,8.783300e-03
4,295246844,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CSC1CCC(CN)CC1,C#CCCC[C@H](Nc1nc(NCC2CCC(SC)CC2)nc(Nc2ccc(C=C...,sEH,8.777015e-04
...,...,...,...,...,...,...,...
558137,296921713,[N-]=[N+]=NCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc...,Nc1nncs1,Cn1ncc2cc(N)ccc21,Cn1ncc2cc(Nc3nc(Nc4nncs4)nc(N[C@@H](CCCN=[N+]=...,sEH,5.136990e-05
558138,296921716,[N-]=[N+]=NCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc...,Nc1nncs1,NCC1CCC2CC2C1,[N-]=[N+]=NCCC[C@H](Nc1nc(NCC2CCC3CC3C2)nc(Nc2...,sEH,2.350343e-06
558139,296921719,[N-]=[N+]=NCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc...,Nc1noc2ccc(F)cc12,COC(=O)c1ccnc(N)c1,COC(=O)c1ccnc(Nc2nc(Nc3noc4ccc(F)cc34)nc(N[C@@...,sEH,4.358083e-07
558140,296921722,[N-]=[N+]=NCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc...,Nc1noc2ccc(F)cc12,COC1CCC(CCN)CC1,COC1CCC(CCNc2nc(Nc3noc4ccc(F)cc34)nc(N[C@@H](C...,sEH,2.130598e-04
