# `fastprop` Classification with `polaris` Demo
This notebook demonstrates training `fastprop` on a binary classification dataset using the `polaris` benchmarking library.

Requires:
 - fastprop
 - polaris-lib

## Retrieiving the Data

After running `polaris login` on the command line, we can use this code to access the dataset (follow [this link](https://polarishub.io/benchmarks/polaris/pkis2-egfr-wt-c-1) to learn more about this task):

In [1]:
%%capture
import polaris as po

benchmark = po.load_benchmark("polaris/pkis2-egfr-wt-c-1")
train, test = benchmark.get_train_test_split()

In [2]:
train_df, test_df = train.as_dataframe(), test.as_dataframe()

`polaris` returns rows in a non-deterministic order.
For reproducibility with this notebook, we will first sort to ensure that dataframe ia always in the same order.

In [3]:
train_df = train_df.sort_values("smiles")
train_df.head(6)

Unnamed: 0,smiles,CLASS_EGFR
72,C=CC(=O)Nc1cc2c(Nc3ccc(F)c(Cl)c3)ncnc2cc1OCCCN...,1.0
171,CC(=O)N1CCN(c2ccc(Nc3nccc(-c4sc(C(C)C)nc4-c4cc...,1.0
265,CC(=O)Nc1ccc(-c2cc3ncnc(SCC(=O)O)c3s2)cc1,0.0
251,CC(=O)Nc1ccc(COc2ccc(Nc3ccnc4cc(-c5ccccn5)ccc3...,0.0
131,CC(=O)Nc1cn2nc(Oc3cccc(NC(=O)c4cccc(C(F)(F)F)c...,0.0
67,CC(=O)Nc1n[nH]c2ncc(-c3ccccc3)cc12,0.0


We will use 20% of this data for early stopping, which we can select like this:

In [4]:
val_df = train_df.sample(frac=0.2, random_state=42)
train_df = train_df.drop(val_df.index)

## Calculating Molecular Descriptors
Now, we need to calculate the molecular descriptors for each of these species.
We will save these to a cache file so that subsequent runs are faster!

`fastprop` uses [`mordredcommunity`](https://github.com/JacksonBurns/mordred-community) to calculate molecular descriptors - if there is a different set of descriptors you want to use (e.g. padel, osmordred, etc.) you could easily replace this code with that calculator!

In [5]:
import os
import numpy as np
from rdkit.Chem import MolFromSmiles
from fastprop.descriptors import get_descriptors
from fastprop.defaults import ALL_2D
from fastprop.io import load_saved_descriptors

for name, df in (("train", train_df), ("val", val_df), ("test", test_df)):
    cache_file = f"cached_{name}_descriptors.csv"
    if not os.path.exists(cache_file):
        descriptors = get_descriptors(
            cache_file,
            ALL_2D,
            list(map(MolFromSmiles, df["smiles"])),
        ).to_numpy(dtype=np.float32)
train_descriptors = load_saved_descriptors("cached_train_descriptors.csv")
val_descriptors = load_saved_descriptors("cached_val_descriptors.csv")
test_descriptors = load_saved_descriptors("cached_test_descriptors.csv")

## Training

Now that we have descriptors, we can set up the code for training.

Molecular descriptors are prone to outliers, infinite, and invalid values so `fastprop` includes utilities to automatically impute, rescale, and Winsorize descriptors.

First, we cast everything to `torch.Tensor`:

In [6]:
import torch

train_descriptors = torch.tensor(train_descriptors, dtype=torch.float32)
val_descriptors = torch.tensor(val_descriptors, dtype=torch.float32)
test_descriptors = torch.tensor(test_descriptors, dtype=torch.float32)
train_targets = torch.tensor(train_df["CLASS_EGFR"].to_numpy(), dtype=torch.float32)[:, None]  # 2d!
val_targets = torch.tensor(val_df["CLASS_EGFR"].to_numpy(), dtype=torch.float32)[:, None]

Next, we rescale the features (and impute missing/invalid ones) and then prepare the dataloaders and model itself:

In [7]:
from fastprop.model import fastprop
from fastprop.data import fastpropDataLoader, standard_scale
from torch.utils.data import TensorDataset


train_descriptors, feature_means, feature_vars = standard_scale(train_descriptors)
val_descriptors = standard_scale(val_descriptors, feature_means, feature_vars)
# don't rescale the test_descriptors - fastprop will do this automatically during inference

train_dataloader = fastpropDataLoader(TensorDataset(train_descriptors, train_targets), shuffle=True, batch_size=16)
val_dataloader = fastpropDataLoader(TensorDataset(val_descriptors, val_targets), batch_size=1024)
test_dataloader = fastpropDataLoader(TensorDataset(test_descriptors), batch_size=1024)

model = fastprop(
    problem_type="binary",
    target_names=list(benchmark.target_cols),
    clamp_input=True,  # winsorization
    fnn_layers=2,
    hidden_size=1_800,
    feature_means=feature_means,
    feature_vars=feature_vars,
    learning_rate=0.00001,
)

The last blocks here are classical pytorch lightning training and inference:

In [8]:
from pathlib import Path

outdir = Path("demo_output")

In [9]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint


tensorboard_logger = TensorBoardLogger(
    outdir,
    name="tensorboard_logs",
    default_hp_metric=False,
)
callbacks = [
    EarlyStopping(
        monitor="validation_binary_auroc",
        mode="max",
        verbose=False,
        patience=5,
    ),
    ModelCheckpoint(
        monitor="validation_binary_auroc",
        save_top_k=1,
        mode="max",
        dirpath=outdir / "checkpoints",
    ),
]
trainer = Trainer(
    max_epochs=50,
    logger=tensorboard_logger,
    log_every_n_steps=1,
    enable_checkpointing=True,
    check_val_every_n_epoch=1,
    callbacks=callbacks,
)
trainer.fit(model, train_dataloader, val_dataloader)
ckpt_path = trainer.checkpoint_callback.best_model_path
print(f"Reloading best model from checkpoint file: {ckpt_path}")
model = model.__class__.load_from_checkpoint(ckpt_path)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
/home/jwburns/.conda/envs/fastprop_dev/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jwburns/.conda/envs/fastprop_dev/lib/python3.1 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/jwburns/.conda/envs/fastprop_dev/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/jwburns/fastprop/examples/demo_output/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISI

Epoch 31: 100%|██████████| 25/25 [00:00<00:00, 134.13it/s, v_num=9]        
Reloading best model from checkpoint file: /home/jwburns/fastprop/examples/demo_output/checkpoints/epoch=26-step=675.ckpt


In [10]:
predictions = torch.stack(trainer.predict(model, test_dataloader)).flatten().numpy()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 20.67it/s]


# Results

We can again use the handy `polaris` library to look at the results of our predictions:

In [11]:
results = benchmark.evaluate(predictions > 0.5, predictions)
results.name = "fastprop"
results.github_url = "https://github.com/JacksonBurns/fastprop/blob/main/examples/fastprop_polaris_classification_demo.ipynb"
results.paper_url = "https://github.com/JacksonBurns/fastprop/blob/main/paper/paper.pdf"
results.description = "fastprop-based FNN model"
results.tags = ["mordred", "mordredcommunity", "fastprop", "fnn"]
results.user_attributes = {"Framework": "fastprop"}
results

test_set,target_label,scores
test,CLASS_EGFR,accuracy0.9236111111111112pr_auc0.7750935064383808f10.5217391304347826cohen_kappa0.4870466321243523roc_auc0.9521484375mcc0.536591218301113
accuracy,0.9236111111111112,
pr_auc,0.7750935064383808,
f1,0.5217391304347826,
cohen_kappa,0.4870466321243523,
roc_auc,0.9521484375,
mcc,0.536591218301113,
benchmark_artifact_id,polaris/pkis2-egfr-wt-c-1,
benchmark_name,,
benchmark_owner,,

test_set,target_label,scores
test,CLASS_EGFR,accuracy0.9236111111111112pr_auc0.7750935064383808f10.5217391304347826cohen_kappa0.4870466321243523roc_auc0.9521484375mcc0.536591218301113
accuracy,0.9236111111111112,
pr_auc,0.7750935064383808,
f1,0.5217391304347826,
cohen_kappa,0.4870466321243523,
roc_auc,0.9521484375,
mcc,0.536591218301113,

0,1
accuracy,0.9236111111111112
pr_auc,0.7750935064383808
f1,0.5217391304347826
cohen_kappa,0.4870466321243523
roc_auc,0.9521484375
mcc,0.536591218301113

0,1
Framework,fastprop


Looks pretty good!
Let's upload to the `polaris` website for everyone to see (this next block is commented because it will fail unless you are logged in to `polaris`):

In [None]:
# results.upload_to_hub(owner="jacksonburns", access="public")