# Training ML models 

## Prerequisites

Before starting this tutorial, you should have already worked through the tutorials on [Interfacing with databases and systems](https://asapdiscovery.readthedocs.io/en/latest/tutorials/index.html#interfacing-with-databases-and-systems) and [Docking and scoring](https://asapdiscovery.readthedocs.io/en/latest/tutorials/index.html#docking-and-scoring).

We will use publicly available test files as the starting point for this tutorial, but feel free to substitute with your own data.

In [1]:
# First download the needed files
from asapdiscovery.data.testing.test_resources import fetch_test_file

docked_files = fetch_test_file(
    [
        "ml_testing/docked/AAR-POS-5507155c-1_Mpro-P0018_0A_0_bound_best.pdb",
        "ml_testing/docked/AAR-RCN-390aeb1f-1_Mpro-P3074_0A_0_bound_best.pdb",
        "ml_testing/docked/AAR-RCN-67438d21-1_Mpro-P3074_0A_0_bound_best.pdb",
        "ml_testing/docked/AAR-POS-d2a4d1df-27_Mpro-P0238_0A_0_bound_best.pdb",
        "ml_testing/docked/AAR-POS-8a4e0f60-7_Mpro-P0053_0A_0_bound_best.pdb",
        "ml_testing/docked/AAR-RCN-28a8122f-1_Mpro-P2005_0A_0_bound_best.pdb",
        "ml_testing/docked/AAR-RCN-37d0aa00-1_Mpro-P3074_0A_0_bound_best.pdb",
        "ml_testing/docked/AAR-RCN-748c104b-1_Mpro-P3074_0A_0_bound_best.pdb",
        "ml_testing/docked/AAR-RCN-521d1733-1_Mpro-P2005_0A_0_bound_best.pdb",
        "ml_testing/docked/AAR-RCN-845f9611-1_Mpro-P2005_0A_0_bound_best.pdb",
    ]
)
docked_dir = docked_files[0].parent
unfilt_csv_file = fetch_test_file("ml_testing/cdd_unfiltered.csv")

## Intro

In this guide, we will start with a CSV file downloaded from CDD and a directory of docked protein-ligand complex PDB files. These are the two outputs from the guides mentioned in the Prerequisites section, so be sure to complete those if you haven't already.

## Preparing the experimental data

Before using the data in training, we will do some filtering and processing of the experimental data to ensure that everything is in the correct format. We will use all the default values for column names, which come from the [COVID Moonshot project](https://www.science.org/doi/10.1126/science.abo7201). See the docs for individual functions to see how these values can be tuned for your use-case.

In [2]:
from asapdiscovery.data.util.utils import (
    cdd_to_schema,
    filter_molecules_dataframe,
    parse_fluorescence_data_cdd,
)
from pathlib import Path

import pandas

# Replace this name with whatever you've saved your CDD download as
mol_df = pandas.read_csv(unfilt_csv_file)

"""
In this example, we will ultimately use this data to train both 2D and structure-based
models, so we will keep all achiral and enantiopure molecules, including any molecules
with semiquantitative fluorescence values.
"""
mol_df_filt = filter_molecules_dataframe(
    mol_df,
    retain_achiral=True,
    retain_enantiopure=True,
    retain_semiquantitative_data=True,
)

"""
In addition to being appropriately filtered, mol_df_filt now contains some identifying
colums with standardized names.

The parse_fluorescence_data_cdd function standardizes the fluorescence assay results,
adding a number of columns to the data frame. In addition to IC50 and pIC50 values, it
will also calculate deltaG in kcal/mol and kT units. If you know your fluorescence
assay conditions and they were consistent across all molecules, you can supply the
information to the cp_values arg as a tuple of (substrate_concentration, Km), which the
function will use in the Cheng-Prusoff equation to calculate the deltaG values. If you
don't have these values, the function will use a less accurate approximation. We will
exclude the Cheng-Prusoff values in this example for simplicity.

More details on the columns that the function expects the input to have and that it
adds to the output can be found in the function's docstring.
"""
mol_df_filt_processed = parse_fluorescence_data_cdd(mol_df_filt)

# Save the processed data
mol_df_filt_processed.to_csv("cdd_filtered_processed.csv", index=False)

"""
The last step in this process is to convert it into the format that the ML pipeline
expects it in. The below function does that, taking the previously generated CSV file
as input and producing a JSON file that we will load later.
"""
_ = cdd_to_schema(
    cdd_csv="cdd_filtered_processed.csv", out_json="cdd_filtered_processed.json"
)

Replacing unicode character with - in UNK-CYC-68f84b31−70
Wrote cdd_filtered_processed.json


## Building the ML `Trainer` object

Training in the `asapdiscovery-ml` repo is controlled by the `asapdiscovery.ml.trainer.Trainer` class, which itself is composed of several different config `Pydantic` schema objects. These configs are defined in `asapdiscovery.ml.config`, with the exception of the config for the actual model, which is defined in our ML backend [`mtenn`](https://github.com/choderalab/mtenn/) (see the [`mtenn.config`](https://mtenn.readthedocs.io/en/latest/_autosummary/mtenn.config.html#module-mtenn.config) docs for more information).
The required configs are:

* `OptimizerConfig`: Config describing the optimizer to use in training
* `ModelConfigBase`: Config describing the model to train (defined in `mtenn.config`)
* `EarlyStoppingConfig`: Config describing the early stopping check to use
* `DatasetConfig`: Config describing the dataset object to train on
* `DatasetSplitterConfig`: Config describing how to split the dataset into train, val, and test splits
* `LossFunctionConfig`: Config describing the loss function for training
* `DataAugConfig`: Config describing data augmentations to be applied to each pose

Note that below we will define each of these configs separately, but this can also be done automatically by passing a `dict` defining each config to the `Trainer` instead of the actual config object.

In the below example, we'll build a different `Trainer` object with a different model config for each model type that we are demoing:

* Graph attention ([GAT](https://arxiv.org/abs/1710.10903)): a topology-only, ligand-only GNN
* [SchNet](https://arxiv.org/abs/1706.08566): an E(3)-invariant, structure-based model
* [e3nn](https://arxiv.org/abs/2207.09453): an E(3)-equivariant, structure-based model

Because we use a config-based setup, we can share configs between the different `Trainer`s as the configs will be used to build the actual Python objects rather than doing the work themselves. The only exception to this sharing is, as previously mentioned, the model configs, which are model-specific and will need to be different for each `Trainer`.

In [3]:
from asapdiscovery.data.util.utils import MOONSHOT_CDD_ID_REGEX, MPRO_ID_REGEX
from asapdiscovery.ml.config import (
    DataAugConfig,
    DatasetConfig,
    DatasetSplitterConfig,
    EarlyStoppingConfig,
    LossFunctionConfig,
    OptimizerConfig,
)
from asapdiscovery.ml.trainer import Trainer
from mtenn.config import E3NNModelConfig, GATModelConfig, SchNetModelConfig

# We will use the Adam optimizer (default for OptimizerConfig) as well as all the
#  default parameters, so no need to pass anything here
optimizer_config = OptimizerConfig()

"""
We will make 3 different model configs, one for each architecture. In this example we
will build overly small models so that the training can be run in our docs runner, but
in a real training example you should of course perform your own hyperparameter
optimization. Any hyperparameter that is not specified will use the default of the
underlying architecture.

For the SchNet and e3nn models we will use a DeltaStrategy (default) and PIC50Readout
(see the mtenn docs for more information).
"""
gat_model_config = GATModelConfig(num_layers=1)
schnet_model_config = SchNetModelConfig(pred_readout="pic50", num_interactions=1)
# We recommend using Irreps of at least l=1, but will stick with l=0 for this example
e3nn_model_config = E3NNModelConfig(pred_readout="pic50", irreps_hidden={"0": 5})

# We will need this configs to rebuild these models for inference, so we'll serialize
#  them to JSON files. This isn't strictly necessary, as this same information will be
#  stored in the Trainer JSON files that we'll save later, but this prevents us from
#  having to load in that whole file
Path("gat_model_config.json").write_text(gat_model_config.json())
Path("schnet_model_config.json").write_text(schnet_model_config.json())
Path("e3nn_model_config.json").write_text(e3nn_model_config.json())

# We will use the default of no early stopping, but this can be configured as desired.
#  See the docs for asapdiscovery.ml.es and EarlyStoppingConfig for more details
es_config = EarlyStoppingConfig()

"""
The DatasetConfig requires a bit more explanation than the other configs, as it
involves some data processing. Constructing a DatasetConfig object directly
requires passing a list of asapdiscovery.data.schema.ligand.Ligand (for 2D models)
or a list of asapdiscovery.data.schema.complex.Complex (for structure-based models).
To get around having to parse the data from files yourself, we offer convenience
functions DatasetConfig.from_exp_file (for 2D models) and DatasetConfig.from_str_files
(for structure-based models). We will also need to create 3 of these configs, a 2D one
for the GAT model, and 1 for each of the structural models, as the models expect
slightly different inputs (this is all handled within the constructor methods).

The from_exp_file method requires only an experimental data file, which is the file
we generated in the previous step. The from_str_files method only requires this file
if the dataset is being used for training, as this is where the method will pull data
labels from. We will also need to pass in a glob or directory containing complex
structure PDB files, as well as regular expressions defining what the crystal structure
and compound IDs look like. We provide these regexes for the Moonshot data in
asapdiscovery.data.util.utils as MPRO_ID_REGEX and MOONSHOT_CDD_ID_REGEX respectively.
If your structure files are formatted with differently, you will need to modify these
expressions to match your data.

One other important consideration with DatasetConfig is caching. The ultimate object
that this config will build is one of the classes in asapdiscovery.ml.dataset. These
objects can take a while to create from files (especially the structure-based ones), so
we offer the ability to cache the built object using pickle. This way, if the build
method is called for a DatasetConfig that has a Path set for cache_file, that object
will be loaded directly without having to regenerate it.
"""
gat_ds_config = DatasetConfig.from_exp_file(
    Path("cdd_filtered_processed.json"), cache_file=Path("gat_ds_config.pkl")
)
# We will assume the complex structure files are in the directory ./docking_results/
#  and follow the Moonshot naming scheme for crystal structures and compound IDs
schnet_ds_config = DatasetConfig.from_str_files(
    structures=f"{str(docked_dir)}/*.pdb",
    xtal_regex=MPRO_ID_REGEX,
    cpd_regex=MOONSHOT_CDD_ID_REGEX,
    for_training=True,
    exp_file=Path("cdd_filtered_processed.json"),
    cache_file=Path("schnet_ds_config.pkl"),
)
e3nn_ds_config = DatasetConfig.from_str_files(
    structures=f"{str(docked_dir)}/*.pdb",
    xtal_regex=MPRO_ID_REGEX,
    cpd_regex=MOONSHOT_CDD_ID_REGEX,
    for_training=True,
    exp_file=Path("cdd_filtered_processed.json"),
    cache_file=Path("e3nn_ds_config.pkl"),
    for_e3nn=True,
)

# Split our molecules temporally, using an 80:10:10 split for train:val:test (default)
ds_splitter_config = DatasetSplitterConfig(split_type="temporal")

# Use a semi-quantitative MSE loss function
#  (see asapdiscovery.ml.loss docs for more information)
loss_config = LossFunctionConfig(loss_type="mse_step")

"""
Finally, we are ready to build our Trainer and start training. We will set a couple
other options here, including logging to Weights & Biases. This functionality is
optional, and can be avoided by simply not setting use_wandb=True, however we find this
to be a useful way to track experiments. Note that you will first need to set up W&B on
your machine (see their docs for how to get started). The only option other than the
configs that is required to be set is output_dir.

We are training here for 1 epoch (for docs purposes), with a mini-batch size of 25.
The training will be done on the CPU (also for docs purposes), and will be saved to
./<model>_training/. We will log each training run to W&B, in a project named tutorial
as a run named after the model.
"""
t_gat = Trainer(
    optimizer_config=optimizer_config,
    model_config=gat_model_config,
    es_config=es_config,
    ds_config=gat_ds_config,
    ds_splitter_config=ds_splitter_config,
    loss_config=loss_config,
    n_epochs=1,
    batch_size=25,
    device="cpu",
    output_dir="./gat_training/",
    use_wandb=True,
    wandb_project="tutorial",
    wandb_name="gat",
)
t_schnet = Trainer(
    optimizer_config=optimizer_config,
    model_config=schnet_model_config,
    es_config=es_config,
    ds_config=schnet_ds_config,
    ds_splitter_config=ds_splitter_config,
    loss_config=loss_config,
    n_epochs=1,
    batch_size=25,
    device="cpu",
    output_dir="./schnet_training/",
    use_wandb=True,
    wandb_project="tutorial",
    wandb_name="schnet",
)
t_e3nn = Trainer(
    optimizer_config=optimizer_config,
    model_config=e3nn_model_config,
    es_config=es_config,
    ds_config=e3nn_ds_config,
    ds_splitter_config=ds_splitter_config,
    loss_config=loss_config,
    n_epochs=1,
    batch_size=25,
    device="cpu",
    output_dir="./e3nn_training/",
    use_wandb=True,
    wandb_project="tutorial",
    wandb_name="e3nn",
)

# If desired we can save each of these Trainers as a JSON file, which will let us skip
#  all of the above steps next time we want to re-run this training or something
#  similar to it
Path("gat_trainer.json").write_text(t_gat.json())
Path("schnet_trainer.json").write_text(t_schnet.json())
Path("e3nn_trainer.json").write_text(t_e3nn.json())

# Finally, we initialize each Trainer and start training. The initialization step
#  handles building all the underlying Python objects, as well as syncing with W&B.
t_gat.initialize()
t_gat.train()
t_schnet.initialize()
t_schnet.train()
t_e3nn.initialize()
t_e3nn.train()

Filtering 0 structures that we don't have experimental data for.
10 10
Filtering 0 structures that we don't have experimental data for.
10 10
loading from cache


ds lengths 991 127 93
Epoch 0/1


0,1
epoch,▁
epoch_time,▁
test_loss,▁
train_loss,▁
val_loss,▁

0,1
epoch,0.0
epoch_time,3.7519
test_loss,1.42651
train_loss,7.22673
val_loss,2.74728


loading from cache


ds lengths 2 2 6
Epoch 0/1


0,1
epoch,▁
epoch_time,▁
test_loss,▁
train_loss,▁
val_loss,▁

0,1
epoch,0.0
epoch_time,2.08681
test_loss,0.0
train_loss,11.28308
val_loss,12.07008


loading from cache


ds lengths 2 2 6


Epoch 0/1


0,1
epoch,▁
epoch_time,▁
test_loss,▁
train_loss,▁
val_loss,▁

0,1
epoch,0.0
epoch_time,5.29716
test_loss,0.0
train_loss,9.31578
val_loss,10.75333
