<a href="https://colab.research.google.com/github/S-AJ-H/AIMS26/blob/main/4b_Project_B_Structure_based_splitting_with_Chemprop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 4b. Project B: Structure splitting with Chemprop

Thus far, we have used random splitting to separate our test and validation datasets. In most chemistry problems we are trying to extrapolate our predictions beyond our training data, aiming to predict structure-property relationships on materials we know nothing about. It is therefore particularly important to consider how we choose our test and validation datasets.

In this project, you will be predicting electron affinities of polymer photocatalysts using the same dataset as in the workshop. Your goal is to design a method of splitting the polymers into test and validation datasets on the basis of the similarity of their chemical structures. Using your new splitting method, optimise and evaluate a Chemprop model.  

---

## Project Tasks:
   
1.   Split the data into 9 folds, which each fold comprising a single Monomer A and variable Monomer B. Analyse the distribution of EAs in each fold. Train this model and compare to random splitting.
2.   Use your new splitting to train a model using fixed representations and compare it with the model from Task 1.

Evaluate the performances of these models against each other and draw conclusions on how structure-based splitting affects the architecture, hyperparameters, training and performance of your model. You might need to utilise some of the resources given below.

---

## In your presentations, you are expected to:

1.   Define the project problem and discuss its real-world applications.
2.   Explain the model architecture and the reasons for using it for the specific problem, with a focus on how it is different from the models from the workshop earlier in the week.
3.   Describe the training process and show training loss curves. How does the splitting methodology affect training?
4.   Discuss the impact of your hyperparameter optimisations. Explain why your reasoning for hyperparameter selection and tuning. Present your best hyperparameters.
5.   Present key performance metrics from your best model. Present any notable failures. How does the splitting methodology affect performance? Compare against the fixed representations model.
6.   What is limiting the model and how it could be further improved?

---

## Extension Tasks:

* Create multi-task models, in which EA and IP are predicted simultaneously. Does model performance improve? Why/why not?
* Evaluate the performance of these models as a function of dataset size.
* Use Morgan fingerprint similarity to split the training and validation datasets. Train a new model and compare it to the two previous approaches. Comment on any advantages or drawbacks.

---

## Resources:

>RDKit:   
>https://rdkit.org/docs/index.html

>Chemprop:  
>https://pubs.acs.org/doi/10.1021/acs.jcim.9b00237  
>https://pubs.acs.org/doi/10.1021/acs.jcim.3c01250  
>https://chemprop.readthedocs.io/en/latest/

>Data from:  
>https://pubs.acs.org/doi/full/10.1021/jacs.9b03591


##0. Install Chemprop from GitHub

In [None]:
# Chemprop (~1min)
!pip install chemprop -qq
import chemprop
print("Imported Chemprop version", chemprop.__version__)

from rdkit import Chem                                                  # rdkit is used to convert SMILES to molecular graphs ("mols")
from rdkit.Chem import Draw                                             # Lets us draw molecules
from chemprop import data, featurizers, models, nn                      # chemprop is our GNN package

# ML
import lightning.pytorch as pl                                          # lightning has built-in functions for lots of the basics (metric tracking etc); Chemprop is built on this.
from lightning.pytorch.callbacks import ModelCheckpoint,EarlyStopping
from lightning.pytorch.loggers import CSVLogger                         # Configure CSV logger for tracking losses
import logging
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
from sklearn.model_selection import train_test_split, KFold, PredefinedSplit
from sklearn.metrics import r2_score, mean_absolute_error

# Misc
import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
version = 0                                                             # used for save files

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.4/150.4 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m853.6/853.6 kB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.6/36.6 MB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m40.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.0/188.0 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m857.3/857.3 kB[0m [31m26.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

##1. Load data



In [None]:
#Get the polymer SMILES from GitHub.
csv_url = "https://raw.githubusercontent.com/S-AJ-H/AIMS26/25478252292fe3bde0e4fb06977ea21c7e05545a/dataset.csv"
df_data = pd.read_csv(csv_url)
display(df_data)

Unnamed: 0,poly_ID,poly_SMI,EA,IP
0,0_0,Fc1ccc(F)cc1.Oc1cc(O)cc(O)c1,-3.406210,1.808017
1,0_1,Fc1ccc(F)cc1.N#Cc1ccc(F)cc1,-2.472685,2.635116
2,0_2,Fc1ccc(F)cc1.Cc1cc(N)ccc1Cl,-3.459588,1.454940
3,0_3,Fc1ccc(F)cc1.O=C(Cl)COc1ccccc1,-2.842112,2.066126
4,0_4,Fc1ccc(F)cc1.COC(=O)c1cccc(N)n1,-2.724312,1.757542
...,...,...,...,...
6133,8_677,O=S1(=O)c2ccccc2-c2ccccc21.O=[N+]([O-])c1ccc(F...,-1.704934,2.202737
6134,8_678,O=S1(=O)c2ccccc2-c2ccccc21.O=C(O)c1cccs1,-2.010124,1.951355
6135,8_679,O=S1(=O)c2ccccc2-c2ccccc21.O=[N+]([O-])c1ccccn1,-1.718059,2.356684
6136,8_680,O=S1(=O)c2ccccc2-c2ccccc21.CC(Oc1ccccc1)C(=O)NN,-2.740288,1.583449


##2. Prepare data for machine learning

##3. Define, train and validate model

##4. Analyse results