# Sobolev Demo

This notebook demonstrates how to setup sobolev training for a single molecule property prediction task.
The actual `fastsolv` model takes two molecules at a time, but the core idea of Sobolev training is the same.

## Preparing the Data
Your data should come in a format that looks something like what is shown below that can be loaded into a `pandas.DataFrame`.
For this example, we are calculating the gas phase heat capacity for various small molecules.

In [1]:
import pandas as pd
from io import StringIO

In [2]:
data_csv = '''
Name of compounds,Formula,SMILES,CAS,Cp (J/mol.K),T (K)
Formaldehyde,CH2O,C=O,50-00-0,33.5,200.0
Formaldehyde,CH2O,C=O,50-00-0,34.7,273.15
Formaldehyde,CH2O,C=O,50-00-0,35.44,300.0
Formaldehyde,CH2O,C=O,50-00-0,39.24,400.0
Formic acid,CH2O2,O=CO,64-18-6,37.83,200.0
Formic acid,CH2O2,O=CO,64-18-6,43.54,273.15
Formic acid,CH2O2,O=CO,64-18-6,45.84,300.0
Formic acid,CH2O2,O=CO,64-18-6,54.52,400.0
'''

In [3]:
df = pd.read_csv(StringIO(data_csv))
df

Unnamed: 0,Name of compounds,Formula,SMILES,CAS,Cp (J/mol.K),T (K)
0,Formaldehyde,CH2O,C=O,50-00-0,33.5,200.0
1,Formaldehyde,CH2O,C=O,50-00-0,34.7,273.15
2,Formaldehyde,CH2O,C=O,50-00-0,35.44,300.0
3,Formaldehyde,CH2O,C=O,50-00-0,39.24,400.0
4,Formic acid,CH2O2,O=CO,64-18-6,37.83,200.0
5,Formic acid,CH2O2,O=CO,64-18-6,43.54,273.15
6,Formic acid,CH2O2,O=CO,64-18-6,45.84,300.0
7,Formic acid,CH2O2,O=CO,64-18-6,54.52,400.0


From here, we can calculate the features needed for our models during training.
This is adapted from the code in `data/utils.py`.

In [4]:
df = df.rename(columns={"SMILES":"smiles", "Cp (J/mol.K)" :"target", "T (K)":"independent_variable"})

In [5]:
import numpy as np
from fastprop.defaults import ALL_2D
from fastprop.descriptors import get_descriptors
from rdkit import Chem


def get_descs(src_df: pd.DataFrame):
    """Calculates features for molecules

    Args:
        src_df (pd.DataFrame): DataFrame with 'smiles', 'target', 'independent_variable'.
                               Other columns will be ignored.
    """
    unique_smiles: np.ndarray = pd.unique(src_df["smiles"])
    descs: np.ndarray = get_descriptors(False, ALL_2D, list(Chem.MolFromSmiles(i) for i in unique_smiles)).to_numpy(dtype=np.float32)
    # assemble the data into the format expected in fastprop
    # map smiles -> descriptors
    smiles_to_descs: dict = {smiles: desc for smiles, desc in zip(unique_smiles, descs)}
    fastprop_data: pd.DataFrame = src_df[["smiles", "target", "independent_variable"]]
    fastprop_data: pd.DataFrame = fastprop_data.reindex(columns=fastprop_data.columns.tolist() + ALL_2D)
    fastprop_data[ALL_2D] = [smiles_to_descs[smi] for smi in fastprop_data["smiles"]]
    return fastprop_data

In [6]:
data: pd.DataFrame = get_descs(df)
data


100%|██████████| 2/2 [00:00<00:00, 41.29it/s]


Unnamed: 0,smiles,target,independent_variable,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,...,SRW10,TSRW10,MW,AMW,WPath,WPol,Zagreb1,Zagreb2,mZagreb1,mZagreb2
0,C=O,33.5,200.0,0.0,0.0,0.0,0.0,2.0,1.0,2.0,...,1.098612,7.493062,30.010565,7.502641,1.0,0.0,2.0,1.0,2.0,1.0
1,C=O,34.7,273.15,0.0,0.0,0.0,0.0,2.0,1.0,2.0,...,1.098612,7.493062,30.010565,7.502641,1.0,0.0,2.0,1.0,2.0,1.0
2,C=O,35.44,300.0,0.0,0.0,0.0,0.0,2.0,1.0,2.0,...,1.098612,7.493062,30.010565,7.502641,1.0,0.0,2.0,1.0,2.0,1.0
3,C=O,39.24,400.0,0.0,0.0,0.0,0.0,2.0,1.0,2.0,...,1.098612,7.493062,30.010565,7.502641,1.0,0.0,2.0,1.0,2.0,1.0
4,O=CO,37.83,200.0,1.414214,1.414214,1.0,0.0,2.828427,1.414214,2.828427,...,4.174387,17.31077,46.005478,9.201096,4.0,0.0,6.0,4.0,2.25,1.0
5,O=CO,43.54,273.15,1.414214,1.414214,1.0,0.0,2.828427,1.414214,2.828427,...,4.174387,17.31077,46.005478,9.201096,4.0,0.0,6.0,4.0,2.25,1.0
6,O=CO,45.84,300.0,1.414214,1.414214,1.0,0.0,2.828427,1.414214,2.828427,...,4.174387,17.31077,46.005478,9.201096,4.0,0.0,6.0,4.0,2.25,1.0
7,O=CO,54.52,400.0,1.414214,1.414214,1.0,0.0,2.828427,1.414214,2.828427,...,4.174387,17.31077,46.005478,9.201096,4.0,0.0,6.0,4.0,2.25,1.0


Now we need to calculate the gradients of our target property with respect to the independent variable.
To do so, we will first write a helper function that calculate the gradients:

In [7]:
def _f(r):
    if len(r["scaled_target"]) == 1:
        return [np.nan]
    sorted_idxs = np.argsort(r["scaled_independent_variable"])
    unsort_idxs = np.argsort(sorted_idxs)
    # mask out enormous nan/inf
    grads = [
        i if np.isfinite(i) else np.nan
        for i in np.gradient(
            [r["scaled_target"][i] for i in sorted_idxs],
            [r["scaled_independent_variable"][i] for i in sorted_idxs],
        )
    ]
    return [grads[i] for i in unsort_idxs]

Typically you would need to decide which points to use for validation and which for training - for this simple demo, we will train without a validation set.
Keep in mind that because we are using gradients, you must include all rows of your data for a given molecule at every temperature in either training _or_ validation to avoid leaking data.
For example, in this case we would not want to included formaldehyde at 200 K in training and then 273 K in validation.
This is because when we calculate the gradient at 200 K we use information about the point at 273 K, which would leak data if they are not kept together!

Start by re-scaling the data.

In [8]:
from sklearn.preprocessing import StandardScaler

In [9]:
target_scaler = StandardScaler().fit(data[["target"]])
scaled_target = target_scaler.transform(data[["target"]]).ravel()
independent_variable_scaler = StandardScaler().fit(data[["independent_variable"]])
scaled_independent_variable = independent_variable_scaler.transform(df[["independent_variable"]]).ravel()

From here, we just use pandas to calculate the gradients - the inline comments provide greater detail on how it actually works (it's rather involved).

In [10]:

grads = pd.concat(
    (
        df,
        pd.DataFrame(
            {
                "source_index": np.arange(len(df["independent_variable"])),
                "scaled_independent_variable": scaled_independent_variable,
                "scaled_target": scaled_target,
            }
        ),
    ),
    axis=1,
)
# group the data by molecule
grads = grads.groupby(["smiles"])[["scaled_target", "scaled_independent_variable", "source_index"]].aggregate(list)
# calculate the gradient at each measurement of target wrt independent_variable
grads["grad"] = grads.apply(_f, axis=1)
# get them in the same order as the source data
grads = grads.explode(["grad", "source_index"]).sort_values(by="source_index")
# convert and mask
grads = grads["grad"].to_numpy(dtype=np.float32)
_mask = np.isnan(grads)

The data is ready!
We can re-use this same input file for both a `fastprop`- and `chemprop`-based model.
For large, real-world datasets, it is highly suggested to save the features to a file (like a `.csv` with `df.to_csv`) to allow re-loading it without recalculating in the future.

Later on in this notebook we will actually pack these values into data classes that `torch` knows how to use.

## Setting up the Models

The model definitions for the original study are in `models/fastprop` and `models/chemprop`.
Each expects two molecules, so we will adapt that code here to work for our one molecule case.

For the sake of demonstration, the below variable can be changed to enable or disable the use of Sobolev loss.

In [11]:
ENABLE_SOBOLEV_LOSS = True

### `fastprop`-based Model

First, we'll define the model code for a `fastprop`-based model.
Most of the model functionality is handled by the `fastprop` model definition - you can find the source for `fastprop` [here](https://github.com/JacksonBurns/fastprop).

In [12]:
from fastprop.model import fastprop as _fastprop
from fastprop.data import standard_scale, inverse_standard_scale
import torch

This concatenation class will just be used to combine the input temperature with the molecule features.

In [13]:
class Concatenation(torch.nn.Module):
    def forward(self, batch):
        return torch.cat(batch, dim=1)

In [14]:
from typing import OrderedDict

In [15]:
class fastpropSobolev(_fastprop):
    def __init__(
        self,
        num_layers: 2,
        hidden_size: 1_800,
        num_features: int = 1613,
        learning_rate: float = 0.0001,
        target_means: torch.Tensor = None,
        target_vars: torch.Tensor = None,
        feature_means: torch.Tensor = None,
        feature_vars: torch.Tensor = None,
        independent_variable_means: torch.Tensor = None,
        independent_variable_vars: torch.Tensor = None,
    ):
        super().__init__(
            input_size=num_features + 1,
            hidden_size=hidden_size,
            fnn_layers=num_layers,
            readout_size=1,
            num_tasks=1,
            learning_rate=learning_rate,
            problem_type="regression",
            target_names=[],
            target_means=target_means,
            target_vars=target_vars,
            feature_means=feature_means,
            feature_vars=feature_vars,
            clamp_input=True,
        )

        # for later predicting
        self.register_buffer("independent_variable_means", independent_variable_means)
        self.register_buffer("independent_variable_vars", independent_variable_vars)

        # add concatenation of the temperature to the input features
        _modules = OrderedDict()
        _modules['concatenate'] = Concatenation()
        for name, module in self.fnn.named_children():
            _modules[name] = module
        self.fnn = torch.nn.Sequential(_modules)

        self.save_hyperparameters()

    def predict_step(self, batch):
        err_msg = ""
        for stat_obj, stat_name in zip(
            (
                self.feature_means,
                self.feature_vars,
                self.independent_variable_means,
                self.independent_variable_vars,
                self.target_means,
                self.target_vars,
            ),
            (
                "feature_means",
                "feature_vars",
                "independent_variable_means",
                "independent_variable_vars",
                "target_means",
                "target_vars",
            ),
        ):
            if stat_obj is None:
                err_msg.append(f"{stat_name} is None!\n")
        if err_msg:
            raise RuntimeError("Missing scaler statistics!\n" + err_msg)

        features, independent_variables = batch[0]  # batch 1 is solubility
        features = standard_scale(features, self.feature_means, self.feature_vars)
        independent_variables = standard_scale(independent_variables, self.independent_variable_means, self.independent_variable_vars)
        with torch.inference_mode():
            logits = self.forward((features, independent_variables))
        return inverse_standard_scale(logits, self.target_means, self.target_vars)

    @torch.enable_grad()
    def _custom_loss(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor], name: str):
        (_features, independent_variable), y, y_grad = batch
        independent_variable.requires_grad_()
        y_hat: torch.Tensor = self.forward((_features, independent_variable))
        y_loss = torch.nn.functional.mse_loss(y_hat, y, reduction="mean")
        (y_grad_hat,) = torch.autograd.grad(
            y_hat,
            independent_variable,
            grad_outputs=torch.ones_like(y_hat),
            retain_graph=True,
        )
        _scale_factor = 10.0
        y_grad_loss = _scale_factor * (y_grad_hat - y_grad).pow(2).nanmean()  # MSE ignoring nan
        loss = y_loss + y_grad_loss
        self.log(f"{name}/{self.training_metric}_scaled_loss", loss)
        self.log(f"{name}/logS_scaled_loss", y_loss)
        self.log(f"{name}/dlogSdT_scaled_loss", y_grad_loss)
        return loss, y_hat

    def _plain_loss(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor], name: str):
        (_features, independent_variable), y, _ = batch
        y_hat: torch.Tensor = self.forward((_features, independent_variable))
        loss = torch.nn.functional.mse_loss(y_hat, y, reduction="mean")
        self.log(f"{name}/{self.training_metric}_scaled_loss", loss)
        return loss, y_hat

    def _loss(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor], name: str):
        if ENABLE_SOBOLEV_LOSS:
            return self._plain_loss(batch, name)
        else:
            return self._custom_loss(batch, name)

    def training_step(self, batch, batch_idx):
        return self._loss(batch, "train")[0]

    def validation_step(self, batch, batch_idx):
        loss, y_hat = self._loss(batch, "validation")
        self._human_loss(y_hat, batch, "validation")
        return loss

    def test_step(self, batch, batch_idx):
        loss, y_hat = self._loss(batch, "test")
        self._human_loss(y_hat, batch, "test")
        return loss

### `chemprop`-based Model

The one 'trick' for the `chemprop`-derived model is the use of a custom metric.
This code specifically works for `chemprop` version 2.0 - it would need some (small) changes to work with 2.1 or newer, which changed the way that metrics work.

In [16]:
from chemprop import models
from chemprop.nn import metrics

In [17]:
class CustomMSEMetric(metrics.MSEMetric):
    def forward(self, preds, targets, mask, weights, lt_mask, gt_mask):
        return torch.nn.functional.mse_loss(preds, targets[:, 0, None], reduction="mean")


class SobolevMPNN(models.MPNN):
    def training_step(self, batch, batch_idx):
        return self._sobolev_loss(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self._sobolev_loss(batch, "val")

    @torch.enable_grad()
    def _sobolev_loss(self, batch, name):
        bmg, V_d, X_d, targets, *_ = batch
        # track grad for temperature
        X_d.requires_grad_()
        Z = self.fingerprint(bmg, V_d, X_d)
        y_hat = self.predictor.train_step(Z)
        y_loss = torch.nn.functional.mse_loss(y_hat, targets[:, 0, None], reduction="mean")
        (y_grad_hat,) = torch.autograd.grad(
            y_hat,
            X_d,
            grad_outputs=torch.ones_like(y_hat),
            retain_graph=True,
        )
        _scale_factor = 1.0
        y_grad_loss = _scale_factor * (y_grad_hat - targets[:, 1]).pow(2).nanmean()  # MSE ignoring nan
        loss = y_loss + y_grad_loss
        self.log(f"{name}/sobolev_loss", loss, batch_size=len(batch[0]))
        self.log(f"{name}/logs_loss", y_loss, batch_size=len(batch[0]))
        self.log(f"{name}/grad_loss", y_grad_loss, batch_size=len(batch[0]))
        self.log(f"{name}_loss", loss, prog_bar=True, batch_size=len(batch[0]))
        return loss


## Working with `torch`

Now that we have our data and models we can start setting up our code to interact with `torch`.

### `fastprop` Data

We'll start by defining a new dataset class to help interoperate our data with `fastprop`.

In [18]:
from torch.utils.data import Dataset as TorchDataset

In [19]:
class SolubilityDataset(TorchDataset):
    def __init__(
        self,
        features: torch.Tensor,
        independent_variable: torch.Tensor,
        target: torch.Tensor,
        target_gradient: torch.Tensor,
    ):
        self.features = features
        self.independent_variable = independent_variable
        self.target = target
        self.target_gradient = target_gradient
        self.length = features.shape[0]

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        return (
            (
                self.features[index],
                self.independent_variable[index],
            ),
            self.target[index],
            self.target_gradient[index],
        )

Rescale the input feature for fastprop (we already rescaled the target and independent variable).
If we had a validation set, we would want to be very careful to apply the scaler fit on just the _training_ data to the _validation_ data, again avoiding data leaks!

In [20]:
features, feature_means, feature_vars = standard_scale(torch.tensor(data[ALL_2D].to_numpy(dtype=np.float32)))
tens = lambda d: torch.tensor(d, dtype=torch.float32)
target_means, target_vars = tens(target_scaler.mean_), tens(target_scaler.var_)
independent_variable_means, independent_variable_vars = tens(independent_variable_scaler.mean_), tens(independent_variable_scaler.var_)

We use the `[:, None]` syntax to insert a second dimension all of our 1D tensors - this makes concatenating things easier later on.

In [21]:
fastprop_dataset = SolubilityDataset(
    features=features,
    independent_variable=tens(scaled_independent_variable)[:, None],
    target=tens(scaled_target)[:, None],
    target_gradient=tens(grads)[:, None],
)

### `chemprop` Data

We can use the existing `chemprop` classes right out-of-the-box!

In [22]:
from chemprop.data import MoleculeDatapoint, MoleculeDataset
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer

In [23]:
chemprop_data = [
    MoleculeDatapoint.from_smi(smi, [tgt, grd], x_d=np.array([ind]))
    for smi, tgt, grd, ind in zip(data["smiles"], scaled_target, grads, data["independent_variable"])
]
featurizer = SimpleMoleculeMolGraphFeaturizer()
chemprop_dataset = MoleculeDataset(chemprop_data, featurizer)
chemprop_dataset.normalize_inputs("X_d", independent_variable_scaler)
chemprop_dataset.cache = True
# if we had a validation set, we would want to set it up here.
# chemprop 2.0 docs say to do this, but it actually gets scaled during training
# this is fixed in chemprop 2.1 (but we wrote this code in 2.0) so stick to 2.0 for now
# val_dataset = MoleculeDataset(chemprop_validation_data, featurizer)
# val_dataset.normalize_inputs("X_d", independent_variable_sccaler)
# val_dataset.cache = True       



## Training

Now we can finally train the model!
This is just the typical pytorch lightning training setup from here on out.
This notebook shows how to run both the `fastprop`- and `chemprop`-based model.

The last small bit of difficulty is that `chemprop` and `fastprop` need a slightly different setup for `Trainer` and other imports because `lightning` changed their syntax in between the time period when the two were developed.

In [24]:
from torch.utils.data import DataLoader

from chemprop.data.dataloader import build_dataloader
from chemprop import nn
from lightning.pytorch.callbacks import ModelCheckpoint as NewModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger as NewTensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint as OldModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger as OldTensorBoardLogger

# from lightning.pytorch.callbacks.early_stopping import EarlyStopping  <-- uncomment if you have a validation set for early stopping (recommended!!)

In [25]:
fastprop_dataloader = DataLoader(fastprop_dataset, batch_size=4)
chemprop_loader = build_dataloader(chemprop_dataset, batch_size=4)

In [26]:
fastprop_model = fastpropSobolev(
    num_layers=2,
    hidden_size=400,
    num_features=len(ALL_2D),
    target_means=target_means,
    target_vars=target_vars,
    feature_means=feature_means,
    feature_vars=feature_vars,
    independent_variable_means=independent_variable_means,
    independent_variable_vars=independent_variable_vars,
)
fastprop_model

fastpropSobolev(
  (fnn): Sequential(
    (concatenate): Concatenation()
    (clamp): ClampN(n=3)
    (lin1): Linear(in_features=1614, out_features=400, bias=True)
    (act1): ReLU()
    (lin2): Linear(in_features=400, out_features=400, bias=True)
  )
  (readout): Linear(in_features=400, out_features=1, bias=True)
)

In [27]:
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()
output_transform = nn.UnscaleTransform.from_standard_scaler(target_scaler)
ffn = nn.RegressionFFN(
    input_dim=mp.output_dim + 1,  # temperature
    n_layers=2,
    criterion=CustomMSEMetric(),
    output_transform=output_transform,
)
X_d_transform = nn.ScaleTransform.from_standard_scaler(independent_variable_scaler)
metric_list = [CustomMSEMetric()]
chemprop_model = SobolevMPNN(
    mp,
    agg,
    ffn,
    batch_norm=True,
    metrics=metric_list,
    X_d_transform=X_d_transform,
)
chemprop_model

SobolevMPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=301, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=300, bias=True)
      )
      (2): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): CustomMSEMetric(task

Normally we would want to use early stopping - for this demo notebook we don't have a validation set, but the code needed to implement early stopping is shown in the comments.

In [28]:
from pathlib import Path

In [29]:
_outdir = Path("demo_output")

Finally, the actual training call!

In [30]:
from lightning.pytorch import Trainer as NewTrainer
from pytorch_lightning import Trainer as OldTrainer

In [31]:
tensorboard_logger = OldTensorBoardLogger(
    _outdir,
    name="tensorboard_logs",
    default_hp_metric=False,
)
callbacks = [
    # EarlyStopping(
    #     monitor="val_loss",
    #     mode="min",
    #     verbose=False,
    #     patience=15,
    # ),
    OldModelCheckpoint(
        # monitor="val_loss",  <-- uncomment these lines if early stopping
        # save_top_k=1,
        # mode="min",
        dirpath=_outdir / "checkpoints",
    ),
]
trainer = OldTrainer(
    max_epochs=1,
    logger=tensorboard_logger,
    log_every_n_steps=1,
    enable_checkpointing=True,
    check_val_every_n_epoch=1,
    callbacks=callbacks,
    # REQUIRED!! to enable sobolev loss during validation
    inference_mode=False,
)
trainer.fit(fastprop_model, fastprop_dataloader)  # , val_loader)  <-- include this if early stopping

# the rest of this shows how to reload the best model from early stopping and make predictions
# ckpt_path = trainer.checkpoint_callback.best_model_path
# print(f"Reloading best model from checkpoint file: {ckpt_path}")
# fastprop_model = fastprop_model.__class__.load_from_checkpoint(ckpt_path)
# val_results = trainer.validate(mcmpnn, val_loader)
# predictions = trainer.predict(fastprop_model, predict_dataloader)  <-- use this to make predictions

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/jackson/miniconda3/envs/fastsolv/lib/python3.11/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | fnn     | Sequential | 806 K  | train
1 | readout | Linear     | 401    | train
-----------------------------------------------
806 K     Trainable params
0         Non-trainable params
806 K     Total params
3.227     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 2/2 [00:00<00:00, 16.28it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 2/2 [00:00<00:00, 13.64it/s, v_num=0]


In [32]:
tensorboard_logger = NewTensorBoardLogger(
    _outdir,
    name="tensorboard_logs",
    default_hp_metric=False,
)
callbacks = [
    # EarlyStopping(
    #     monitor="val_loss",
    #     mode="min",
    #     verbose=False,
    #     patience=15,
    # ),
    NewModelCheckpoint(
        # monitor="val_loss",  <-- uncomment these lines if early stopping
        # save_top_k=1,
        # mode="min",
        dirpath=_outdir / "checkpoints",
    ),
]
trainer = NewTrainer(
    max_epochs=1,
    logger=tensorboard_logger,
    log_every_n_steps=1,
    enable_checkpointing=True,
    check_val_every_n_epoch=1,
    callbacks=callbacks,
    # REQUIRED!! to enable sobolev loss during validation
    inference_mode=False,
)
trainer.fit(chemprop_model, chemprop_loader)

# ... same logic as above applies for loading the chemprop model and running inference

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/jackson/miniconda3/envs/fastsolv/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/home/jackson/miniconda3/envs/fastsolv/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/jackson/fastsolv/paper/demo_output/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K  | train
1 | agg             | MeanAggregation    | 0      | train
2 | bn              | BatchNorm1d        | 600    | train
3 | predictor       | RegressionFFN      | 181 K  | 

Epoch 0: 100%|██████████| 2/2 [00:00<00:00, 12.20it/s, v_num=1, train_loss=0.708]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 2/2 [00:00<00:00, 11.28it/s, v_num=1, train_loss=0.708]
