# Combining ChemProp, `polaris`, and `nepare`

This notebook demonstrates using ChemProp as a learnable embedding with Neural Pairwise Regression (via `nepare`) with the `polaris` benchmarking library.

## Requirements
Python 3.10+ (originally run on 3.12)
 - polaris-lib
 - pandas
 - rdkit
 - lightning
 - torch
 - ipywidgets

You will also need to run `pip install .` in the repository's root directory to install `nepare`.

## `polaris` Setup

After running `polaris login` on the command line, we can import everything (checking that the version is recent enough) and then download the benchmark data.

In [None]:
import polaris as po
import pandas as pd

In [None]:
from packaging.version import Version
assert Version(po.__version__) >= Version("0.11.6"), "test.as_dataframe does not work in earlier versions of Polaris, please upgrade"

`polaris` makes it really easy to run different benchmarks quickly - just change the name inside `load_benchmark` to try something else.
I'm using this same notebook for a few different benchmarks, all from the Fang biogen ADME paper (https://pubs.acs.org/doi/abs/10.1021/acs.jcim.3c00160) which have been made conveniently available on `polaris`.

In [None]:
%%capture
# https://polarishub.io/benchmarks/polaris/adme-fang-rppb-1
# benchmark = po.load_benchmark("polaris/adme-fang-RPPB-1")
# https://polarishub.io/benchmarks/polaris/adme-fang-solu-1
benchmark = po.load_benchmark("polaris/adme-fang-SOLU-1")

In [None]:
train, test = benchmark.get_train_test_split()
test_df: pd.DataFrame = test.as_dataframe()
train_df: pd.DataFrame = train.as_dataframe()

We'll shuffle the data just for good measure.

In [None]:
train_df = train_df.sample(frac=1.0, random_state=1701)  # shuffle the training data

In [None]:
train_df

## Learn an Embedding with ChemProp
ChemProp using Message Passing Graph Neural Networks to learn a molecular representation tailored for the problem at hand.
We can 'plug it in' to `nepare` to take advantage of that, with the additional benefit for ChemProp that it will have more training data to learn its representation.

In [None]:
val_idx = 150  # use n for validation

We'll first write a function that converts our SMILES into their ChemProp input (a `MolGraph`).

In [None]:
from rdkit.Chem import MolFromSmiles
from chemprop.featurizers import MolGraphCache, SimpleMoleculeMolGraphFeaturizer

def smiles2molgraphcache(smiles: list[str]):
    mols = list(map(MolFromSmiles, smiles))
    featurizer = SimpleMoleculeMolGraphFeaturizer()
    mgc = MolGraphCache(mols, [None] * len(mols), [None] * len(mols), featurizer)
    return mgc

In [None]:
train_mgc = smiles2molgraphcache(train_df["smiles"][val_idx:])
train_targets = train_df["LOG_SOLUBILITY"][val_idx:].to_numpy()
val_mgc = smiles2molgraphcache(train_df["smiles"][:val_idx])
val_targets = train_df["LOG_SOLUBILITY"][:val_idx].to_numpy()
test_mgc = smiles2molgraphcache(test_df["smiles"])

In [None]:
from nepare.data import PairwiseAugmentedDataset, PairwiseAnchoredDataset, PairwiseInferenceDataset

In [None]:
train_dataset = PairwiseAugmentedDataset(train_mgc, train_targets)
val_dataset = PairwiseAnchoredDataset(train_mgc, train_targets, val_mgc, val_targets)
test_dataset = PairwiseInferenceDataset(train_mgc, train_targets, test_mgc)

Next, we need to write a function to collate our `MolGraph`s and target values - ChemProp has a class for batches of `MolGraph` aptly named `BatchMolGraph`.

In [None]:
from typing import Iterable

import torch
from chemprop.data.molgraph import MolGraph
from chemprop.data.collate import BatchMolGraph

In [None]:
def _collate(batch: Iterable[tuple[MolGraph, MolGraph, float]]):
    mgs_1, mgs_2, ys = zip(*batch)  #  now need to convert y back into a tensor
    return BatchMolGraph(mgs_1), BatchMolGraph(mgs_2), torch.tensor(ys, dtype=torch.float32).reshape(-1, 1)

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=_collate)
val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, collate_fn=_collate)
predict_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, collate_fn=_collate)

Finally, we just need to define a class to take our collated batches and convert them into their learned representations.
This class can then be passed to the `nepare` class `LearnedEmbeddingNeuralPairwiseRegressor`, which will call our class on the two inputs for each batch.

In [None]:
from chemprop.conf import DEFAULT_HIDDEN_DIM
from chemprop.nn.agg import MeanAggregation
from chemprop.nn.message_passing import BondMessagePassing

from nepare.nn import LearnedEmbeddingNeuralPairwiseRegressor

In [None]:
class ChemPropEmbedder(torch.nn.Module):
    def __init__(self, mp, agg):
        super().__init__()
        self.mp = mp
        self.agg = agg

    def forward(self, bmg):
        H = self.mp(bmg)
        Z = self.agg(H, bmg.batch)
        return Z

In [None]:
mp = BondMessagePassing()
agg = MeanAggregation()
embedder = ChemPropEmbedder(mp, agg)

In [None]:
npr = LearnedEmbeddingNeuralPairwiseRegressor(embedder, DEFAULT_HIDDEN_DIM, 100, 2)

## Training and Predicting

From here on out we follow a very standard `lightning` training workflow - see `demo.ipynb` for a slightly more in-depth explanation of what's going on.

In [None]:
import lightning
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

from nepare.inference import predict

In [None]:
early_stopping = EarlyStopping(monitor="validation/loss", patience=3)
model_checkpoint = ModelCheckpoint(monitor="validation/loss")

In [None]:
trainer = lightning.Trainer(max_epochs=50, log_every_n_steps=1, callbacks=[early_stopping, model_checkpoint])
trainer.fit(npr, train_loader)

In [None]:
npr = LearnedEmbeddingNeuralPairwiseRegressor.load_from_checkpoint(model_checkpoint.best_model_path)  # reload best model based on early stopping

In [None]:
y_pred, y_stdev = predict(npr, predict_loader, how="all")

In [None]:
results = benchmark.evaluate(y_pred)

In [None]:
results.name = "nepare"
results.description = "Neural Pairwise Regression with Mordred(-community) Molecular Descriptors"
results.github_url = "https://github.com/JacksonBurns/neural-pairwise-regression/blob/main/notebooks/polaris_chemprop_nepare.ipynb"

In [None]:
results

This last line is commented out because it will fail (unless you are me) - you can replace the `owner` without your own name to upload your results (and also update the link, name, and description above).

In [None]:
# results.upload_to_hub(owner="jacksonburns", access="public")