# Training

# Import packages

In [83]:
import pandas as pd
from pathlib import Path
import numpy as np

import ast
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import torch

from chemprop import data, featurizers, models, nn
from chemprop.models import load_mixed_model
from chemprop.utils import make_mol

# Change data inputs here

In [84]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "mixed_regression_input.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['molecule', 'atom'] # list of names of the columns containing targets

In [85]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,molecule,atom
0,CC,1.0,"[1,2]"
1,CCC,2.0,"[1,2,3]"
2,CCCO,3.0,"[1,2,3,5]"
3,CCOO,4.0,"[1,2,3,4]"
4,COO,5.0,"[1,3,5]"
5,CCOOO,6.0,"[1,7,3,4,5]"
6,COOO,7.0,"[2,5,3,2]"
7,CO,8.0,"[1,3]"
8,CCO,9.0,"[1,3,5]"
9,OO,10.0,"[5,7]"


## Get SMILES and targets

In [86]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns]

In [87]:
smis[:2] # show first 2 SMILES strings

array(['CC', 'CCC'], dtype=object)

In [88]:
ys[:5] # show first 5 molecule targets

Unnamed: 0,molecule,atom
0,1.0,"[1,2]"
1,2.0,"[1,2,3]"
2,3.0,"[1,2,3,5]"
3,4.0,"[1,2,3,4]"
4,5.0,"[1,3,5]"


In [89]:
flag = [] # mark which columns belong to which of (mol, atom, bond)
mol_Y, atom_Y, bond_Y = [], [], [] # target values for each type
for column in target_columns:
    index = 0
    column_type = df_input.iloc[index][column]
    if isinstance(column_type, float):
        for molecule in range(len(df_input)):
            mol_Y.append([df_input.iloc[molecule][column]])
        flag.append("mol")
    else:
        column_mol = make_mol(df_input.iloc[index][smiles_column], False, False)
        column_type = ast.literal_eval(column_type)
        while index < len(df_input) and column_mol.GetNumAtoms() == column_mol.GetNumBonds():
            index += 1
            column_mol = make_mol(df_input.iloc[index][smiles_column], False, False)
        column_type = ast.literal_eval(df_input.iloc[index][column])
        flag.append("atom") if len(column_type) == column_mol.GetNumAtoms() else flag.append("bond")

In [90]:
for molecule in range(len(df_input)):
    atom_list_props, bond_list_props = [], [] # list containing all target values, which we need to hstack for atom_Y and bond_Y
    for prop in range(len(target_columns)):
        if flag[prop] == "mol":
            continue
        np_prop = np.array(ast.literal_eval(df_input.iloc[molecule][target_columns[prop]]))
        np_prop = np.expand_dims(np_prop, axis=1)
        atom_list_props.append(np_prop) if flag[prop] == "atom" else bond_list_props.append(np_prop)
    if len(atom_list_props) > 0:
        atom_Y.append(np.hstack(atom_list_props))
    else:
        atom_Y = df_input[[]]
        atom_Y = atom_Y.to_numpy()
    if len(bond_list_props) > 0:
        bond_Y.append(np.hstack(bond_list_props))
    else:
        bond_Y = df_input[[]]
        bond_Y = bond_Y.to_numpy()

## Get molecule datapoints

In [91]:
mol_data = []
for smi, y in zip(smis, mol_Y):
    mol_data.append(data.MoleculeDatapoint.from_smi(smi, y, keep_h=True))

atom_data = []
for smi, y in zip(smis, atom_Y):
    atom_data.append(data.MoleculeDatapoint.from_smi(smi, y, keep_h=True))

bond_data = []
for smi, y in zip(smis, bond_Y):
    bond_data.append(data.MoleculeDatapoint.from_smi(smi, y, keep_h=True))

In [92]:
all_data = []
all_data.append(mol_data)
all_data.append(atom_data)
all_data.append(bond_data)

## Perform data splitting for training, validation, and testing

In [93]:
# available split types
list(data.SplitType.keys())

['SCAFFOLD_BALANCED',
 'RANDOM_WITH_REPEATED_SMILES',
 'RANDOM',
 'KENNARD_STONE',
 'KMEANS']

In [94]:
mol = [d.mol for d in all_data[0]]

train_indices, val_indices, test_indices = data.make_split_indices(mol, "random", (0.6, 0.2, 0.2))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

## Get All Datasets

In [95]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dsets = []
train_dsets.append(data.MoleculeDataset(train_data[0][0], featurizer)) if mol_Y else train_dsets.append(data.MockDataset())
train_dsets.append(data.AtomDataset(train_data[0][1], featurizer)) if atom_Y else train_dsets.append(data.MockDataset())
train_dsets.append(data.BondDataset(train_data[0][2], featurizer)) if bond_Y else train_dsets.append(data.MockDataset())

mol_scaler = train_dsets[0].normalize_targets()
atom_scaler = train_dsets[1].normalize_targets()
bond_scaler = train_dsets[2].normalize_targets()
train_dset = data.MolAtomBondDataset(train_dsets[0], train_dsets[1], train_dsets[2])

  train_dsets.append(data.BondDataset(train_data[0][2], featurizer)) if bond_Y else train_dsets.append(data.MockDataset())


In [96]:
val_dsets = []
val_dsets.append(data.MoleculeDataset(val_data[0][0], featurizer)) if mol_Y else val_dsets.append(data.MockDataset())
val_dsets.append(data.AtomDataset(val_data[0][1], featurizer)) if atom_Y else val_dsets.append(data.MockDataset())
val_dsets.append(data.BondDataset(val_data[0][2], featurizer)) if bond_Y else val_dsets.append(data.MockDataset())
val_dsets[0].normalize_targets(mol_scaler)
val_dsets[1].normalize_targets(atom_scaler)
val_dsets[2].normalize_targets(bond_scaler)
val_dset = data.MolAtomBondDataset(val_dsets[0], val_dsets[1], val_dsets[2])

  val_dsets.append(data.BondDataset(val_data[0][2], featurizer)) if bond_Y else val_dsets.append(data.MockDataset())


In [97]:
test_dsets = []
test_dsets.append(data.MoleculeDataset(test_data[0][0], featurizer)) if mol_Y else test_dsets.append(data.MockDataset())
test_dsets.append(data.AtomDataset(test_data[0][1], featurizer)) if atom_Y else test_dsets.append(data.MockDataset())
test_dsets.append(data.BondDataset(test_data[0][2], featurizer)) if bond_Y else test_dsets.append(data.MockDataset())
test_dset = data.MolAtomBondDataset(test_dsets[0], test_dsets[1], test_dsets[2])

  test_dsets.append(data.BondDataset(test_data[0][2], featurizer)) if bond_Y else test_dsets.append(data.MockDataset())


# Get Atom/Bond Slices

In [98]:
all_dsets = []
all_dsets.append(data.MoleculeDataset(all_data[0], featurizer)) if mol_Y else all_dsets.append(data.MockDataset())
all_dsets.append(data.AtomDataset(all_data[1], featurizer)) if atom_Y else all_dsets.append(data.MockDataset())
atom_slices = all_dsets[1]._slices
all_dsets.append(data.BondDataset(all_data[2], featurizer)) if bond_Y else all_dsets.append(data.MockDataset())
bond_slices = all_dsets[2]._slices
all_dset = data.MolAtomBondDataset(all_dsets[0], all_dsets[1], all_dsets[2])

  all_dsets.append(data.BondDataset(all_data[2], featurizer)) if bond_Y else all_dsets.append(data.MockDataset())


## Get DataLoader

In [99]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers, shuffle=False)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
all_loader = data.build_dataloader(all_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [100]:
mp = nn.MixedBondMessagePassing() #include why aggregation isn't used
mp

MixedBondMessagePassing(
  (W_i): Linear(in_features=86, out_features=300, bias=False)
  (W_h): Linear(in_features=300, out_features=300, bias=False)
  (W_o): Linear(in_features=372, out_features=300, bias=True)
  (W_o_b): Linear(in_features=314, out_features=300, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (tau): ReLU()
  (V_d_transform): Identity()
  (E_d_transform): Identity()
  (graph_transform): Identity()
)

## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `nn.PredictorRegistry`.

For regression:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

For classification:
- `ffn = nn.BinaryClassificationFFN()`
- `ffn = nn.BinaryDirichletFFN()`
- `ffn = nn.MulticlassClassificationFFN()`
- `ffn = nn.MulticlassDirichletFFN()`

For spectral:
- `ffn = nn.SpectralFFN()` # will be available in future version

In [101]:
print(nn.PredictorRegistry)

ClassRegistry {
    'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,
    'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,
    'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,
    'regression-quantile': <class 'chemprop.nn.predictors.QuantileFFN'>,
    'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,
    'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,
    'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,
    'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,
    'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>
}


In [102]:
mol_output_transform = nn.UnscaleTransform.from_standard_scaler(mol_scaler)
atom_output_transform = nn.UnscaleTransform.from_standard_scaler(atom_scaler)
bond_output_transform = nn.UnscaleTransform.from_standard_scaler(bond_scaler)

In [103]:
mol_ffn = nn.RegressionFFN(output_transform=mol_output_transform, n_tasks=1) # one ffn for each type
atom_ffn = nn.RegressionFFN(output_transform=atom_output_transform, n_tasks=1)
bond_ffn = nn.RegressionFFN(output_transform=bond_output_transform, input_dim=600, n_tasks=1)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [104]:
batch_norm = True

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

Available options can be found in `metrics.MetricRegistry`, including

In [105]:
print(nn.metrics.MetricRegistry)

ClassRegistry {
    'mse': <class 'chemprop.nn.metrics.MSE'>,
    'mae': <class 'chemprop.nn.metrics.MAE'>,
    'rmse': <class 'chemprop.nn.metrics.RMSE'>,
    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSE'>,
    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAE'>,
    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSE'>,
    'r2': <class 'chemprop.nn.metrics.R2Score'>,
    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,
    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,
    'roc': <class 'chemprop.nn.metrics.BinaryAUROC'>,
    'prc': <class 'chemprop.nn.metrics.BinaryAUPRC'>,
    'accuracy': <class 'chemprop.nn.metrics.BinaryAccuracy'>,
    'f1': <class 'chemprop.nn.metrics.BinaryF1Score'>
}


In [106]:
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()] # Only the first metric is used for training and early stopping

## Constructs MolAtomBondMPNN

In [107]:
agg = nn.MeanAggregation()
mol_atom_bond_mpnn = models.MolAtomBondMPNN(mp, agg, mol_ffn, atom_ffn, bond_ffn, batch_norm, metric_list)

mol_atom_bond_mpnn

MolAtomBondMPNN(
  (message_passing): MixedBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (W_o_b): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictors): ModuleList(
    (0-1): 2 x RegressionFFN(
      (ffn): MLP(
        (0): Sequential(
          (0): Linear(in_features=300, out_features=300, bias=True)
        )
        (1): Sequential(
          (0): ReLU()
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=300, out_features=1, bias=True)
        )
      )
      (criterion): MSE(task_weights=[[1.0]

# Set up trainer

In [108]:
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="cpu",
    devices=1,
    max_epochs=20, # number of epochs to train for
    callbacks=[checkpointing],
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


# Start training

In [109]:
trainer.fit(mol_atom_bond_mpnn, train_loader, val_loader)

/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/brianli/Documents/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | MixedBondMessagePassing | 322 K  | train
1 | agg             | MeanAggregation         | 0      | train
2 | bn              | BatchNorm1d             | 600    | train
3 | predictors      | ModuleList              | 361 K  | train
4 | X_d_transform

                                                                                

/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0: 100%|███████████████████████████████████| 1/1 [00:00<00:00, 111.72it/s]
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                         | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|███████████████████| 1/1 [00:00<00:00, 140.16it/s][A
Epoch 1: 100%|███████████████████| 1/1 [00:00<00:00, 117.62it/s, val_loss=3.480][A
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                         | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|███████████████████| 1/1 [00:00<00:00, 172.18it/s][A
Epoch 2: 100%|███████████████████| 1/1 [00:00<00:00, 121.70it/s, val_loss=3.390][A
Validation: |                                             | 0/? [00:00<?, ?it/s

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|███████████████████| 1/1 [00:00<00:00, 25.85it/s, val_loss=2.180]


# Test results

In [110]:
results = trainer.test(mol_atom_bond_mpnn, test_loader)

/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████████████████| 1/1 [00:00<00:00, 170.27it/s]


# Predictions

In [111]:
mol_individual_preds, atom_individual_preds, bond_individual_preds = [], [], []
model = load_mixed_model(checkpointing.best_model_path)

trainer = pl.Trainer(
    logger=False,
    enable_progress_bar=True,
    accelerator="cpu",
    devices=1,
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [112]:
predss = trainer.predict(model, all_loader)

/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|███████████████████| 1/1 [00:00<00:00, 309.13it/s]


In [113]:
mol_individual_preds.append(torch.concat([predss[0][0]], 0))
atom_individual_preds.append(torch.concat([predss[0][1]], 0))
bond_individual_preds.append(torch.concat([predss[0][2]], 0))

mol_average_preds = torch.mean(torch.stack(mol_individual_preds).float(), dim=0)
atom_average_preds = torch.mean(torch.stack(atom_individual_preds).float(), dim=0)
bond_average_preds = torch.mean(torch.stack(bond_individual_preds).float(), dim=0)

In [114]:
test_path = chemprop_dir / "tests" / "data" / "mixed_regression_input.csv"
df_test = pd.read_csv(test_path, header="infer", index_col=False)

## Loaded Model

In [115]:
model

MolAtomBondMPNN(
  (message_passing): MixedBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (W_o_b): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictors): ModuleList(
    (0-1): 2 x RegressionFFN(
      (ffn): MLP(
        (0): Sequential(
          (0): Linear(in_features=300, out_features=300, bias=True)
        )
        (1): Sequential(
          (0): ReLU()
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=300, out_features=1, bias=True)
        )
      )
      (criterion): MSE(task_weights=[[1.0]

# Output Predictions

In [116]:
target_cols = df_test.columns.tolist()
mol_cols, atom_cols, bond_cols = [], [], []
for i in range(1, len(target_cols)):
    index = 0
    column_type = df_test.iloc[index][target_cols[i]]
    if isinstance(column_type, float):
        mol_cols.append(i)
    else:
        column_mol = make_mol(df_test.iloc[index][target_cols[0]], False, False)
        column_type = ast.literal_eval(column_type)
        while index < len(df_test) and column_mol.GetNumAtoms() == column_mol.GetNumBonds():
            index += 1
            column_mol = make_mol(df_test.iloc[index][target_cols[0]])
        column_type = ast.literal_eval(df_test.iloc[index][target_cols[i]])
        atom_cols.append(i) if len(column_type) == column_mol.GetNumAtoms() else bond_cols.append(i)

In [117]:
df_test.iloc[:, mol_cols] = mol_average_preds.tolist()

for i in range(len(df_test)):
    if atom_slices is not None:
        first_atom = atom_slices.index(i)
        last_atom = first_atom + atom_slices.count(i)
        atom_preds = atom_average_preds[first_atom:last_atom]
        df_test.iloc[i, atom_cols] = [str(atom_preds[:,j].tolist()) for j in range(len(atom_cols))]

    if bond_slices is not None:
        first_bond = bond_slices.index(i)
        last_bond = first_bond + bond_slices.count(i)
        bond_preds = bond_average_preds[first_atom:last_atom]
        df_test.iloc[i, bond_cols] = [str(bond_preds[:,j].tolist()) for j in range(len(bond_cols))]

In [118]:
output_path = chemprop_dir / "tests" / "data" / "mixed_regression_output.csv"
if output_path.suffix == ".pkl":
    df_test = df_test.reset_index(drop=True)
    df_test.to_pickle(output_path)
else:
    df_test.to_csv(output_path, index=False)

df_test

Unnamed: 0,smiles,molecule,atom
0,CC,5.34214,"[2.569624900817871, 2.569624900817871]"
1,CCC,5.233238,"[2.4788389205932617, 2.5985231399536133, 2.478..."
2,CCCO,6.901553,"[2.4710915088653564, 2.505635976791382, 2.7169..."
3,CCOO,6.782914,"[2.4493448734283447, 2.6188316345214844, 3.430..."
4,COO,6.48114,"[2.5182647705078125, 3.493091344833374, 3.4651..."
5,CCOOO,6.913412,"[2.451606273651123, 2.626923084259033, 3.22619..."
6,COOO,6.779137,"[2.5419275760650635, 3.3292524814605713, 3.465..."
7,CO,6.977914,"[2.758082389831543, 3.42410945892334]"
8,CCO,6.894196,"[2.4541964530944824, 2.7911367416381836, 3.389..."
9,OO,7.73526,"[3.6392979621887207, 3.6392979621887207]"
