In [None]:
import pandas as pd
import torch
from lightning import pytorch as pl
from chemprop import data, models, nn, featurizers
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import IterableDataset
import rdkit
from rdkit import Chem
from torch.utils.data import IterableDataset
from sklearn.preprocessing import StandardScaler
import sys
import json
sys.path.append(os.path.abspath('../lrp_chemprop/'))
from Data_Preprocessor import Data_Preprocessor
from IterableMolDatapoints import IterableMolDatapoints

First, we prepare data using IterableMolDatapoints. Check usage here: 

https://github.com/DinhLongHuynh/lrp_chemprop/blob/main/lrp_chemprop/IterableMolDatapoints.py

https://medium.com/@dinhlong240600/large-dataset-on-8gb-ram-let-iterabledataset-handle-442bb4764c7a


In [None]:
data_path = '../DRD2_diverse_data.csv'
smiles_column = 'smiles'
target_column = 'docking_score'
weight_column = 'weight_lowscores'
split_column = 'split'
epochs = 50
batch_size = 64

# Prepare data
df = pd.read_csv(data_path)
df_train = df[df[split_column]=='train']
df_val = df[df[split_column]=='val']
scaler = StandardScaler().fit(df_train[[target_column]])


train_streaming_dataset = IterableMolDatapoints(
    df=df_train,
    smiles_column=smiles_column,
    target_column=target_column,
    weight_column=weight_column,
    scaler=scaler, shuffle=True, size_at_time=640)

train_loader = data.build_dataloader(
    train_streaming_dataset,
    batch_size=batch_size,
    shuffle=False)

val_streaming_dataset = IterableMolDatapoints(
    df=df_val,
    smiles_column=smiles_column,
    target_column=target_column,
    weight_column=weight_column,
    scaler=scaler, shuffle=False, size_at_time=640)

val_loader = data.build_dataloader(
    val_streaming_dataset,
    batch_size=batch_size,
    shuffle=False)



  df = pd.read_csv(data_path)


Next, we define our model. The parameters can be added by hands or using .toml, .json file.

In [None]:
# Establish model (Parameter can be modify manually or add from .tmol file)
mp = nn.BondMessagePassing(d_v = 74, d_e = 14, d_h = 300,
                           dropout=0.3,
                           depth=5)

agg = nn.NormAggregation(norm=199)

output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)

ffn = nn.RegressionFFN(n_layers=2,
                       dropout=0.3,
                       input_dim=300,
                       hidden_dim=2200,
                       output_transform=output_transform)
                       
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE(), nn.metrics.R2Score()]

mpnn = models.MPNN(message_passing=mp, 
                   agg = agg, 
                   predictor=ffn, 
                   batch_norm=False, 
                   metrics=metric_list,
                   warmup_epochs=1,
                   init_lr=1.477783789959149e-06,
                   max_lr=0.00012044152141486488,
                   final_lr=0.00011724292252282861)

Finally, we train our model

In [None]:
checkpointing = ModelCheckpoint(
    "../hyperparam_optim_7/model_7/checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)


trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=epochs,
    callbacks=[checkpointing]
)

trainer.fit(mpnn, train_dataloaders=train_loader, val_dataloaders=val_loader)
