# Training Classification

In [None]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

# Import packages

In [None]:
import pandas as pd
from pathlib import Path

from lightning import pytorch as pl

from chemprop import data, featurizers, models, nn

# Change data inputs here

In [None]:
input_path = "/content/s_aureus_43300_MASSA_split_final.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = "smiles" # name of the column containing SMILES strings
target_column = ["label"] # classification of activity (either 0 or 1)

## Load data

In [None]:
df_input = pd.read_csv(input_path)
df_input

## Get SMILES and targets

Data splitting into training, testing, and validation was performed in advance using MASSA Algorithm: https://github.com/gcverissimo/MASSA_Algorithm

In [None]:
df_input["split"].value_counts()

In [None]:
df_train = df_input[df_input["split"] == "Treino"]
df_test = df_input[df_input["split"] == "Teste"]
df_val = df_input[df_input["split"] == "Validacao"]

In [None]:
smis_train = df_train.loc[:, smiles_column].values
ys_train = df_train.loc[:, target_column].values
smis_train[:5], ys_train[:5]

In [None]:
smis_test = df_test.loc[:, smiles_column].values
ys_test = df_test.loc[:, target_column].values
smis_test[:5], ys_test[:5]

In [None]:
smis_val = df_val.loc[:, smiles_column].values
ys_val = df_val.loc[:, target_column].values
smis_val[:5], ys_val[:5]

## Get molecule datapoints

In [None]:
data_train = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis_train, ys_train)]
data_test = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis_test, ys_test)]
data_val = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis_val, ys_val)]

## Get MoleculeDataset

In [None]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(data_train, featurizer)
test_dset = data.MoleculeDataset(data_test, featurizer)
val_dset = data.MoleculeDataset(data_val, featurizer)

## Get DataLoader

In [None]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)

# Change Message-Passing Neural Network (MPNN) inputs here

## Message Passing
A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.

Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`

In [None]:
mp = nn.AtomMessagePassing()

## Aggregation
An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.

Available options can be found in ` nn.agg.AggregationRegistry`, including
- `agg = nn.MeanAggregation()`
- `agg = nn.SumAggregation()`
- `agg = nn.NormAggregation()`

In [None]:
print(nn.agg.AggregationRegistry)

In [None]:
agg = nn.MeanAggregation()

## Feed-Forward Network (FFN)

A `FFN` takes the aggregated representations and make target predictions.

Available options can be found in `nn.PredictorRegistry`.

For regression:
- `ffn = nn.RegressionFFN()`
- `ffn = nn.MveFFN()`
- `ffn = nn.EvidentialFFN()`

For classification:
- `ffn = nn.BinaryClassificationFFN()`
- `ffn = nn.BinaryDirichletFFN()`
- `ffn = nn.MulticlassClassificationFFN()`
- `ffn = nn.MulticlassDirichletFFN()`

For spectral:
- `ffn = nn.SpectralFFN()` # will be available in future version

In [None]:
print(nn.PredictorRegistry)

In [None]:
ffn = nn.BinaryClassificationFFN(n_tasks = 1)

## Batch Norm
A `Batch Norm` normalizes the outputs of the aggregation by re-centering and re-scaling.

Whether to use batch norm

In [None]:
batch_norm = False

## Metrics
`Metrics` are the ways to evaluate the performance of model predictions.

Available options can be found in `metrics.MetricRegistry`, including

In [None]:
print(nn.metrics.MetricRegistry)

In [None]:
# AUROC used by default

metric_list = [nn.metrics.BinaryMCCMetric(), nn.metrics.BinaryAUROC()]

## Constructs MPNN

In [None]:
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)

mpnn

# Set up trainer

In [None]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="cpu",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

# Start training

In [None]:
trainer.fit(mpnn, train_loader, test_loader)

# Test results

In [None]:
results = trainer.test(mpnn, val_loader)