<a href="https://colab.research.google.com/github/S-AJ-H/AIMS26/blob/main/3_Chemprop_representations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 3. Chemprop for predicting polymer properties

This workbook introduces the use of Chemprop for predicting chemical properties. Chemprop uses a graph neural network method in which "message passing" is used to learn hidden representations of chemicals. These representations are then passed through a feed-forward neural network to predict chemical properties.

We again choose to predict the "electron affinity" of polymer photocatalysts.

You will:
   
*   Train and evaluate a Chemprop model
*   Compare this model to the "fixed representations" model from Notebook 1.


Resources:
>RDKit:   
>https://rdkit.org/docs/index.html

>Chemprop:  
>https://pubs.acs.org/doi/10.1021/acs.jcim.9b00237  
>https://pubs.acs.org/doi/10.1021/acs.jcim.3c01250  
>https://chemprop.readthedocs.io/en/latest/

>Data from:  
>https://pubs.acs.org/doi/full/10.1021/jacs.9b03591


##0. Install Chemprop from GitHub

In [None]:
# Chemprop (~1min)
!pip install chemprop -qq
import chemprop
print("Imported Chemprop version", chemprop.__version__)

from rdkit import Chem                                                  # rdkit is used to convert SMILES to molecular graphs ("mols")
from rdkit.Chem import Draw                                             # Lets us draw molecules
from chemprop import data, featurizers, models, nn                      # chemprop is our GNN package

# ML
import lightning.pytorch as pl                                          # lightning has built-in functions for lots of the basics (metric tracking etc); Chemprop is built on this.
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from lightning.pytorch.loggers import CSVLogger                         # Configure CSV logger for tracking losses
import logging
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
from sklearn.model_selection import train_test_split, KFold, PredefinedSplit
from sklearn.metrics import r2_score, mean_absolute_error

# Misc
import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
version = 0                                                             # used for save files

##1. Load data



In [None]:
#Get the polymer SMILES from GitHub.
csv_url = "https://raw.githubusercontent.com/S-AJ-H/AIMS26/25478252292fe3bde0e4fb06977ea21c7e05545a/dataset.csv"
df_data = pd.read_csv(csv_url)
display(df_data)

##2. Prepare data for machine learning

##### 2.1 Extract and split features and targets

In [None]:
#1. Extract smiles and targets (EA)
smiles = df_data.loc[:, 'poly_SMI'].values
targets = df_data.loc[:, 'EA'].values

#instead of using generated fingerprints, we use the SMILES to generate mol objects, then pair the mol objects with the targets y to make "MoleculeDatapoints"
all_data = [data.MoleculeDatapoint.from_smi(smi, [y]) for smi, y in zip(smiles, targets)]
display(all_data[:2])

#set up data splitting. Later code is written for easy modification for cross-validation (folds and n_splits variables), but we ignore it in this workbook
fold=1
n_splits=1                                                                                                  # set only 1 fold
train_idx, valid_idx = train_test_split(range(len(smiles)), test_size=0.1, random_state=31, shuffle=True)   # get indices

##### 2.2 Define dataloader

In [None]:
def dataloader(all_data, train_idx, valid_idx, featurizer, batch_size):
  # split and featurise the data:
  train_data, val_data, _ = data.split_data_by_indices(data=all_data, train_indices=[train_idx], val_indices=[valid_idx])           # Use the 2 lists of indices to split the data.

  train_dset = data.MoleculeDataset(train_data[0], featurizer)                                                                      # MoleculeDataset is a Chemprop function that featurises the inputs. We use train_data[0] because there is some nesting
  val_dset = data.MoleculeDataset(val_data[0], featurizer)

  # scale
  scaler = train_dset.normalize_targets()                                                                                           # normalise the targets using StandardScaler (subtract mean, scale to unit variance)
  val_dset.normalize_targets(scaler)

  # make loaders
  train_loader = data.build_dataloader(train_dset, batch_size=batch_size)
  val_loader = data.build_dataloader(val_dset, batch_size=batch_size, shuffle=False)

  return train_loader, val_loader, scaler

##3. Define, train and validate model

>Now our data is ready, its time to use Chemprop. There are 3 main steps:

1.   Message passing: Constructs hidden atom representations by passing messages from bond to bond (or atom to atom).
2.   Aggregation: Combines the atom representations into a graph level representation (usually atoms -> molecule).
3.   Feed-Forward Network (FFN): takes the aggregated representations and make target predictions.

####3. Questions

> (a) There is only one place in this model where the pairs of monomers interact with each other - where is it?

##### 3.1 Train and validate model

In [None]:
# hyperparameters
max_epochs = 50                                                                 # default 100
batch_size = 256                                                                # default 512
dim_h = 100                                                                     # default 300 hidden dimension for graph-level representations. This is the output of the aggregation step and the input of the FFN.
ffn_dim = 100                                                                   # default 200 hidden dimension for the FFN.
mp_steps = 2                                                                    # default 2
mp_drop = 0.0                                                                   # default 0.1
ffn_drop = 0.0                                                                  # default 0.1
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE(), nn.metrics.R2Score()]       # Only the first metric is used for training and early stopping
patience = 10

results = []                                                                    #for storing all predictions and targets, for all folds
fold_r2_scores = []                                                             #for storing all per-fold scores
fold_mae_scores = []

# train
for fold, (train_idx, valid_idx) in enumerate([(train_idx, valid_idx)], 1):
  # setup
  #=============================================================================
  # dataloader
  train_loader, val_loader, scaler = dataloader(all_data, train_idx, valid_idx, featurizer, batch_size)

  # define (and reset) the model:
  mp = nn.BondMessagePassing(d_h=dim_h, bias=False, depth=mp_steps, dropout=mp_drop, activation=nn.utils.Activation.RELU, undirected=False)   # message passing
  agg = nn.MeanAggregation()                                                                                                                  # average together all of the hidden node/edge representations to get a graph-level representation.
  batch_norm = True                                                                                                                           # normalizes the outputs of the aggregation by re-centering and re-scaling (for better FFN inputs).
  output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)                                                                         # define the unscaling function for trainer.predict later
  ffn = nn.RegressionFFN(input_dim=dim_h, hidden_dim=ffn_dim, dropout=ffn_drop, n_layers=2, n_tasks=1, output_transform=output_transform)

  model = models.MPNN(mp, agg, ffn, batch_norm, metric_list, init_lr = 1e-5, max_lr = 1e-4)                                                   # Defines the complete model structure. Can also change learning rate, optimiser etc. End-to-end trained.

  # Configure model checkpointing, early stopping and logging
  early_stop = EarlyStopping(monitor="val_loss", patience=patience, mode="min")                                                               # define early stopping conditions
  checkpointing = ModelCheckpoint(                                                                                                            # saves trained weights for the model with the lowest validation loss
      dirpath="checkpoints",
      filename=f"fold={fold}-best-{{epoch}}-{{val_loss:.2f}}",
      monitor="val_loss",
      mode="min",
  )
  logger = CSVLogger("logs", name=f"fold_{fold}", version=0)                                                                                  # saves metrics and losses
  # train and validate
  #=============================================================================
  # train
  trainer = pl.Trainer(logger=logger, accelerator="auto", devices=1, max_epochs=max_epochs, callbacks=[checkpointing, early_stop])            # lightning trainer
  print(f"\nFold {fold}:\n{'=' * 100}")
  trainer.fit(model, train_loader, val_loader)                                                                                                # model.train() etc is in here
  print(f"\nFold {fold} complete!\n")

  # get predictions vs targets
  preds = trainer.predict(model, dataloaders=val_loader, ckpt_path="best", weights_only=False)                                                # "ckpt_path = best" to load weights with lowest loss
  preds_array = np.concatenate([p.cpu().numpy() for p in preds], axis=0).reshape(-1)                                                          # preds from .predict() is already unscaled

  # also get the targets for this fold out, so we can calculate per-fold metrics
  targets_scaled = np.concatenate([batch.Y.cpu().numpy() for batch in val_loader], axis=0).reshape(-1,1)
  targets_array = scaler.inverse_transform(targets_scaled).reshape(-1)                                                                        # need to un-scale targets and flatten

  # store predictions and metrics
  #=============================================================================
  # store predictions
  df_fold = pd.DataFrame({
      "original_index": valid_idx,
      "targets": targets_array,
      "preds": preds_array,
      "fold": fold
  })
  results.append(df_fold)

  #Compute metrics per fold
  r2 = r2_score(df_fold["targets"], df_fold["preds"])
  mae = mean_absolute_error(df_fold["targets"], df_fold["preds"])
  fold_r2_scores.append(r2)                                                                                                                    # collect the per-fold r2 into a single list
  fold_mae_scores.append(mae)
  print(f"Fold {fold} metrics: \nR²={r2:.3f}, MAE={mae:.3f}\n")

#After all folds:
#==========================================================
# save all predictions in a dataframe
df_results = pd.concat(results, ignore_index=True)
df_all = df_data.merge(df_results, left_index=True, right_on="original_index")

# Aggregate across folds
mean_r2 = np.mean(fold_r2_scores)
mean_mae = np.mean(fold_mae_scores)
std_r2 = np.std(fold_r2_scores)
std_mae = np.std(fold_mae_scores)
print("\nFull results:")
print(f"Mean R²: {mean_r2:.3f} ± {std_r2:.3f}")
print(f"Mean MAE: {mean_mae:.3f} ± {std_mae:.3f}")

##4. Analyse results

Now we have:  
>>

1.   `df_all`: a dataframe containing our initial info (`df_data`) and our results (`df_results`).
2.   Per-fold MAE, R2 metrics in the lists `fold_mae_scores` and `fold_r2_scores`  
3. Overall mean metrics and standard deviations (when using multiple folds) in `mean_r2`, `std_r2` etc
3.   Per-fold losses per epoch in the file *metrics.csv* in *logs/fold_x/version_0*
4.   Per-fold checkpoint files containing the best weights for each fold in the *checkpoints* folder  

##### 4.1 Print results

In [None]:
# Print df_all and the MAE:
##############################

##### 4.2 Plot results

In [None]:
# extract losses per epoch per fold:
train_losses_dict = {}                                                          # store the per-fold, per-epoch losses with the key = fold number
valid_losses_dict = {}

for fold in range(1, n_splits+1):
  # extract losses
  metrics_path = f"logs/fold_{fold}/version_{version}/metrics.csv"
  df_metrics = pd.read_csv(metrics_path)
  train_losses = df_metrics["train_loss_epoch"].dropna().values
  valid_losses = df_metrics["val_loss"].dropna().values

  # save to dictionaries
  train_losses_dict[fold] = train_losses
  valid_losses_dict[fold] = valid_losses

# Create subplots
n_plots = n_splits + 2  # train/val plots + scatter + bar
fig, axes = plt.subplots(3, 4, figsize=(15, 8))
axes = axes.flatten()

# Training vs Validation Loss per fold
for fold in range(1, n_splits+1):
    ax = axes[fold-1]
    ax.plot(train_losses_dict[fold], label=f"Train Fold {fold}", alpha=0.7)
    ax.plot(valid_losses_dict[fold], label=f"Valid Fold {fold}", alpha=0.7)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("MSE Loss")
    ax.set_title(f"Fold {fold} Losses")
    ax.set_yscale("log")
    ax.legend()

# Scatter plot: True vs Predicted
ax2 = axes[n_splits]
for fold in range(1, n_splits+1):
    fold_data = df_all[df_all["fold"] == fold]
    ax2.scatter(
        fold_data["targets"],
        fold_data["preds"],
        alpha=0.8,
        s=5,
        label=f"Fold {fold}"
    )

ax2.set_xlabel("True EA")
ax2.set_ylabel("Predicted EA")
ax2.set_title("True vs Predicted EA")
lims = [
    min(df_all["targets"].min(), df_all["preds"].min()),
    max(df_all["targets"].max(), df_all["preds"].max())
]
ax2.plot(lims, lims, "r--")  # y=x reference
ax2.text(0.05, 0.95, f"R²: {mean_r2:.3f} ± {std_r2:.3f}",
         transform=ax2.transAxes, fontsize=10, verticalalignment='top')
ax2.text(0.05, 0.85, f"MAE: {mean_mae:.3f} ± {std_mae:.3f}",
         transform=ax2.transAxes, fontsize=10, verticalalignment='top')
ax2.legend(title="Folds", markerscale=2, loc="center left", bbox_to_anchor=(1.05, 0.5))

# Bar chart: MAE per fold
ax3 = axes[n_splits+1]
cmap = plt.cm.get_cmap("tab10")
colors = [cmap(fold-1) for fold in range(1, n_splits+1)]
ax3.bar(range(1, n_splits+1), fold_mae_scores, color=colors, edgecolor="black")
ax3.set_xlabel("Fold")
ax3.set_ylabel("MAE")
ax3.set_title("MAE per Fold")
ax3.set_xticks(range(1, n_splits+1))

# Delete any unused subplots
for i in range(n_splits+2, len(axes)):
    fig.delaxes(axes[i])

plt.tight_layout()
plt.show()

##5. Questions


1.   Compare the MAE of this model with the "Fixed representations" model. What can you conclude?
2.   Change the length of the hidden atom representations vector by a factor of 10 smaller and/or larger. How does this impact model MAE?
3.   Change the number of message passing steps to 1, and to 5. Explain your observations. How could changing the input SMILES make better use of message passing?
4. We split our training and validation data using random splitting. In the context of this polymer SMILES dataset, explain why this may lead to over-optimistic validation performance and poor generalisation to unseen polymers.




##6. Extension Questions


1.   Calculate the percentage of polymers for which the model predicts the EA within 0.025 eV. Using google, why might this quantity be a good metric for evaluating model performance? Change the training metric to your new metric and re-train the model.
2.   Load the CheMeleon pre-trained weights as seen in the message passing notebook. Use these (a) as an initial starting set of weights which are then fine-tuned and (b) to predict the EAs directly. Does this improve the MAE?

