# Train a Simple Regression Model

The process of training a machine learning (ML) model can be thought of as fitting a
highly parameterized function to map inputs to outputs. An ML algorithm needs to learn from
numerous examples of input and output pairs to accurately map an input to an output,
i. e., make a prediction. After training, the result is referred to a trained ML model or an artifact.

This tutorial will detail how we can use [AMPL](https://github.com/ATOMScience-org/AMPL) tools to train a regression model to predict 
how much a compound will inhibit the KCNA3 protein as measured by pIC50. 
We will train a random forest model using the following inputs:

1. The curated kcna3 dataset from **tutorial 2**.
2. The split file generated in **tutorial 3**.
3. [RDKit](https://github.com/rdkit/rdkit) features calculated by the [AMPL](https://github.com/ATOMScience-org/AMPL) pipeline.

We will explain the use of descriptors, how to evaulate model performance,
and where the model is saved as a .tar.gz file.

> **Note** *Training a random forest model and splitting the dataset are non-deterministic. 
You will obtain a slightly different random forest model by running this tutorial each time.*

## Model Training (using already split data)

We will use the curated dataset created in **tutorial 2** and the split file 
created in **tutorial 3** to build a json file for training. We set `"previously_split": "True"`
 and set the `split_uuid`. 
Here, we will use `"split_uuid" : "8daa5687-c2ee-45e4-b385-36164246c419"; 
the uuid for the scaffold split created in **tutorial 3**.

[AMPL](https://github.com/ATOMScience-org/AMPL) provides an extensive featurization module that can generate a 
variety of molecular feature types, given SMILES strings as input. 
For demonstration purposes, we choose to use RDKit features in this tutorial.

When the featurized dataset is not previously saved for curated_kcna3_ic50, 
[AMPL](https://github.com/ATOMScience-org/AMPL) will create a featurized dataset and save it in a folder called `scaled_descriptors` 
as a csv file e.g. `dataset/scaled_descriptors/curated_kcna3_ic50_with_rdkit_raw_descriptors.csv`

In [1]:
# importing relevant libraries
import pandas as pd
from atomsci.ddm.pipeline import model_pipeline as mp
from atomsci.ddm.pipeline import parameter_parser as parse

# Set up
dataset_file = 'dataset/curated_kcna3_ic50.csv'
odir='dataset'

response_col = "avg_pIC50"
compound_id = "compound_id"
smiles_col = "base_rdkit_smiles"
split_uuid = "8daa5687-c2ee-45e4-b385-36164246c419"

params = {
        "verbose": "True",
        "system": "LC",
        "datastore": "False",
        "save_results": "False",
        "prediction_type": "regression",
        "dataset_key": dataset_file,
        "id_col": compound_id,
        "smiles_col": smiles_col,
        "response_cols": response_col,
        "previously_split": "True",
        "split_uuid" : split_uuid,
        "split_only": "False",
        "featurizer": "computed_descriptors",
        "descriptor_type" : "rdkit_raw",
        "model_type": "RF",
        "verbose": "True",
        "transformers": "True",
        "rerun": "False",
        "result_dir": odir
    }

ampl_param = parse.wrapper(params)
pl = mp.ModelPipeline(ampl_param)
pl.train_model()

Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/usr/WS2/kmelough/ampl16_env/lib/python3.9/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'pytorch_lightning'
Skipped loading some Jax models, missing a dependency. jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.
  X = np.nan_to_num((X - self.X_means) * X_weight / self.X_stds)


## Model Training (Split data and train)

It is possible to split and train a model in one step. 
Here, we set `"previously_split": "False"` and don't specify a split_uuid parameter. 
[AMPL](https://github.com/ATOMScience-org/AMPL) splits the data by the type of split specified in the splitter parameter, 
scaffold in this example, and writes the split file in
`dataset/curated_kcna3_ic50_train_valid_test_scaffold_{split_uuid}.csv.` 
After training, [AMPL](https://github.com/ATOMScience-org/AMPL) saves the model and all of its parameters as a tarball in `result_dir`.

In [2]:
response_col = "avg_pIC50"
compound_id = "compound_id"
smiles_col = "base_rdkit_smiles"

params = {
        "verbose": "True",
        "system": "LC",
        "datastore": "False",
        "save_results": "False",
        "prediction_type": "regression",
        "dataset_key": dataset_file,
        "id_col": compound_id,
        "smiles_col": smiles_col,
        "response_cols": response_col,
        "previously_split": "False",
        "split_only": "False",
        "splitter": "scaffold",
        "split_valid_frac": "0.15",
        "split_test_frac": "0.15",
        "featurizer": "computed_descriptors",
        "descriptor_type" : "rdkit_raw",
        "model_type": "RF",
        "verbose": "True",
        "transformers": "True",
        "rerun": "False",
        "result_dir": odir
    }

ampl_param = parse.wrapper(params)
pl = mp.ModelPipeline(ampl_param)
pl.train_model()

  X = np.nan_to_num((X - self.X_means) * X_weight / self.X_stds)


## Performance of the Model
We evaluate model performance by measuring how accurate 
model predictions are on validation and test sets. 
The validation set is used while optimizing the model and for choosing the best
parameter settings. Then the performance on the test set is the final judge of
model performance.

AMPL has several popular metrics to evaulate regression models; 
Mean Absolute Error (MAE), Mean Squared Error (MSE), Root Mean Squared Error (RMSE) and R² (R-Squared).
In our tutorials, we will use R² metric to compare our models. The best model will have the highest
R² score.

> **Note** *The model tracker client will not be supported in your environment.*

In [3]:
# Model Performance
from atomsci.ddm.pipeline import compare_models as cm
pred_df = cm.get_filesystem_perf_results(odir, pred_type='regression')



Found data for 4 models under dataset


The pred_df dataframe has details about the model_uuid, model_path, ampl_version, model_type, features, splitter and the results for popular metrics that help evaluate the performance. Let us view the contents of the pred_df dataframe.

In [4]:
pred_df.to_csv('./dataset/pred_df.csv')

In [5]:
# View the pred_df dataframe
pred_df

Unnamed: 0,model_uuid,model_path,ampl_version,model_type,dataset_key,features,splitter,split_strategy,model_score_type,feature_transform_type,...,dropouts,xgb_gamma,xgb_learning_rate,xgb_max_depth,xgb_colsample_bytree,xgb_subsample,xgb_n_estimators,xgb_min_child_weight,model_parameters_dict,feat_parameters_dict
1,18d7ef2d-7192-41df-b3ad-ec92a7fd0e72,dataset/curated_kcna3_ic50_model_18d7ef2d-7192...,1.6.0,RF,/usr/WS2/kmelough/git3/AMPL/atomsci/ddm/exampl...,rdkit_raw,scaffold,train_valid_test,r2,normalization,...,,,,,,,,,"{""rf_estimators"": 500, ""rf_max_depth"": null, ""...",{}
3,2eb869ed-3581-4ea7-a601-90bb2f7e19a5,dataset/curated_kcna3_ic50_model_2eb869ed-3581...,1.6.0,RF,/usr/WS2/kmelough/git3/AMPL/atomsci/ddm/exampl...,rdkit_raw,scaffold,train_valid_test,r2,normalization,...,,,,,,,,,"{""rf_estimators"": 500, ""rf_max_depth"": null, ""...",{}
0,2e54d840-2145-4973-b975-9aad8edb2f35,dataset/curated_kcna3_ic50_model_2e54d840-2145...,1.6.0,RF,/usr/WS2/kmelough/git3/AMPL/atomsci/ddm/exampl...,rdkit_raw,scaffold,train_valid_test,r2,normalization,...,,,,,,,,,"{""rf_estimators"": 500, ""rf_max_depth"": null, ""...",{}
2,06bced83-b064-4163-91d2-8e56bb7d237f,dataset/curated_kcna3_ic50_model_06bced83-b064...,1.6.0,RF,/usr/WS2/kmelough/git3/AMPL/atomsci/ddm/exampl...,rdkit_raw,scaffold,train_valid_test,r2,normalization,...,,,,,,,,,"{""rf_estimators"": 500, ""rf_max_depth"": null, ""...",{}


In [7]:
pred_df[['model_uuid', 'best_valid_r2_score', 'best_test_r2_score', 'best_train_num_compounds']]

Unnamed: 0,model_uuid,best_valid_r2_score,best_test_r2_score,best_train_num_compounds
1,18d7ef2d-7192-41df-b3ad-ec92a7fd0e72,0.372506,0.277066,259
3,2eb869ed-3581-4ea7-a601-90bb2f7e19a5,0.370694,0.291152,259
0,2e54d840-2145-4973-b975-9aad8edb2f35,0.365745,0.27258,259
2,06bced83-b064-4163-91d2-8e56bb7d237f,0.360854,0.272374,259


## Top Performing Model
To pick the top performing model, we sort the `best_valid_r2_score` column in descending order and pick the one that is maximum.

In [8]:
# Top performing model
top_model=pred_df.sort_values(by="best_valid_r2_score", ascending=False).iloc[0,:]
top_model

model_uuid                               18d7ef2d-7192-41df-b3ad-ec92a7fd0e72
model_path                  dataset/curated_kcna3_ic50_model_18d7ef2d-7192...
ampl_version                                                            1.6.0
model_type                                                                 RF
dataset_key                 /usr/WS2/kmelough/git3/AMPL/atomsci/ddm/exampl...
features                                                            rdkit_raw
splitter                                                             scaffold
split_strategy                                               train_valid_test
model_score_type                                                           r2
feature_transform_type                                          normalization
model_choice_score                                                   0.372506
best_train_r2_score                                                  0.925609
best_train_rms_score                                            

## Model Tarball 
The model_path or the location of the tarball where the top performing model is saved is in `top_model.model_path`.

In [9]:
# Top performing model path
top_model.model_path

'dataset/curated_kcna3_ic50_model_18d7ef2d-7192-41df-b3ad-ec92a7fd0e72.tar.gz'

We will need this path in the next tutorial in which we use the trained model to make predictions on a new dataset.