# Polymer prediction with Chemprop

Sam A.J. Hillman Feb 2025.

# 1) Install Chemprop from GitHub

In [None]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .

#Import packages
from pathlib import Path

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
import numpy as np

from rdkit import Chem                  #rdkit is used to convert SMILES to mols
from rdkit.Chem import Draw#, PandasTools  #PandasTools enables using RDKit molecules in columns of a Pandas dataframe
from rdkit.Chem.Draw import SimilarityMaps  #for drawing the partial charges
from chemprop import data, featurizers, models, nn    #chemprop is our GNN package

import matplotlib.pyplot as plt

Cloning into 'chemprop'...
remote: Enumerating objects: 24696, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (31/31), done.[K


# 2) Load, explore and process data

> Pair the acids and bromides together to make 9 * 682 =6138 unique polymers.

> When the monomers react together, the bonding occurs where the  Br and B(OH)(OH) groups are, so these fall off. One of the benefits of the mol representation is that it is easy to replicate this chemistry virtually with RDKit (this is more robust than editing SMILES strings).

>In this case, the reaction has been done for us - all polymers are in  "polymer_dataset_alternating.csv"

> All molecules can be written with SMILES (Simplified Molecular Input Line Entry System).



In [None]:
#Get the polymer SMILES from GitHub.
csv_url = "https://raw.githubusercontent.com/S-AJ-H/Chemprop-Tutorial/main/polymer_dataset_alternating.csv"  #get data
df = pd.read_csv(csv_url) # Load into a DataFrame
display(df)

## (i) Isolated example: Drawing the pairs of monomers (the "polymers")

> This example is for demonstration purposes and isn't used in the model.

> We construct 'molecule' objects from the SMILES using RDKit. RDKit uses a collection of rules to calculate a complete set of molecule-defining chemical information from the SMILES. These molecule objects encode the atomic structure, bonds, spatial arrangement etc of a molecule.

> We can still represent pairs of monomers as single 'molecule' objects. This approach has some problems (e.g. there is no explict info on where or how the bonding takes place) but is OK for now.

> RDKit: https://www.rdkit.org/docs/index.html



In [None]:
# Lets look at a few randomly chosen mol objects representing pairs of monomers.

#Convert the SMILES to mol objects using MolFromSmiles on each element in the Series:
random_mols_index = [0,1000,2000,5000, 6137]
df['poly_MOL'] = df['poly_SMI'].iloc[random_mols_index].apply(Chem.MolFromSmiles)   #Chem.MolFromSmiles is the RDKit function

#Draw the pairs:
img = Draw.MolsToGridImage(list(df.poly_MOL.iloc[random_mols_index]), molsPerRow=5)
display(img)


RDKit gives us chemistry! As an example, lets calculate the partial charges for the first pair of monomers. The resulting graph shows where electrons are localised in the molecules (blue = higher electron density, brown = lower).

In [None]:
#RDKit allows us to do loads of chemistry. Lets look at the charge localisation of the first pair of monomers

#calculate:
mol = df.poly_MOL.iloc[0]   #access first pair of monomers
Chem.AllChem.ComputeGasteigerCharges(mol) #calculate the partial charges
contribs = [mol.GetAtomWithIdx(i).GetDoubleProp('_GasteigerCharge') for i in range(mol.GetNumAtoms())]  #store in contribs with the atom index

#now draw:
d2d = Draw.MolDraw2DCairo(400, 400)
drawing = Draw.SimilarityMaps.GetSimilarityMapFromWeights(mol, contribs, draw2d=d2d, colorMap='jet', contourLines=10)
drawing.FinishDrawing()

import io
from PIL import Image
def show_png(data):
    bio = io.BytesIO(data)
    img = Image.open(bio)
    return img

show_png(drawing.GetDrawingText())  #Note how the bottom OH has a dotted line around it which is influenced by the bottom molecule

## 2a) Make molecule datapoints

> Molecule datapoints link the molecule objects, being the RDKit-generated molecules, and the target(s) y.

> It also adds extra things which Chemprop can use later (such as opportunities to manually add extra properties).


In [None]:
#Get SMILES and targets
smiles = df.loc[:, 'poly_SMI'].values
targets = df.loc[:,['EA']].values

display(smiles[:2]) # show first 2 SMILES strings
display(targets[:2]) # show first 2 targets

#Use the SMILES to generate mol objects, pair the mol objects with the targets y
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smiles, targets)]
display(all_data[:2])

#We can still access the molecules and extract chemical info if we like:
#all_data[0].mol.GetNumAtoms()   #Can use RDKit things on the mol, but not on the MoleculeDatapoint

## (ii) Isolated example: Featurization with RDKit

>To start the training, we need "MoleculeDatasets". These comprise a list of Molecules and ends with the choice of featurizer. We need one MoleculeDataset for each of the training, validation and testing splits.

> Here we use the built in featurizer "SimpleMoleculeMolGraphFeaturizer", which is a part of RDKit. It outputs a "MolGraph", which is the graph featurisation of the molecule i.e. atom and bond features. These features will be used to kick-start the upcoming message passing.

>SimpleMoleculeMolGraphFeaturizer uses a multi-hot encoding to featurize individual atoms and bonds. In the example below for carbon monoxide (CO), you can see arrays for V (atom features) and E (edge/bond features), along with a mapping between atoms and bonds (edge_index and rev_edge_index).

>Atom features are generated by rdkit and cast to one-hot vectors. Features include e.g. the mass, charge, number of bonded hydrogen atoms (for atoms); bond type, conjugation, whether its in an aromatic ring (for bonds). These feature vectors are joined together to a single multi-hot feature vector.

>https://chemprop.readthedocs.io/en/latest/tutorial/python/featurizers/molgraph_molecule_featurizer.html

In [None]:
#lets look at the featurizer:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer() #chemprop module which outputs a MolGraph

#start with a simple two-atom example:
carbon_monoxide = Chem.MolFromSmiles("[C-]#[O+]")
display(Draw.MolToImage(carbon_monoxide))

display("Features of carbon monoxide:", featurizer(carbon_monoxide))  #see e.g. 0.12011 and 0.15999 at the end of the arrays, which are the (normalised) masses

#now lets have alook at the features of the first pair of monomers in the training dataset:
#display("Features of the first polymer:", featurizer(train_data[0][0].mol))

## 2b) Perform data splitting for training, validation, and testing

ChemProp's `make_split_indices` function will always return a two- (if no validation) or three-length tuple (if including validation, like in this example). The inner lists then contain the actual indices for splitting.

The type signature for this return type is `tuple[list[list[int]], ...]`.

In [None]:
# available split types - Kennard Stone, Kmeans and scafford balanced are all structure-based splits (similar structures go to same split) based on https://jacksonburns.github.io/astartes/
#list(data.SplitType.keys())

mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))  # unpack the tuple into three separate lists. data is a Chemprop function

#display(test_indices[0]) #test indices
display("test_indices length =", len(test_indices[0]))     #~6138/10. (Note that the list of test indices is nested)

train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices)     #Use the 3 lists of indices to split the data.

#display(len(test_data[0]))
#display(test_data[0][0])  #same format as in 2a

## 2c) Make the three MoleculeDatasets

In [None]:
train_dset = data.MoleculeDataset(train_data[0], featurizer)  #append the type of featuriser to the list of molecule to make the molecule dataset. (We use train_data[0] because of the aforementioned nesting)
#The MoleculeDataset "train_dset" is a list of MoleculeDatapoints, with the featurizer type listed at the end:
#display(len(test_dset))
#display(test_dset.data[:1])      #test_dset.data is the same as test_data, without nesting;

#each MoleculeDatapoint now has MolGraph features which are accessed through indexing. Compare to Part (4) - its the same but we have MolGraph features instead of mol objects.
#display(test_dset[0])

scaler = train_dset.normalize_targets() #define the normalisation using StandardScaler (subtract mean, scale to unit variance)


#Do the same for validation and test
val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(scaler)            #normalise the validation dataset in the same way as the training

test_dset = data.MoleculeDataset(test_data[0], featurizer)      #no normalisation


# 3) Message-Passing Neural Network input parameters

>Now our data is ready, its time to use Chemprop.There are 3 main steps: Message passing, aggregation and the feed-forward NN.

> For more info on a step: https://chemprop.readthedocs.io/en/latest/tutorial/python/index.html

## 3a) Message Passing and Aggregation

`Message passing`: Constructs hidden node-level representations. Pass messages from bond to bond or atom to atom. Bond to bond allows for directed messages and is generally preferred. Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`.

`Aggregation`: The aggregation layer combines the node level representations into a graph level representation (usually atoms -> molecule). Options include - `agg = nn.MeanAggregation()`, `agg = nn.SumAggregation()`, `agg = nn.NormAggregation()`

In [None]:
#Message passing: construct node-level representations of the atoms
#Define message passing type. Can pass different activations, dropout, etc etc.
#Defaults: 300 hidden dimensions. 3 message passing iterations. ReLU.
mp = nn.BondMessagePassing()    #https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/nn/index.html#chemprop.nn.BondMessagePassing

#Aggregation: node-level --> graph-level representation
agg = nn.MeanAggregation()    #average together all of the hidden node/edge representations to get a graph-level representation.
#Note this is the only place where the two monomers in each pair interact!
batch_norm = True #normalizes the outputs of the aggregation by re-centering and re-scaling. Helps keep the inputs to the FFN small and centered around zero.

## 3b) Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Regression options include:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

In [None]:
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler) #unscale the data

#define the feed-forward network:
ffn = nn.RegressionFFN(output_transform=output_transform, n_tasks = 1) #n_tasks sets the number of targets. Can change number of layers etc etc.
#set the metrics:
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE(), nn.metrics.R2Score()] # Only the first metric is used for training and early stopping

## 3c) Construct the MPNN

In [None]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list) #can change learning rate, optimiser etc
mpnn
#Entire model, consisting of message passing, aggregation and FFN, is end-to-end trained.

#In the message passing NN:
#w_i = input weights, applied to the bond feature vectors. Length = sum of bond and atom features. Output hidden dimension is 300 by default.
#w_h = hidden weights, applied to the messages.
#w_o = output weights.

#In the FFN:
#Input dimension is the same as the MPNN hidden dimension (300)

# 4) Set up and start the trainer

In [None]:
#Get DataLoader

num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading. Specifies how many subprocesses should be used to load data.
#Each of these subprocesses retrieves a batch of data from your dataset and sends it to the main training process.

train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
#---------------------------------------------------------------------------------------------------------------------------------
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)
#---------------------------------------------------------------------------------------------------------------------------------

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=10, # number of epochs to train for
    callbacks=[checkpointing], # Use the configured checkpoint callback
)


#start training
trainer.fit(mpnn, train_loader, val_loader)

# 5) Test results

In [None]:
results = trainer.test(dataloaders=test_loader)

In [None]:
full_dset = data.MoleculeDataset(all_data, featurizer=featurizer)
full_loader = data.build_dataloader(full_dset, shuffle=False)

predictions = trainer.predict(mpnn, full_loader)

#append the predicted values to the original array
predictions_array = np.concatenate(predictions, axis=0)
df[['pred_EA']] = predictions_array
df

Plot prediction vs actual

In [None]:
plt.figure(figsize=(10, 6))

# Plot pred vs true
plt.scatter(df['EA'], df['pred_EA'], label='Prediction', marker='x', color='red')

# Plot y=x
plt.plot(df['EA'], df['EA'], label='y=x', linestyle='-')

plt.xlabel('true EA')
plt.ylabel('Predicted EA')
plt.title('Pred vs true EA')
plt.legend()
plt.grid(True)
plt.show()


# 6) Bonus: Transfer learning

Transfer learning (or pretraining) leverages knowledge from a pre-trained model on a related task to enhance performance on a new task. In Chemprop, we can use pre-trained model checkpoints to initialize a new model and freeze components of the new model during training.

The originally imported data contains a column "IP". By repeating the steps above, can you train a model to a similar level of accuracy to the one above using a smaller training dataset?

In [None]:
df

## 6a) Make the MoleculeDatasets by following the steps in Section 2.

> The previous model saved a checkpoint file with all of the weights etc from the trained MPNN. We need to call it between 2b and 2c below. We also need to ensure that the scaling is the same as that used in the first model.

In [None]:
#2a as before

#2b as before

#new bit: the previous model saved a checkpoint file with all of the weights etc from the trained MPNN. We can call it here:
checkpoint_path = "/content/chemprop/checkpoints/best-epoch=7-val_loss=0.03.ckpt" #replace with your checkpoint name (see the files on the left hand side)
mpnn_cls = models.MPNN
mpnn = mpnn_cls.load_from_file(checkpoint_path)
mpnn

#new bit: scaling. Need to use the same scaling as in the pre-trained data.
pretraining_scaler = scaler
pretraining_scaler.mean_ = mpnn.predictor.output_transform.mean.numpy()
pretraining_scaler.scale_ = mpnn.predictor.output_transform.scale.numpy()

#2c as before

## 6b) Freezing MPNN and FFN layers
Certain layers of a pre-trained model can be kept unchanged during further training on a new task.

In [None]:
#To freeze the MPNN (i.e. the learned representation of the molecules)
mpnn.message_passing.apply(lambda module: module.requires_grad_(False))
mpnn.message_passing.eval()
mpnn.bn.apply(lambda module: module.requires_grad_(False))
mpnn.bn.eval()  # Set batch norm layers to eval mode to freeze running mean and running var.

## 6c) Train and test by following the steps in Sections 4 and 5.

> Are the predictions for IP better or worse than for EA?

> Change the data split such that you are only training on 10% of the data (Try running for 50 epochs instead of 10).

In [None]:
#4 Train. Notice how there are fewer trainable parameters in the model now.

#5 Results.