# Chemprop v2 PySCF + Polaris Demo

This demo shows how to integrate Chemprop v2 with PySCF and Polaris to drive predictions using on-the-fly quantum chemistry calculations.
Check the associated files in this repo for examples of implementing your own featurizers and molecule generators!

Polaris makes it easy to access data - in this case we'll retrieve some Rat Plasma Protein Binding data and randomly select some of it to use for early stopping:

In [23]:
%%capture
import polaris as po

benchmark = po.load_benchmark("polaris/adme-fang-rppb-1")
train, test = benchmark.get_train_test_split()
train_df, test_df = train.as_dataframe(), test.as_dataframe()
val_df = train_df.sample(frac=0.2, random_state=42)
train_df = train_df[~train_df.index.isin(val_df.index)]
smiles_column = list(benchmark.input_cols)[0]
target_column = [list(benchmark.target_cols)[0]]

Now we want to convert these to `rdkit` molecules.
We can wrap the Chemprop function `make_mol` to handle on-the-fly execution of PySCF simulations, which we then access later on:

In [24]:
from make_hirshfeld_mol import make_hirshfeld_mol
from tqdm import tqdm

In [25]:
test_mols = []
for smi in tqdm(test_df[smiles_column]):
    test_mols.append(make_hirshfeld_mol(smi, use_gpu=True))
val_mols = []
for smi in tqdm(train_df[smiles_column]):
    val_mols.append(make_hirshfeld_mol(smi, use_gpu=True))
train_mols = []
for smi in tqdm(train_df[smiles_column]):
    train_mols.append(make_hirshfeld_mol(smi, use_gpu=True))

100%|██████████| 24/24 [00:00<00:00, 4519.93it/s]
100%|██████████| 89/89 [00:00<00:00, 4935.58it/s]
100%|██████████| 89/89 [00:00<00:00, 4617.45it/s]


In [26]:
import numpy as np

In [27]:
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

from chemprop import data, featurizers, models, nn

Chemprop v2 makes it easy to add further features to your molecule.
In this case I show how to add Hirshfeld charges (and other properties) to the atoms both before and after message passing.
First, we'll add the atom features as `V_f` to each molecule so that they are added _after_ message passing:

In [28]:
train_data = [data.MoleculeDatapoint(
        mol=mol,
        y=y,
        V_f=np.array([
            [a.GetDoubleProp("charge_eff"),
             np.sqrt(a.GetDoubleProp("dipole_eff_x")**2 + a.GetDoubleProp("dipole_eff_y")**2 + a.GetDoubleProp("dipole_eff_z")**2), a.GetDoubleProp("V_eff"),
             a.GetDoubleProp("V_free"),
             a.GetDoubleProp("V_eff") / a.GetDoubleProp("V_free")]
             for a in mol.GetAtoms()]),
    )
    for mol, y in zip(train_mols, train_df[target_column].values)]
val_data = [data.MoleculeDatapoint(
        mol=mol,
        y=y,
        V_f=np.array([
            [a.GetDoubleProp("charge_eff"),
             np.sqrt(a.GetDoubleProp("dipole_eff_x")**2 + a.GetDoubleProp("dipole_eff_y")**2 + a.GetDoubleProp("dipole_eff_z")**2), a.GetDoubleProp("V_eff"),
             a.GetDoubleProp("V_free"),
             a.GetDoubleProp("V_eff") / a.GetDoubleProp("V_free")]
             for a in mol.GetAtoms()]),
    )
    for mol, y in zip(val_mols, val_df[target_column].values)]
test_data = [data.MoleculeDatapoint(
        mol=mol,
        V_f=np.array([
            [a.GetDoubleProp("charge_eff"),
             np.sqrt(a.GetDoubleProp("dipole_eff_x")**2 + a.GetDoubleProp("dipole_eff_y")**2 + a.GetDoubleProp("dipole_eff_z")**2), a.GetDoubleProp("V_eff"),
             a.GetDoubleProp("V_free"),
             a.GetDoubleProp("V_eff") / a.GetDoubleProp("V_free")]
             for a in mol.GetAtoms()]),
    ) for mol in test_mols]

And then we'll use our custom Atom Featurizer to include them _before_ messaging passing as well:

In [29]:
from hirshfeld_featurizer import HirshfeldAtomFeaturizer
from chemprop.featurizers.atom import MultiHotAtomFeaturizer
from chemprop.featurizers.bond import RIGRBondFeaturizer

from chemprop.nn.transforms import ScaleTransform, GraphTransform

In [30]:
atom_featurizer = HirshfeldAtomFeaturizer()  # can easily drop in MultiHotAtomFeaturizer.v2() here
bond_featurizer = RIGRBondFeaturizer()
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer(
    atom_featurizer=atom_featurizer,
    bond_featurizer=bond_featurizer,
    extra_atom_fdim=5,
    extra_bond_fdim=0,
)

train_dset = data.MoleculeDataset(train_data, featurizer)
val_dset = data.MoleculeDataset(val_data, featurizer)
test_dset = data.MoleculeDataset(test_data, featurizer)

target_scaler = train_dset.normalize_targets()
val_dset.normalize_targets(target_scaler)

output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)

extra_atom_features_scaler = train_dset.normalize_inputs("V_f")
val_dset.normalize_inputs("V_f", extra_atom_features_scaler)

train_dset.cache = True
val_dset.cache = True

From here on out we can follow conventional Chemprop training!
Chemprop v2 uses `lightning`, so the handling of GPUs, parallelism, and early stopping are all fully automated and highly reliable.

In [31]:
train_loader = data.build_dataloader(train_dset, batch_size=32)
val_loader = data.build_dataloader(val_dset, shuffle=False)
test_loader = data.build_dataloader(test_dset, shuffle=False)

In [32]:
import torch

In [33]:
n_V_features = featurizer.atom_fdim - featurizer.extra_atom_fdim
n_E_features = featurizer.bond_fdim - featurizer.extra_bond_fdim

V_f_transform = nn.ScaleTransform.from_standard_scaler(extra_atom_features_scaler, pad=n_V_features)
graph_transform = nn.GraphTransform(V_f_transform, torch.nn.Identity())


In [34]:
mp = nn.BondMessagePassing(
    d_v=featurizer.atom_fdim,
    d_e=featurizer.bond_fdim,
    d_h=64,
    depth=3,
    activation="leakyrelu",
    dropout=0.50,
    graph_transform=graph_transform,
)

In [35]:
agg = nn.NormAggregation()

In [36]:
ffn = nn.RegressionFFN(input_dim=mp.output_dim, hidden_dim=mp.output_dim, activation="leakyrelu", output_transform=output_transform, n_layers=2, dropout=0.50)

In [37]:
metric_list = [nn.metrics.MSE(), nn.metrics.RMSE(), nn.metrics.MAE()] # Only the first metric is used for training and early stopping

In [38]:
mpnn = models.MPNN(mp, agg, ffn, True, metric_list)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=15, out_features=64, bias=False)
    (W_h): Linear(in_features=64, out_features=64, bias=False)
    (W_o): Linear(in_features=77, out_features=64, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (tau): LeakyReLU(negative_slope=0.1)
    (V_d_transform): Identity()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): Identity()
    )
  )
  (agg): NormAggregation()
  (bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=True)
      )
      (1): Sequential(
        (0): LeakyReLU(negative_slope=0.1)
        (1): Dropout(p=0.5, inplace=False)
        (2): Linear(in_features=64, out_features=64, bias=True)
      )
      (2): Sequential(
        (0): LeakyReLU(negative_slope=0.1)
        (1): Dropout(p=0.

In [39]:
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)
early_stopping = EarlyStopping(
    "val_loss",
    patience=10,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=100, # number of epochs to train for
    callbacks=[checkpointing, early_stopping], # Use the configured checkpoint callback
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [40]:
trainer.fit(mpnn, train_loader, val_loader)

/home/jackson/miniforge3/envs/hirshfeld/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:751: Checkpoint directory /home/jackson/chemprop_polaris_pyscf_demo/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/jackson/miniforge3/envs/hirshfeld/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 10.0 K | train
1 | agg             | NormAggregation    | 0      | train
2 | bn              | BatchNorm1d        | 128    | train
3 | predictor       | RegressionFFN      

                                                                            

/home/jackson/miniforge3/envs/hirshfeld/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 27: 100%|██████████| 3/3 [00:00<00:00, 94.57it/s, train_loss_step=0.848, val_loss=0.891, train_loss_epoch=0.814] 


In [41]:
mpnn = mpnn.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

In [42]:
predictions = torch.cat(trainer.predict(mpnn, test_loader)).flatten()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/jackson/miniforge3/envs/hirshfeld/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 404.35it/s]


In [43]:
benchmark.evaluate(predictions).results

Unnamed: 0,Test set,Target label,Metric,Score
0,test,LOG_RPPB,spearmanr,0.673043
1,test,LOG_RPPB,pearsonr,0.643209
2,test,LOG_RPPB,r2,0.188435
3,test,LOG_RPPB,explained_var,0.188461
4,test,LOG_RPPB,mean_absolute_error,0.691382
5,test,LOG_RPPB,mean_squared_error,0.721061
