# Model training

This notebook is used to train and compare a set of models on the data splits. This step is done to identify the best performing model.

In [1]:
import os
from pycaret.classification import *
import pandas as pd

# Model Training and comparison 

In [2]:
def train_model_pipeline(
    data: pd.DataFrame, test_data: pd.DataFrame, metric: str, exp_name: str
):
    """PyCaret pipeline for training a model on a dataset, selecting the bet and saving it to disk."""

    setup(
        data=data,
        target="label",
        test_data=test_data,
        index=False,
        train_size=0.8,  # train-validation split
        low_variance_threshold=0,
        fold_strategy="stratifiedkfold",  # CV strategy
        fold=5,  # CV strategy
        experiment_name=exp_name,
    )

    best_model = compare_models(
        sort=metric,
        round=3,  # number of decimal places for metric
        include=["nb", "lr", "lightgbm", "dt", "rf", "xgboost"],
    )  # Gets the best model based on the metric for training set

In [3]:
# Run this script to train the models - takes 1 day to train all models
for fingerprint_name in ["ecfp4", "rdkit", "maccs", "mhfp6", "erg", "chemphys"]:
    train_data = pd.read_csv(f"../data/splits/{fingerprint_name}_train.csv")
    train_with_smote = pd.read_csv(f"../data/splits/{fingerprint_name}_smote_train.csv")
    test_data = pd.read_csv(f"../data/splits/{fingerprint_name}_test.csv")

    print(f"Training {fingerprint_name} model with no SMOTE")
    train_model_pipeline(
        data=train_data,
        test_data=test_data,
        metric="Kappa",
        exp_name=f"{fingerprint_name}",
    )

    print(f"Training {fingerprint_name} model with SMOTE")
    train_model_pipeline(
        data=train_with_smote,
        test_data=test_data,
        metric="Kappa",
        exp_name=f"{fingerprint_name}_with_smote",
    )

Training ecfp4 model with no SMOTE


Unnamed: 0,Description,Value
0,Session id,3036
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(54137, 1025)"
5,Transformed data shape,"(54137, 1025)"
6,Transformed train set shape,"(43309, 1025)"
7,Transformed test set shape,"(10828, 1025)"
8,Numeric features,1024
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.786,0.94,0.786,0.809,0.773,0.634,0.658,4.51
xgboost,Extreme Gradient Boosting,0.773,0.919,0.773,0.779,0.763,0.62,0.633,3.874
lightgbm,Light Gradient Boosting Machine,0.74,0.901,0.74,0.756,0.726,0.556,0.579,2.162
dt,Decision Tree Classifier,0.664,0.74,0.664,0.663,0.663,0.473,0.473,1.984
lr,Logistic Regression,0.653,0.808,0.653,0.644,0.643,0.429,0.434,4.236
nb,Naive Bayes,0.483,0.731,0.483,0.592,0.509,0.288,0.304,0.846


Training ecfp4 model with SMOTE


Unnamed: 0,Description,Value
0,Session id,4767
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(99620, 1025)"
5,Transformed data shape,"(99620, 1025)"
6,Transformed train set shape,"(88792, 1025)"
7,Transformed test set shape,"(10828, 1025)"
8,Numeric features,1024
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.898,0.985,0.898,0.908,0.899,0.864,0.867,9.986
xgboost,Extreme Gradient Boosting,0.819,0.96,0.819,0.835,0.819,0.759,0.764,8.078
dt,Decision Tree Classifier,0.815,0.883,0.815,0.82,0.814,0.753,0.756,3.696
lightgbm,Light Gradient Boosting Machine,0.789,0.947,0.789,0.807,0.788,0.718,0.724,3.21
lr,Logistic Regression,0.686,0.89,0.686,0.705,0.682,0.581,0.588,10.162
nb,Naive Bayes,0.534,0.748,0.534,0.567,0.519,0.378,0.393,1.416


Training rdkit model with no SMOTE


Unnamed: 0,Description,Value
0,Session id,7834
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(54137, 1025)"
5,Transformed data shape,"(54137, 1025)"
6,Transformed train set shape,"(43309, 1025)"
7,Transformed test set shape,"(10828, 1025)"
8,Numeric features,1024
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
xgboost,Extreme Gradient Boosting,0.806,0.939,0.806,0.806,0.8,0.683,0.688,7.262
rf,Random Forest Classifier,0.803,0.943,0.803,0.813,0.794,0.67,0.683,4.314
lightgbm,Light Gradient Boosting Machine,0.765,0.916,0.765,0.775,0.754,0.603,0.619,4.704
dt,Decision Tree Classifier,0.695,0.771,0.695,0.697,0.696,0.524,0.524,2.018
lr,Logistic Regression,0.663,0.811,0.663,0.654,0.654,0.448,0.452,23.326
nb,Naive Bayes,0.282,0.59,0.282,0.534,0.293,0.111,0.137,0.724


Training rdkit model with SMOTE


Unnamed: 0,Description,Value
0,Session id,8964
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(99620, 1025)"
5,Transformed data shape,"(99620, 1025)"
6,Transformed train set shape,"(88792, 1025)"
7,Transformed test set shape,"(10828, 1025)"
8,Numeric features,1024
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.934,0.991,0.934,0.935,0.934,0.912,0.912,7.376
xgboost,Extreme Gradient Boosting,0.902,0.984,0.902,0.904,0.902,0.87,0.87,13.144
dt,Decision Tree Classifier,0.843,0.897,0.843,0.844,0.842,0.791,0.792,3.69
lightgbm,Light Gradient Boosting Machine,0.829,0.96,0.829,0.831,0.829,0.772,0.773,8.174
lr,Logistic Regression,0.68,0.877,0.68,0.679,0.678,0.573,0.573,32.016
nb,Naive Bayes,0.388,0.63,0.388,0.408,0.362,0.184,0.195,1.758


Training maccs model with no SMOTE


Unnamed: 0,Description,Value
0,Session id,8838
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(54137, 168)"
5,Transformed data shape,"(54137, 162)"
6,Transformed train set shape,"(43309, 162)"
7,Transformed test set shape,"(10828, 162)"
8,Numeric features,167
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.808,0.942,0.808,0.808,0.802,0.686,0.692,1.242
xgboost,Extreme Gradient Boosting,0.764,0.91,0.764,0.762,0.757,0.613,0.619,0.974
dt,Decision Tree Classifier,0.73,0.795,0.73,0.728,0.729,0.576,0.576,0.268
lightgbm,Light Gradient Boosting Machine,0.726,0.886,0.726,0.731,0.713,0.538,0.554,1.15
lr,Logistic Regression,0.57,0.722,0.57,0.547,0.53,0.243,0.263,3.364
nb,Naive Bayes,0.302,0.623,0.302,0.548,0.285,0.108,0.155,0.144


Training maccs model with SMOTE


Unnamed: 0,Description,Value
0,Session id,6541
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(99620, 168)"
5,Transformed data shape,"(99620, 162)"
6,Transformed train set shape,"(88792, 162)"
7,Transformed test set shape,"(10828, 162)"
8,Numeric features,167
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.915,0.987,0.915,0.916,0.915,0.887,0.887,2.656
dt,Decision Tree Classifier,0.849,0.903,0.849,0.851,0.848,0.798,0.799,0.588
xgboost,Extreme Gradient Boosting,0.817,0.955,0.817,0.82,0.816,0.756,0.757,2.01
lightgbm,Light Gradient Boosting Machine,0.762,0.929,0.762,0.766,0.761,0.683,0.685,1.676
lr,Logistic Regression,0.544,0.792,0.544,0.542,0.539,0.392,0.394,8.254
nb,Naive Bayes,0.412,0.698,0.412,0.496,0.355,0.216,0.25,0.404


Training mhfp6 model with no SMOTE


Unnamed: 0,Description,Value
0,Session id,4265
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(54137, 2049)"
5,Transformed data shape,"(54137, 2049)"
6,Transformed train set shape,"(43309, 2049)"
7,Transformed test set shape,"(10828, 2049)"
8,Numeric features,2048
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
xgboost,Extreme Gradient Boosting,0.806,0.941,0.806,0.808,0.799,0.681,0.687,37.166
lightgbm,Light Gradient Boosting Machine,0.79,0.934,0.79,0.794,0.782,0.651,0.661,36.236
rf,Random Forest Classifier,0.787,0.945,0.787,0.811,0.774,0.636,0.66,17.616
dt,Decision Tree Classifier,0.665,0.743,0.665,0.665,0.665,0.476,0.476,16.778
lr,Logistic Regression,0.658,0.809,0.658,0.651,0.653,0.451,0.453,49.676
nb,Naive Bayes,0.195,0.526,0.195,0.443,0.152,0.033,0.051,2.846


Training mhfp6 model with SMOTE


Unnamed: 0,Description,Value
0,Session id,3481
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(99620, 2049)"
5,Transformed data shape,"(99620, 2049)"
6,Transformed train set shape,"(88792, 2049)"
7,Transformed test set shape,"(10828, 2049)"
8,Numeric features,2048
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.933,0.992,0.933,0.935,0.933,0.911,0.911,47.698
xgboost,Extreme Gradient Boosting,0.901,0.987,0.901,0.911,0.902,0.868,0.871,88.078
lightgbm,Light Gradient Boosting Machine,0.86,0.977,0.86,0.876,0.861,0.814,0.819,58.08
dt,Decision Tree Classifier,0.782,0.855,0.782,0.783,0.781,0.709,0.709,47.268
lr,Logistic Regression,0.73,0.9,0.73,0.728,0.728,0.641,0.641,105.03
nb,Naive Bayes,0.386,0.63,0.386,0.384,0.34,0.181,0.203,8.62


Training erg model with no SMOTE


Unnamed: 0,Description,Value
0,Session id,7623
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(52871, 316)"
5,Transformed data shape,"(52871, 316)"
6,Transformed train set shape,"(42296, 316)"
7,Transformed test set shape,"(10575, 316)"
8,Numeric features,315
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.782,0.925,0.782,0.781,0.773,0.642,0.65,1.942
dt,Decision Tree Classifier,0.717,0.794,0.717,0.714,0.716,0.558,0.558,0.546
xgboost,Extreme Gradient Boosting,0.736,0.895,0.736,0.74,0.721,0.557,0.572,1.444
lightgbm,Light Gradient Boosting Machine,0.714,0.88,0.714,0.728,0.694,0.51,0.536,1.496
lr,Logistic Regression,0.557,0.718,0.557,0.541,0.508,0.211,0.238,6.74
nb,Naive Bayes,0.226,0.579,0.226,0.505,0.252,0.077,0.114,0.354


Training erg model with SMOTE


Unnamed: 0,Description,Value
0,Session id,8307
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(96079, 316)"
5,Transformed data shape,"(96079, 316)"
6,Transformed train set shape,"(85504, 316)"
7,Transformed test set shape,"(10575, 316)"
8,Numeric features,315
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.895,0.979,0.895,0.9,0.895,0.86,0.861,4.218
dt,Decision Tree Classifier,0.804,0.878,0.804,0.808,0.804,0.739,0.74,1.112
xgboost,Extreme Gradient Boosting,0.8,0.954,0.8,0.82,0.8,0.734,0.741,3.65
lightgbm,Light Gradient Boosting Machine,0.769,0.938,0.769,0.791,0.768,0.692,0.7,4.358
lr,Logistic Regression,0.532,0.771,0.532,0.533,0.524,0.376,0.381,13.838
nb,Naive Bayes,0.36,0.625,0.36,0.424,0.307,0.147,0.188,0.326


Training chemphys model with no SMOTE


Unnamed: 0,Description,Value
0,Session id,4928
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(107695, 36)"
5,Transformed data shape,"(107695, 36)"
6,Transformed train set shape,"(86156, 36)"
7,Transformed test set shape,"(21539, 36)"
8,Numeric features,35
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.919,0.986,0.919,0.921,0.918,0.871,0.873,2.512
dt,Decision Tree Classifier,0.882,0.91,0.882,0.882,0.882,0.816,0.816,0.352
xgboost,Extreme Gradient Boosting,0.767,0.919,0.767,0.772,0.757,0.611,0.623,0.912
lightgbm,Light Gradient Boosting Machine,0.708,0.88,0.708,0.724,0.689,0.497,0.524,1.45
nb,Naive Bayes,0.192,0.565,0.192,0.441,0.18,0.044,0.058,0.268
lr,Logistic Regression,0.505,0.599,0.505,0.353,0.352,-0.001,-0.003,2.042


Training chemphys model with SMOTE


Unnamed: 0,Description,Value
0,Session id,168
1,Target,label
2,Target type,Multiclass
3,Target mapping,"acid-fast: 0, gram-negative: 1, gram-positive: 2, unselective: 3"
4,Original data shape,"(197435, 36)"
5,Transformed data shape,"(197435, 36)"
6,Transformed train set shape,"(175896, 36)"
7,Transformed test set shape,"(21539, 36)"
8,Numeric features,35
9,Preprocess,True


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.937,0.994,0.937,0.938,0.937,0.916,0.917,5.566
dt,Decision Tree Classifier,0.865,0.91,0.865,0.865,0.865,0.82,0.821,0.762
xgboost,Extreme Gradient Boosting,0.773,0.94,0.773,0.779,0.772,0.697,0.699,1.576
lightgbm,Light Gradient Boosting Machine,0.695,0.899,0.695,0.701,0.694,0.594,0.597,2.364
lr,Logistic Regression,0.414,0.666,0.414,0.41,0.4,0.218,0.223,3.866
nb,Naive Bayes,0.327,0.615,0.327,0.333,0.266,0.103,0.12,0.354
