# Chemprop 2.x Pure‑Python Workflow on Colab 🚀

*GPU‑Accelerated Graph Neural‑Network Regression Without the CLI*

## Learning Objectives
1. Use **Google Colab GPU** resources to train a Chemprop ≥ 2.0 model entirely from **Python code (no CLI flags)**.
2. Make **test‑set predictions** and create a **parity plot** with `matplotlib` / `seaborn`.

## 0 — Runtime & GPU Check
Make sure your Colab session uses a GPU (**Runtime → Change runtime type → GPU**).

In [1]:
import torch, os, platform, sys
print('PyTorch :', torch.__version__)
print('CUDA?   :', torch.cuda.is_available())
if torch.cuda.is_available():
    print('Device  :', torch.cuda.get_device_name(0))
else:
    print('⚠️ Training will be CPU‑only (slow).')

PyTorch : 2.2.2
CUDA?   : False
⚠️ Training will be CPU‑only (slow).


### Install Chemprop 2.x

In [2]:
# !pip install --quiet 'chemprop>=2.0'

## 1 — Load the Delaney Solubility Dataset

In [3]:
import pandas as pd, requests, io, os, math, numpy as np
url = 'https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/delaney-processed.csv'
df = pd.read_csv(io.StringIO(requests.get(url).text))
df = df.rename(columns={'measured log solubility in mols per litre': 'solubility'})[['smiles','solubility']]
data_path = 'delaney.csv'
df.to_csv(data_path, index=False)
df.head()

Unnamed: 0,smiles,solubility
0,OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...,-0.77
1,Cc1occc1C(=O)Nc2ccccc2,-3.3
2,CC(C)=CCCC(C)=CC(=O),-2.06
3,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43,-7.87
4,c1ccsc1,-1.33


## 2 — Scaffold‑Balanced Train/Val/Test Split

In [5]:
from rdkit import Chem
from chemprop.data import MoleculeDatapoint, MoleculeDataset
from chemprop.data.splitting import make_split_indices          # NEW ✅

# ---- Build MoleculeDataset ----
dataset = MoleculeDataset([
    MoleculeDatapoint(smiles=[s], targets=[v])
    for s, v in zip(df.smiles, df.solubility)
])

# ---- Generate scaffold-balanced indices ----
mols = [Chem.MolFromSmiles(s) for s in df.smiles]               # RDKit molecules
train_idx, val_idx, test_idx = make_split_indices(
    mols,
    split='scaffold_balanced',
    sizes=(0.8, 0.1, 0.1),
    seed=42
)

# ---- Slice the MoleculeDataset ----
train = [dataset[i] for i in train_idx]
val   = [dataset[i] for i in val_idx]
test  = [dataset[i] for i in test_idx]

print(len(train), len(val), len(test))

TypeError: MoleculeDatapoint.__init__() got an unexpected keyword argument 'smiles'

## 3 — Train Chemprop (Python API Only)

In [None]:
from chemprop.args import TrainArgs
from chemprop.train import run_training

args = TrainArgs()
# --- essential paths ---
args.train_data_path = train_path
args.val_data_path = val_path
args.save_dir = 'chemprop_model_py'

# --- basic settings ---
args.dataset_type = 'regression'
args.target_columns = ['solubility']
args.epochs = 30
args.batch_size = 32
args.gpu = 0 if torch.cuda.is_available() else -1
args.metric = 'rmse'
# You can tweak more hyperparameters here, e.g. args.hidden_size, args.depth

run_training(args)

## 4 — Predict on Held‑out Test Set

In [None]:
from chemprop.train import load_checkpoint, make_predictions
from chemprop.args import PredictArgs

# Find the best checkpoint saved during training
ckpt_dir = args.save_dir
checkpoint_path = os.path.join(ckpt_dir, 'model_0', 'model.pt')

pred_args = PredictArgs()
pred_args.test_path = test_path
pred_args.checkpoint_paths = [checkpoint_path]
preds = make_predictions(pred_args)

# Combine predictions with truth
test_df = pd.read_csv(test_path)
test_df['pred'] = preds
test_df.head()

## 5 — Visualise Parity Plot

In [None]:
import matplotlib.pyplot as plt, seaborn as sns
sns.set(style='ticks')

plt.figure(figsize=(5,5))
sns.scatterplot(x='solubility', y='pred', data=test_df)
lims = [test_df[['solubility','pred']].min().min(), test_df[['solubility','pred']].max().max()]
plt.plot(lims, lims, '--k')
plt.xlabel('True log S')
plt.ylabel('Predicted log S')
plt.title('Chemprop 2.x Parity Plot (Python API)')
plt.show()

rmse = math.sqrt(((test_df.solubility - test_df.pred)**2).mean())
print(f'RMSE : {rmse:.3f}')

## Your Turn 📝
1. Swap in **your own dataset** (same columns: `smiles`, target).  
2. Experiment with `args.hidden_size`, `args.depth`, `args.dropout`.  
3. Try `args.ensemble_size = 5` for an ensemble model.  
4. Create residual plots or error vs. molecular weight.