# 1) Install Chemprop from GitHub

In [None]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install -qq .

#Import packages
from pathlib import Path

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd
import numpy as np

from rdkit import Chem                  #rdkit is used to convert SMILES to mols
from rdkit.Chem import Draw#, PandasTools  #PandasTools enables using RDKit molecules in columns of a Pandas dataframe
from rdkit.Chem.Draw import SimilarityMaps  #for drawing the partial charges
from chemprop import data, featurizers, models, nn    #chemprop is our GNN package

import matplotlib.pyplot as plt

# 2) Load, process and explore data

> In this part, we will use an experimental dataset containing ~ 70 polymers.
Data: https://pubs.acs.org/doi/full/10.1021/jacs.9b03591

We are interested in predicting the hydrogen evolution rate, or "HER". But we also have data on the optical band gap "Eg" and the transmittance "T" of each polymer when dispersed in water.

We will explore:

How data skew is a problem in this dataset. This is a common materials problem: we typically want to focus on the rarer, higher performing materials, but these are scarce. We examine different approaches to handling this, including:
  data augmentation through "repeat measurements"
  using square root and log functions to reduce skew in the data
  using a multi-task model that predicts multiple targets as once with different skews.




In [None]:

#Get the list of sulfone polymers with experimental HERs.
HER_polymers_url = "https://raw.githubusercontent.com/S-AJ-H/Chemprop-Tutorial/main/1.2/HERs_sulfones.csv"  #get data

df_HER = pd.read_csv(HER_polymers_url) # Load into a DataFrame
df_HER = df_HER[['MonA', 'MonB', 'Eg', 'HER', 'T']]

# Convert 'Eg' column to numeric, coercing errors to NaN
df_HER['Eg'] = pd.to_numeric(df_HER['Eg'], errors='coerce')
# Drop rows where 'Eg' is NaN
df_HER.dropna(subset=['Eg'], inplace=True)

df_HER = df_HER[df_HER['HER'] != 0]     #remove HER = 0 values
df_HER = df_HER[df_HER['T'] != 0]     #remove HER = 0 values

df_HER['logHER'] = np.log(df_HER['HER']+1)  #create logHER  #+1 so the log plays nice - need to convert back later


df_HER['sqrtHER'] = np.sqrt(df_HER['HER'])  #create sqrtHER
df_HER['sqrtT'] = np.sqrt(df_HER['T'])  #create sqrtT
df_HER['logHER'] = np.log(df_HER['HER']+1)  #create sqrtHER   #bodge so all are log
df_HER['logT'] = np.log(df_HER['T']+1)  #create logT


df_HER.dropna(subset=['MonA'], inplace=True)
df_HER.reset_index(drop=True, inplace=True)
#display(df_HER)



Lets look at the distributions of the three targets:

In [None]:
#plot HER Eg and T as histograms:

import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(10, 3))

df_HER['HER'].plot(kind='hist', bins=20, ax=axes[0], title='HER')
axes[0].spines[['top', 'right',]].set_visible(False)

df_HER['Eg'].plot(kind='hist', bins=20, ax=axes[1], title='Eg')
axes[1].spines[['top', 'right',]].set_visible(False)

df_HER['T'].plot(kind='hist', bins=20, ax=axes[2], title='T')
axes[2].spines[['top', 'right',]].set_visible(False)

plt.tight_layout()
plt.show()

In [None]:
#plot log HER Eg and T as histograms

fig, axes = plt.subplots(1, 3, figsize=(10, 3))

df_HER['logHER'].plot(kind='hist', bins=20, ax=axes[0], title='logHER')
axes[0].spines[['top', 'right',]].set_visible(False)

df_HER['Eg'].plot(kind='hist', bins=20, ax=axes[1], title='Eg')
axes[1].spines[['top', 'right',]].set_visible(False)

df_HER['logT'].plot(kind='hist', bins=20, ax=axes[2], title='logT')
axes[2].spines[['top', 'right',]].set_visible(False)

plt.tight_layout()
plt.show()

In [None]:
#plot sqrt HER Eg and T as histograms

fig, axes = plt.subplots(1, 3, figsize=(10, 3))

df_HER['sqrtHER'].plot(kind='hist', bins=20, ax=axes[0], title='sqrtHER')
axes[0].spines[['top', 'right',]].set_visible(False)

df_HER['Eg'].plot(kind='hist', bins=20, ax=axes[1], title='Eg')
axes[1].spines[['top', 'right',]].set_visible(False)

df_HER['sqrtT'].plot(kind='hist', bins=20, ax=axes[2], title='sqrtT')
axes[2].spines[['top', 'right',]].set_visible(False)

plt.tight_layout()
plt.show()

Lets first try "augmenting" the data. The HER values have significant error attached - perhaps 10-30%. So it is reasonable to create extra data points by duplicating HER values with a small amount of noise. The code below preferentially adds extra data to the dataset which consists of the same SMILES but a HER with some amount of noise. This is akin to re-measuring the sample in the lab.

What are the pros and cons of this approach?

In [None]:
def augment_data(df, noise_percentage=0.01, top_n_percentage=0.2):
  """
  Augments the HER data to have a flatter distribution by adding noise and duplicating high HER rows.

  Args:
    df: The input DataFrame containing 'HER' column.
    noise_percentage: The percentage of noise (as a fraction of the HER value) to add.
    top_n_percentage: The percentage of rows with the highest HER values to duplicate.

  Returns:
    A new DataFrame with augmented data.
  """
  df_augmented = df.copy()

  # Add noise to all rows
  noise = np.random.normal(0, df_augmented['HER'] * noise_percentage, size=len(df_augmented))
  df_augmented['HER'] = df_augmented['HER'] + noise
  df_augmented['HER'] = df_augmented['HER'].clip(lower=0) # Ensure HER remains non-negative

  # Duplicate rows with the highest HER values
  top_n = int(len(df_augmented) * top_n_percentage)
  df_sorted = df_augmented.sort_values(by='HER', ascending=False)
  df_top_n = df_sorted.head(top_n)
  df_augmented = pd.concat([df_augmented, df_top_n], ignore_index=True)

  return df_augmented

# Augment the data
df_HER = augment_data(df_HER, noise_percentage=0.03, top_n_percentage=0.5)

# Plot the distribution of the augmented data
fig, axes = plt.subplots(1, 1, figsize=(5, 3))
df_HER['HER'].plot(kind='hist', bins=20, ax=axes, title='Augmented HER Distribution')
axes.spines[['top', 'right',]].set_visible(False)
plt.tight_layout()
plt.show()

In [None]:
#for MonB = Oc1cc(O)c(Br)c(O)c1Br, plot all HER values by index

import matplotlib.pyplot as plt
df_MonB = df_HER[df_HER['MonB'] == 'Oc1cc(O)c(Br)c(O)c1Br'].copy()

plt.figure(figsize=(10, 6))
plt.plot(df_MonB.index, df_MonB['HER'], marker='o', linestyle='-')
plt.title('HER values for MonB = Oc1cc(O)c(Br)c(O)c1Br by Index')
plt.xlabel('Index')
plt.ylabel('HER')
plt.grid(True)
plt.show()

In [None]:
#define functions for doing chemistry with RDkit

#function that deletes duplicates by converting from mols to *canonical* smiles and back to moles
def rm_duplicate_mols(mols):
    smiles = list(set([Chem.MolToSmiles(m, canonical=True) for m in mols]))
    mols = [Chem.MolFromSmiles(s) for s in smiles]
    return mols

#Replaces Br with [At] in places where there is a cCBr (not just a cBr) - this "protects" cCBr
def protect_CBr(m):
    while m.HasSubstructMatch(Chem.MolFromSmarts('cCBr')):
        smarts = "[*:1]CBr>>[*:1]C[At]"
        rxn = AllChem.ReactionFromSmarts(smarts)
        ps = rxn.RunReactants((m,))
        products = rm_duplicate_mols([m[0] for m in ps])
        m = products[0]
    return m

#reverses the protection once we're finished protecting
def deprotect_CBr(m):
    while m.HasSubstructMatch(Chem.MolFromSmarts('C[At]')):
        smarts = "[*:1]C[At]>>[*:1]CBr"
        rxn = AllChem.ReactionFromSmarts(smarts)
        ps = rxn.RunReactants((m,))
        products = rm_duplicate_mols([m[0] for m in ps])
        m = products[0]
    return m

#the important one: remove the terminal groups
def rm_termini(m):

    # remove all Br
    m = protect_CBr(m)
    while m.HasSubstructMatch(Chem.MolFromSmarts('cBr')):
        smarts = "[*:1]Br>>[*:1]"
        rxn = AllChem.ReactionFromSmarts(smarts)
        ps = rxn.RunReactants((m,))
        products = rm_duplicate_mols([m[0] for m in ps])
        m = products[0]
    m = deprotect_CBr(m)

    # remove all BOO
    while m.HasSubstructMatch(Chem.MolFromSmarts('[B](-O)(-O)')):
        smarts = "[*:1]([B](-O)(-O))>>[*:1]"
        rxn = AllChem.ReactionFromSmarts(smarts)
        ps = rxn.RunReactants((m,))
        products = rm_duplicate_mols([m[0] for m in ps])  #converts mols to SMILES, sets all to canonical, converts back to mols. This prevents duplication
        m = products[0]

    return m

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem

smiA_series = df_HER['MonA']
smiB_series = df_HER['MonB']
smiA = smiA_series.tolist()
smiB = smiB_series.tolist()

#Function that does the polymerisation
def make_master_chemprop_input(smiA, smiB):
    mA = Chem.MolFromSmiles(smiA)   #make mols from smiles
    mB = Chem.MolFromSmiles(smiB)
    mA = rm_termini(mA)             #remove end groups
    mB = rm_termini(mB)
    smiA = Chem.MolToSmiles(mA, canonical=True) #convert mol back to smiles, ensure all are canonical
    smiB = Chem.MolToSmiles(mB, canonical=True)
    smiles = f'{smiA}.{smiB}' #"polymerise" by joining the monomers together
    return smiles

df_HER.loc[:, 'Polymers'] = [make_master_chemprop_input(sA, sB) for sA, sB in zip(df_HER.loc[:, 'MonA'], df_HER.loc[:, 'MonB'])]
#display(df_HER.head())

## 2a) Make molecule datapoints

As in part 1.1, we define our SMILES and our targets - but this time, we define 3 targets simultaneously.


In [None]:
#Get SMILES and targets
smiles = df_HER.loc[:, 'MonB'].values
target_columns = ["HER","T","Eg"]       # here we consider the 3 targets
#target_columns = ["HER"]               # uncomment this line for single target mode

targets = df_HER.loc[:, target_columns].values

display(smiles[:2]) # show first 2 SMILES strings
display(targets[:2]) # show first 2 targets

#Use the SMILES to generate mol objects, pair the mol objects with the targets y
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smiles, targets)]
display(all_data[:2])     #note this is different from the mol object we used before in RDKIt - the "MoleculeDatapoint" contains the mol *and* the targets, plus some extra customisation options that we're not going to use.

#We can still access the molecules and extract chemical info if we like:
#all_data[0].mol.GetNumAtoms()   #Can use RDKit things on the mol, but not on the MoleculeDatapoint

## 2b) Perform data splitting for training, validation, and testing

ChemProp's `make_split_indices` function will always return a two- (if no validation) or three-length tuple (if including validation, like in this example). The inner lists then contain the actual indices for splitting.

The type signature for this return type is `tuple[list[list[int]], ...]`.

In [None]:
# available split types - Kennard Stone, Kmeans and scafford balanced are all structure-based splits (similar structures go to same split) based on https://jacksonburns.github.io/astartes/
#list(data.SplitType.keys())

mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "KENNARD_STONE", (0.8, 0.1, 0.1))  # unpack the tuple into three separate lists. data is a Chemprop function

#display(test_indices[0]) #test indices
display("test_indices length =", len(test_indices[0]))     #~6138/10. (Note that the list of test indices is nested)

train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices)     #Use the 3 lists of indices to split the data.

#display(len(test_data[0]))
#display(test_data[0][0])  #same format as in 2a
display(test_indices)

## 2c) Make the three MoleculeDatasets

In [None]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer() #chemprop module which outputs a MolGraph

train_dset = data.MoleculeDataset(train_data[0], featurizer)  #append the type of featuriser to the list of molecule to make the molecule dataset. (We use train_data[0] because of the aforementioned nesting)
#The MoleculeDataset "train_dset" is a list of MoleculeDatapoints, with the featurizer at the end:
#display(len(test_dset))
#display(test_dset)       #this should give MoleculeDataset(data=[MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f7a6b6ed5b0>, y=array([0.23384385]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='C', V_f=None, E_f=None, V_d=None), MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f7a6b6ed690>, y=array([0.74433064]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CC', V_f=None, E_f=None, V_d=None)], featurizer=SimpleMoleculeMolGraphFeaturizer(atom_featurizer=<chemprop.featurizers.atom.MultiHotAtomFeaturizer object at 0x7f7a6b52c290>, bond_featurizer=<chemprop.featurizers.bond.MultiHotBondFeaturizer object at 0x7f7a6b52c150>))
#display(test_dset.data[:1])      #test_dset.data is the same as test_data, without nesting;

#each MoleculeDatapoint now has MolGraph features which are accessed through indexing. Compare to Part (4) - its the same but we have MolGraph features instead of mol objects.
#display(test_dset[0])

scaler = train_dset.normalize_targets() #define the normalisation using StandardScaler (subtract mean, scale to unit variance)


#Do the same for validation and test
val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(scaler)            #normalise the validation dataset in the same way as the training

test_dset = data.MoleculeDataset(test_data[0], featurizer)      #no normalisation


In [None]:
#display the train_dset y values in a graph

import matplotlib.pyplot as plt
# Extract y values from train_dset
train_y_values = [datapoint.y[0] for datapoint in train_dset]

# Create a histogram of the y values
plt.figure(figsize=(5, 3))
plt.hist(train_y_values, bins=20, edgecolor='black')
plt.title('Distribution of train_dset y values (HER)')
plt.xlabel('HER')
plt.ylabel('Frequency')
plt.grid(axis='y', alpha=0.75)
plt.show()

# If you also want to see the distribution for the second target (Transmission)
train_y_trans_values = [datapoint.y[1] for datapoint in train_dset]

plt.figure(figsize=(5, 3))
plt.hist(train_y_trans_values, bins=20, edgecolor='black')
plt.title('Distribution of train_dset y values (Transmission)')
plt.xlabel('Transmission')
plt.ylabel('Frequency')
plt.grid(axis='y', alpha=0.75)
plt.show()

#and Eg
train_y_trans_values = [datapoint.y[2] for datapoint in train_dset]

plt.figure(figsize=(5, 3))
plt.hist(train_y_trans_values, bins=20, edgecolor='black')
plt.title('Distribution of train_dset y values (Eg)')
plt.xlabel('Eg')
plt.ylabel('Frequency')
plt.grid(axis='y', alpha=0.75)
plt.show()


# 3) Message-Passing Neural Network input parameters

>Now our data is ready, its time to use Chemprop.There are 3 main steps: Message passing, aggregation and the feed-forward NN.

> For more info on a step: https://chemprop.readthedocs.io/en/latest/tutorial/python/index.html

## 3a) Message Passing and Aggregation

In `nn.MeanAggregation()`, we can set the number of tasks to make a multi-task predictor.

Try changing the learning rates!

In [None]:
#Message passing: construct node-level representations of the atoms
#Define message passing type. Can pass different activations, dropout, etc etc.
#Defaults: 300 hidden dimensions. 3 message passing iterations. ReLU.
dim = 300
mp = nn.BondMessagePassing(d_h=dim, dropout=0.1)    #https://chemprop.readthedocs.io/en/latest/autoapi/chemprop/nn/index.html#chemprop.nn.BondMessagePassing

#Aggregation: node-level --> graph-level representation
agg = nn.MeanAggregation()    #average together all of the hidden node/edge representations to get a graph-level representation.
#Note this is the only place where the two monomers in each pair interact!
batch_norm = True #normalizes the outputs of the aggregation by re-centering and re-scaling. Helps keep the inputs to the FFN small and centered around zero.


output_transform = nn.UnscaleTransform.from_standard_scaler(scaler) #"un-scale" the data

#Regression options include:
#ffn = nn.RegressionFFN()`
#ffn = nn.MveFFN()`
#ffn = nn.EvidentialFFN()`

#define the feed-forward network:
ffn = nn.RegressionFFN(output_transform=output_transform, input_dim = dim, hidden_dim = 300, dropout=0.1, n_tasks = 2) #n_tasks sets the number of targets. Can change number of layers etc etc.
#set the metrics:

metric_list = [nn.metrics.RMSE(),nn.metrics.MAE(), nn.metrics.R2Score()] # Only the first metric is used for training and early stopping

mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list, init_lr = 1e-5, max_lr = 5e-5) #can change learning rate, optimiser etc
#mpnn
#Entire model, consisting of message passing, aggregation and FFN, is end-to-end trained.

#In the message passing NN:
#w_i = input weights, applied to the bond feature vectors. Length = sum of bond and atom features. Output hidden dimension is 300 by default.
#w_h = hidden weights, applied to the messages.
#w_o = output weights.

#In the FFN:
#Input dimension is the same as the MPNN hidden dimension (300)

# 4) Set up and start the trainer

In [None]:
#Get DataLoader

num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading. Specifies how many subprocesses should be used to load data.
#Each of these subprocesses retrieves a batch of data from your dataset and sends it to the main training process.

train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
#---------------------------------------------------------------------------------------------------------------------------------
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)
#---------------------------------------------------------------------------------------------------------------------------------

from lightning.pytorch.loggers import CSVLogger

# Configure CSV logger
logger = CSVLogger("/content/chemprop/", name="losses")

trainer = pl.Trainer(
    logger=logger,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=500, # number of epochs to train for
    callbacks=[checkpointing], # Use the configured checkpoint callback
)


#start training
trainer.fit(mpnn, train_loader, val_loader)

# 5) Test results

In [None]:
df = pd.read_csv("/content/chemprop/losses/version_16/metrics.csv")    #change version number starting at 0 each time you restart colab

df['train_loss_epoch'] = df['train_loss_epoch'].shift(-1)
df.dropna(subset=['train_loss_epoch'], inplace=True)

#display(df)


# Plot val_loss and train_loss_epoch
plt.plot(df["epoch"], df["val_loss"], label="val_loss")
plt.plot(df["epoch"], df["train_loss_epoch"], label="train_loss_epoch")

# Customize the plot
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)
plt.xlim(0, 500)


# Show the plot
plt.show()

results = trainer.test(dataloaders=test_loader)

In [None]:
full_dset = data.MoleculeDataset(all_data, featurizer=featurizer)
full_loader = data.build_dataloader(full_dset, shuffle=False)

predictions = trainer.predict(mpnn, full_loader)
predictions_array = np.concatenate(predictions, axis=0)
#append the predicted values to the original array

df_HER[['pred_HER']] = predictions_array
#df_HER[['pred_HER', 'pred_T', 'pred_Eg']] = predictions_array

#df_HER[['pred_HER']] = df_HER[['sqrt_pred_HER']]*df_HER[['sqrt_pred_HER']]
#df_HER

Plot prediction vs actual

In [None]:
plt.figure(figsize=(5, 3))

# Plot pred vs true
plt.scatter(df_HER['HER'], df_HER['pred_HER'], label='Prediction', marker='x', color='red')

# Plot y=x
plt.plot(df_HER['HER'], df_HER['HER'], label='y=x', linestyle='-')

# Filter df_HER for rows corresponding to test_indices and plot with green circles
test_indices_flat = test_indices[0]
df_test = df_HER.iloc[test_indices_flat]

plt.scatter(df_test['HER'], df_test['pred_HER'], label='Test Prediction', marker='o', color='green')

plt.xlabel('true HER')
plt.ylabel('Predicted HER')
plt.title('Pred vs true HER')
plt.legend()
plt.grid(True)
plt.show()




In [None]:
plt.figure(figsize=(5, 3))

# Plot pred vs true
plt.scatter(df_HER['T'], df_HER['pred_T'], label='Prediction', marker='x', color='red')

# Plot y=x
plt.plot(df_HER['T'], df_HER['T'], label='y=x', linestyle='-')

plt.xlabel('true T')
plt.ylabel('Predicted T')
plt.title('Pred vs true T')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# Plot prediction vs actual Eg
plt.figure(figsize=(5, 3))

# Plot pred vs true
plt.scatter(df_HER['Eg'], df_HER['pred_Eg'], label='Prediction', marker='x', color='red')

# Plot y=x
plt.plot(df_HER['Eg'], df_HER['Eg'], label='y=x', linestyle='-')

plt.xlabel('true Eg')
plt.ylabel('Predicted Eg')
plt.title('Pred vs true Eg')
plt.legend()
plt.grid(True)
plt.show()

# 6) Bonus: Transfer learning

Transfer learning (or pretraining) leverages knowledge from a pre-trained model on a related task to enhance performance on a new task. In Chemprop, we can use pre-trained model checkpoints to initialize a new model and freeze components of the new model during training.

By repeating the steps in 1.1, pre-train a GNN using the EA and/or IP data. Then use the data in 1.2 and the info below to fine-tune using the experimental HER data.

Can you train a model to a similar level of accuracy to the one above? Why and/or why not?

## 6a) Make the MoleculeDatasets by following the steps in Section 2.

> The previously-trained model was trained on a scaled dataset. The scaler is saved as part of the model and used during prediction. For further training, we need to scale the fine-tuning data with the same target scaler.

In [None]:
#2a
#Get SMILES and targets
df_HER = df_HER.replace('ND', pd.NA)
df_HER = df_HER.dropna()

display(df_HER)

smiles_HER = df_HER.loc[:, 'df_HER_polymers'].values
targets_HER = df_HER.loc[:,['sqrtHER']].values       #using sqrtHER as better distributed

#display(smiles_HER[:2]) # show first 2 SMILES strings
#display(targets_HER[:2]) # show first 2 targets

#Use the SMILES to generate mol objects, pair the mol objects with the targets y
HER_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smiles_HER, targets_HER)]
#display(HER_data[:2])


#2b
mols_HER = [d.mol for d in HER_data]  # RDkit Mol objects are use for structure based splits
train_indices_HER, val_indices_HER, test_indices_HER = data.make_split_indices(mols_HER, "random", (0.1, 0.1, 0.8))  # unpack the tuple into three separate lists. data is a Chemprop function

#display("test_indices length =", len(test_indices_HER[0]))     #~6138/10. (Note that the list of test indices is nested)

train_data_HER, val_data_HER, test_data_HER = data.split_data_by_indices(
    HER_data, train_indices_HER, val_indices_HER, test_indices_HER)     #Use the 3 lists of indices to split the data.

#new bit
checkpoint_path = "/content/chemprop/checkpoints/best-epoch=8-val_loss=0.03.ckpt"     ##Put the name of your best checkpoint here! These are the learned mpnn parameters from the earlier training.
#checkpoint_path = "https://raw.githubusercontent.com/S-AJ-H/Chemprop-Tutorial/main/best-epoch=9-val_loss=0.04.ckpt"
mpnn_cls = models.MPNN
mpnn = mpnn_cls.load_from_file(checkpoint_path)

#display(mpnn)

#new bit: scaling
#pretraining_scaler = scaler
#pretraining_scaler.mean_ = mpnn.predictor.output_transform.mean.numpy()
#pretraining_scaler.scale_ = mpnn.predictor.output_transform.scale.numpy()

#2c
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset_HER = data.MoleculeDataset(train_data_HER[0], featurizer)
#train_dset_HER.normalize_targets(pretraining_scaler)
scaler_HER = train_dset_HER.normalize_targets()


val_dset_HER = data.MoleculeDataset(val_data_HER[0], featurizer)
#val_dset_HER.normalize_targets(pretraining_scaler)
val_dset_HER.normalize_targets(scaler_HER)
test_dset_HER = data.MoleculeDataset(test_data_HER[0], featurizer)

#define the new scaling for the feed-forward network:
output_transform_HER = nn.UnscaleTransform.from_standard_scaler(scaler_HER) #"un-scale" the data
mpnn.predictor.output_transform = output_transform_HER

from sklearn.preprocessing import StandardScaler
df_HER['sqrtHER_sscale'] = StandardScaler().fit_transform(df_HER[['sqrtHER']])


In [None]:
#look at the data distributions
import seaborn as sns
# Original
plt.subplot(1, 3, 1)
sns.histplot(df_HER['HER'], bins=30, kde=True)
plt.title('Original Skewed HER')

# Log-transformed
plt.subplot(1, 3, 2)
sns.histplot(df_HER['sqrtHER'], bins=30, kde=True)
plt.title('After Log Transformation')

# Standard Scaled (after log)
plt.subplot(1, 3, 3)
sns.histplot(df_HER['sqrtHER_sscale'], bins=30, kde=True)
plt.title('After dqrt + StandardScaler')

## 6b) Freezing MPNN and FFN layers
Certain layers of a pre-trained model can be kept unchanged during further training on a new task.

In [None]:
#To freeze the MPNN (i.e. the learned representation of the molecules)
mpnn.message_passing.apply(lambda module: module.requires_grad_(False))
mpnn.message_passing.eval()
mpnn.bn.apply(lambda module: module.requires_grad_(False))
mpnn.bn.eval()  # Set batch norm layers to eval mode to freeze running mean and running var.

#To freeze the FFN
#frzn_ffn_layers = 1  # the number of consecutive FFN layers to freeze.
#for idx in range(frzn_ffn_layers):
#   mpnn.predictor.ffn[idx].requires_grad_(False)
#   mpnn.predictor.ffn[idx + 1].eval()

## 6c) Train and test by following the steps in Sections 4 and 5.

> Are the predictions for HER better or worse than for EA?

In [None]:
#4
#Get DataLoader

num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading. Specifies how many subprocesses should be used to load data.
#Each of these subprocesses retrieves a batch of data from your dataset and sends it to the main training process.

train_loader_HER = data.build_dataloader(train_dset_HER, num_workers=num_workers)
val_loader_HER = data.build_dataloader(val_dset_HER, num_workers=num_workers, shuffle=False)
test_loader_HER = data.build_dataloader(test_dset_HER, num_workers=num_workers, shuffle=False)
#---------------------------------------------------------------------------------------------------------------------------------
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-HER-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)
#---------------------------------------------------------------------------------------------------------------------------------

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=200, # number of epochs to train for
    callbacks=[checkpointing], # Use the configured checkpoint callback
)


#start training
mpnn.train()
trainer.fit(mpnn, train_loader_HER, val_loader_HER)
#5
results = trainer.test(dataloaders=test_loader_HER) #notice how there are fewer trainable parameters

In [None]:
full_dset_HER = data.MoleculeDataset(HER_data, featurizer=featurizer)
full_loader_HER = data.build_dataloader(full_dset_HER, shuffle=False)

predictions_HER = trainer.predict(mpnn, full_loader_HER)

#append the predicted values to the original array
predictions_array_HER = np.concatenate(predictions_HER, axis=0)
df_HER[['sqrt_pred_HER']] = predictions_array_HER
df_HER['pred_HER'] = np.square(df_HER['sqrt_pred_HER'])
df_HER


In [None]:
plt.figure(figsize=(10, 6))

# Plot pred vs true
plt.scatter(df_HER['HER'], df_HER['pred_HER'], label='Prediction', marker='x', color='red')

# Plot y=x
plt.plot(df_HER['HER'], df_HER['HER'], label='y=x', linestyle='-')

plt.xlabel('true HER')
plt.ylabel('Predicted HER')
plt.title('Pred vs true HER, transferEAandIP, freeze_MPNN, sqrt_transform')
plt.xlim(0, 10000)    # x-axis
plt.ylim(0, 10000)  # y-axis
plt.legend()
plt.grid(True)
plt.show()