# Running hyperparameter optimization on Chemprop model using RayTune

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/hpopting.ipynb)

In [1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install ".[hpopt]"
        %cd examples

## Import packages

In [2]:
from pathlib import Path
import sys

import pandas as pd
from lightning import pytorch as pl
import numpy as np
import ray
from ray import tune
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.lightning import (RayDDPStrategy, RayLightningEnvironment,
                                 RayTrainReportCallback, prepare_trainer)
from ray.train.torch import TorchTrainer
from ray.tune.search.hyperopt import HyperOptSearch
from ray.tune.search.optuna import OptunaSearch
from ray.tune.schedulers import FIFOScheduler
import torch

from chemprop import data, featurizers, models, nn

# sys.path.insert(0, '../agenticadmet')
# from raytune_extra import RayTrainReportCallback

In [3]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
input_path = Path('../data/asap/datasets/rnd_splits/split_0.csv')
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles_std' # name of the column containing SMILES strings
target_columns = ['LogHLM', 'LogMLM', 'LogD', 'LogKSOL', 'LogMDR1-MDCKII'] # list of names of the columns containing targets

hpopt_save_dir = Path('../output/asap/rnd_splits/chemprop/run_0/split_0/hpopt') # directory to save hyperopt results
hpopt_save_dir.mkdir(exist_ok=True)

## Load data

In [5]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,HLM,KSOL,LogD,MLM,MDR1-MDCKII,smiles_std,cxsmiles_std,mol_idx,smiles_ext,LogHLM,LogMLM,LogKSOL,LogMDR1-MDCKII,split
0,COC1=CC=CC(Cl)=C1NC(=O)N1CCC[C@H](C(N)=O)C1 |a...,,,0.3,,2.0,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1 |a:16|,191,|a:16|,,,,0.477121,val
1,O=C(NCC(F)F)[C@H](NC1=CC2=C(C=C1Br)CNC2)C1=CC(...,,333.0,2.9,,0.2,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,335,|&1:7|,,,2.523746,0.079181,train
2,O=C(NCC(F)F)[C@H](NC1=CC=C2CNCC2=C1)C1=CC(Br)=...,,,0.4,,0.5,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,336,|&1:7|,,,,0.176091,train
3,NC(=O)[C@H]1CCCN(C(=O)CC2=CC=CC3=C2C=CO3)C1 |&...,,376.0,1.0,,8.5,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1 |&1:3|,300,|&1:3|,,,2.576341,0.977724,train
4,CC1=CC(CC(=O)N2CCC[C@H](C(N)=O)C2)=CC=N1 |&1:11|,,375.0,-0.3,,0.9,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1 |&1:11|,249,|&1:11|,,,2.575188,0.278754,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
399,CC(C)NC[C@H](O)COC1=CC=CC2=CC=CC=C12 |&1:5|,25.5,,,63.0,,CC(C)NC[C@H](O)COc1cccc2ccccc12,CC(C)NC[C@H](O)COc1cccc2ccccc12 |&1:5|,22,|&1:5|,1.423246,1.806180,,,val
400,O=C(O)CC1=CC=CC=C1NC1=C(Cl)C=CC=C1Cl,216.0,,,386.0,,O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,380,,2.336460,2.587711,,,val
401,NCC1=CC(Cl)=CC(C(=O)NC2=CC=C3CNCC3=C2)=C1,,,2.0,,,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,303,,,,,,train
402,COC(=O)NC1=NC2=CC=C(C(=O)C3=CC=CC=C3)C=C2N1,,,2.9,,,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,166,,,,,,train


## Make data points, splits, and datasets

In [6]:
train_data, val_data = [], []
for _, row in df_input.iterrows():
    dp = data.MoleculeDatapoint.from_smi(row[smiles_column], row[target_columns].values)
    if row['split'] == 'train':
        train_data.append(dp)
    elif row['split'] == 'val':
        val_data.append(dp)

In [7]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data, featurizer)
# scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data, featurizer)
# val_dset.normalize_targets(scaler)

# Define helper function to train the model

In [8]:
# def train_model(config, train_dset, val_dset, num_workers, scaler):
def train_model(config, train_dset, val_dset, num_workers):
    # config is a dictionary containing hyperparameters used for the trial
    depth = int(config["depth"])
    ffn_hidden_dim = int(config["ffn_hidden_dim"])
    ffn_num_layers = int(config["ffn_num_layers"])
    message_hidden_dim = int(config["message_hidden_dim"])
    dropout = float(config["dropout"])

    train_loader = data.build_dataloader(train_dset, num_workers=num_workers, shuffle=True)
    val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)

    mp = nn.BondMessagePassing(d_h=message_hidden_dim, depth=depth, dropout=dropout)
    agg = nn.MeanAggregation()
    # output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
    # ffn = nn.RegressionFFN(output_transform=output_transform, input_dim=message_hidden_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers)
    ffn = nn.RegressionFFN(
        n_tasks=len(target_columns),
        output_transform=None, input_dim=message_hidden_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers,
        dropout=dropout
    )
    batch_norm = True
    metric_list = [nn.metrics.MAE(), nn.metrics.R2Score()]
    model = models.MPNN(mp, agg, ffn, batch_norm, metric_list)

    trainer = pl.Trainer(
        accelerator="auto",
        devices=1,
        max_epochs=200, # number of epochs to train for
        # below are needed for Ray and Lightning integration
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
        enable_progress_bar=False,
        enable_checkpointing=False
    )

    trainer = prepare_trainer(trainer)
    trainer.fit(model, train_loader, val_loader)

## Define parameter search space

In [9]:
search_space = {
    "depth": tune.qrandint(lower=2, upper=6, q=1),
    "ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
    "ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1),
    "message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
    "dropout": tune.uniform(0.0, 0.2)
}

In [10]:
ray.shutdown()
ray.init(include_dashboard=False)

scheduler = FIFOScheduler()

# Scaling config controls the resources used by Ray
scaling_config = ScalingConfig(
    num_workers=1,
    use_gpu=True, # set to True if you want to use GPU
)

# Checkpoint config controls the checkpointing behavior of Ray
checkpoint_config = CheckpointConfig(
    num_to_keep=1, # number of checkpoints to keep
    checkpoint_score_attribute="val_loss", # Save the checkpoint based on this metric
    checkpoint_score_order="min", # Save the checkpoint with the lowest metric value
    checkpoint_frequency=0,        # Do not checkpoint during training
)

run_config = RunConfig(
    checkpoint_config=checkpoint_config,
    storage_path=(hpopt_save_dir / "ray_results").absolute(), # directory to save the results
)

ray_trainer = TorchTrainer(
    lambda config: train_model(
        # config, train_dset, val_dset, num_workers, scaler
        config, train_dset, val_dset, num_workers
    ),
    scaling_config=scaling_config,
    run_config=run_config,
)

search_alg = HyperOptSearch(
    n_initial_points=10, # number of random evaluations before tree parzen estimators
    random_state_seed=RANDOM_SEED,
)

# OptunaSearch is another search algorithm that can be used
# search_alg = OptunaSearch()

tune_config = tune.TuneConfig(
    metric="val_loss",
    mode="min",
    num_samples=30, # number of trials to run
    scheduler=scheduler,
    search_alg=search_alg,
    trial_dirname_creator=lambda trial: str(trial.trial_id), # shorten filepaths
)

tuner = tune.Tuner(
    ray_trainer,
    param_space={
        "train_loop_config": search_space,
    },
    tune_config=tune_config,
)

# Start the hyperparameter search
results = tuner.fit()

0,1
Current time:,2025-03-11 23:38:49
Running for:,02:05:23.30
Memory:,15.3/58.9 GiB

Trial name,status,loc,train_loop_config/de pth,train_loop_config/dr opout,train_loop_config/ff n_hidden_dim,train_loop_config/ff n_num_layers,train_loop_config/me ssage_hidden_dim,iter,total time (s),train_loss,train_loss_step,val/mae
TorchTrainer_fc14306f,TERMINATED,10.128.0.3:524465,4,0.018611,2000,2,500,200,145.428,0.0535116,0.164948,0.358812
TorchTrainer_e07c1499,TERMINATED,10.128.0.3:527951,2,0.173402,1800,2,1800,200,246.244,0.155736,0.263326,0.37329
TorchTrainer_959b378a,TERMINATED,10.128.0.3:533265,3,0.0257109,400,1,400,200,89.1663,0.065592,0.185106,0.336174
TorchTrainer_5b99c550,TERMINATED,10.128.0.3:535442,3,0.0786423,700,1,2200,200,307.261,0.110752,0.304377,0.386981
TorchTrainer_fa17f5a1,TERMINATED,10.128.0.3:541831,2,0.156513,2200,1,2000,200,224.036,0.152651,0.24385,0.406797
TorchTrainer_3d223470,TERMINATED,10.128.0.3:547132,3,0.0341362,800,1,1300,200,188.97,0.0735224,0.162442,0.371128
TorchTrainer_c2ab6226,TERMINATED,10.128.0.3:551331,3,0.114586,700,2,1600,200,233.583,0.124914,0.191656,0.365408
TorchTrainer_7d8c0ad6,TERMINATED,10.128.0.3:556778,5,0.0826331,2100,1,1800,200,387.38,0.0815,0.130444,0.373345
TorchTrainer_3fa10baa,TERMINATED,10.128.0.3:565116,2,0.0430375,2400,2,900,200,192.933,0.0947737,0.280282,0.402029
TorchTrainer_c1c21db2,TERMINATED,10.128.0.3:569455,3,0.0493471,1300,1,1200,200,178.831,0.0794652,0.122593,0.353521


[33m(raylet)[0m [2025-03-11 21:33:34,190 E 523488 523522] (raylet) file_system_monitor.cc:116: /var/tmp/ray/session_2025-03-11_21-33-22_789751_523270 is over 95% full, available space: 13.203 GB; capacity: 295.046 GB. Object creation will fail if spilling is required.
[36m(TorchTrainer pid=524465)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=524465)[0m - (node_id=bacb29b8d065787f0e01b2822b0783aff180964574869eeacae5f33c, ip=10.128.0.3, pid=524734) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=524734)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(RayTrainWorker pid=524734)[0m GPU available: True (cuda), used: True
[36m(RayTrainWorker pid=524734)[0m TPU available: False, using: 0 TPU cores
[36m(RayTrainWorker pid=524734)[0m HPU available: False, using: 0 HPUs
[36m(RayTrainWorker pid=524734)[0m /opt/conda/envs/admet/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.p

## Hyperparameter optimization results

In [11]:
# results of all trials
result_df = results.get_dataframe()
results_df = result_df.sort_values('val/mae')
results_df

Unnamed: 0,train_loss,train_loss_step,val/mae,val/r2,val_loss,train_loss_epoch,epoch,step,timestamp,checkpoint_dir_name,...,hostname,node_ip,time_since_restore,iterations_since_restore,config/train_loop_config/depth,config/train_loop_config/ffn_hidden_dim,config/train_loop_config/ffn_num_layers,config/train_loop_config/message_hidden_dim,config/train_loop_config/dropout,logdir
14,0.079665,0.099432,0.333156,0.778627,0.205901,0.079665,199,1200,1741731889,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,195.133757,200,4,1000,2,1100,0.067358,38b6ac4b
12,0.096029,0.292531,0.333983,0.787656,0.197503,0.096029,199,1200,1741731513,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,96.997631,200,4,300,2,300,0.067396,1493b484
19,0.088536,0.118046,0.334023,0.790624,0.194742,0.088536,199,1200,1741733177,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,246.797507,200,4,500,2,1500,0.062947,22003d01
2,0.065592,0.185106,0.336174,0.771428,0.212597,0.065592,199,1200,1741729324,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,89.166317,200,3,400,1,400,0.025711,959b378a
17,0.078883,0.247472,0.336263,0.784005,0.200899,0.078883,199,1200,1741732716,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,175.015653,200,4,1000,2,1000,0.065428,8be5f6f9
20,0.052391,0.14633,0.351267,0.757447,0.2256,0.052391,199,1200,1741733322,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,134.439033,200,4,1000,2,600,0.000733,7b2f677d
11,0.095477,0.182505,0.351294,0.761386,0.221937,0.095477,199,1200,1741731406,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,143.93182,200,4,1200,1,800,0.134314,6efeb906
9,0.079465,0.122593,0.353521,0.749816,0.232698,0.079465,199,1200,1741731132,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,178.831403,200,3,1300,1,1200,0.049347,c1c21db2
10,0.059563,0.220698,0.353864,0.770414,0.213539,0.059563,199,1200,1741731251,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,108.737776,200,5,400,1,400,0.001732,a6090c98
18,0.082483,0.2067,0.358206,0.751251,0.231363,0.082483,199,1200,1741732918,checkpoint_000199,...,dl-vladvin-1,10.128.0.3,189.780985,200,4,1000,2,1100,0.091567,28b4fb56


In [12]:
results_df.to_csv(hpopt_save_dir / 'results.csv', index=False)

In [13]:
ray.shutdown()