# This is Step 3 in the Pipeline - Training ML Prediction Model
With this notebook we can train various ML classifiers to tackle multi-lable prediction problem. We are predicting Spec2Vec embeddings from molecular fingerprints.

### Imports

In [64]:
from sklearn.metrics import accuracy_score, f1_score, log_loss, precision_score, recall_score, jaccard_score, roc_auc_score, hamming_loss, label_ranking_loss, coverage_error
from sklearn.model_selection import KFold
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import  ClassifierChain
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.dummy import DummyClassifier
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import pickle
import os

### Parameters

In [65]:
RANDOM_STATE = 27082023
np.random.seed(RANDOM_STATE)

# path to merged fingerprint and embedding data (fingerprint columns should be prefixed with 'fingerprint_' and embedding columns should be prefixed with 'embedding_').
MERGED_PATH = './source/embedding/tms_maccs/merged.csv'
MODEL_OUTPUT_FOLDER = "./source/model/tms_maccs/"

In [66]:
assert os.path.isfile(MERGED_PATH)
assert os.path.isdir(MODEL_OUTPUT_FOLDER)
assert MERGED_PATH.endswith('.csv')

In [67]:
ESTIMATOR = DecisionTreeClassifier(random_state=RANDOM_STATE, class_weight='balanced')

In [68]:
MODEL = OneVsRestClassifier(ESTIMATOR, n_jobs=-1)

In [69]:
MODEL_OUTPUT_FOLDER = f'{MODEL_OUTPUT_FOLDER}{MODEL.__class__.__name__}_{ESTIMATOR.__class__.__name__}'
os.makedirs(f'{MODEL_OUTPUT_FOLDER}/models', exist_ok=False)

### Metrics Definition
Creates metrics which can be called with (y_true, y_prob, y_pred) for easier use. It also creates multiple combinations of metrics for different averaging methods.

In [70]:
Y_PRED_SCORES = [accuracy_score, log_loss, hamming_loss] # input y predictions and y true
Y_PRED_SCORES_WITH_AVERAGING = [f1_score, precision_score, recall_score, jaccard_score] # input y predictions and y true and use one of the following: "micro", "macro", "weighted", "samples"
Y_PROB_SCORES = [roc_auc_score, label_ranking_loss, coverage_error] # input y probabilities and y true

In [71]:
METRICS = []
METRIC_NAMES = []
for metric in Y_PRED_SCORES:
    METRICS.append(lambda y_true, y_prob, y_pred, metric=metric: metric(y_true, y_pred))
    METRIC_NAMES.append(metric.__name__)
for metric in Y_PRED_SCORES_WITH_AVERAGING:
    for average in ["micro", "macro", "weighted", "samples"]:
        zero_division = 0 if metric.__name__ == "jaccard_score" else np.nan
        METRICS.append(lambda y_true, y_prob, y_pred, metric=metric, average=average: metric(y_true, y_pred, average=average, zero_division=zero_division))
        METRIC_NAMES.append(metric.__name__ + "__" + average)
for metric in Y_PROB_SCORES:
    METRICS.append(lambda y_true, y_prob, y_pred, metric=metric: metric(y_true, y_prob))
    METRIC_NAMES.append(metric.__name__)

In [80]:
class Metrics:
    def __init__(self, metrics, metric_names, repeats=2, folds=5):
        self.metrics = metrics
        self.metric_names = metric_names
        
        self.repeats = repeats
        self.folds = folds
        self.i = 0

        self.results = pd.DataFrame(columns=['repeat', 'fold', 'model_training_data_path'] + self.metric_names)
    
    def evaluate(self, y_true, y_prob, y_pred, model_training_data_path=None):
        entry = {
            'repeat': self.i // self.folds,
            'fold': self.i % self.folds,
            'model_training_data_path': model_training_data_path
        }
        for metric, metric_name in zip(self.metrics, self.metric_names):
            try:
                entry[metric_name] = metric(y_true, y_prob, y_pred)
            except ValueError as e:
                print("Warning: ", e)
                entry[metric_name] = np.nan
        
        self.results = pd.concat([self.results, pd.DataFrame(entry, index=[0])], ignore_index=True)
        self.i += 1
    
    def store(self, filename):
        self.results.to_csv(filename, index=False)

### Load Data

In [73]:
merged_df = pd.read_csv(MERGED_PATH)
merged_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3025 entries, 0 to 3024
Columns: 467 entries, inchi_key to embedding_299
dtypes: float64(466), object(1)
memory usage: 10.8+ MB


In [74]:
f'Number of NaNs: {merged_df.isna().sum().sum()}' # should be 0

'Number of NaNs: 0'

In [75]:
X = merged_df.filter(regex='^embedding_')
y = merged_df.filter(regex='^fingerprint_')
X.shape, y.shape

((3025, 300), (3025, 166))

In [76]:
X = X.to_numpy()
y = y.to_numpy()

### Train- K-fold Cross Validation

In [81]:
REPEATS = 2
K = 5
metrics = Metrics(METRICS, METRIC_NAMES, REPEATS, K)

for i in tqdm(range(REPEATS), desc="Repeats"):
    kf = KFold(n_splits=K, shuffle=True, random_state=RANDOM_STATE + i)

    for fold, (train_index, test_index) in tqdm(enumerate(kf.split(X, y)), desc="Fold", total=K):
        # train
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        MODEL.fit(X_train, y_train)

        # predict
        y_pred = MODEL.predict(X_test)
        y_prob = MODEL.predict_proba(X_test)

        # store train data
        model_training_data_path = f'{MODEL_OUTPUT_FOLDER}/models/{i}_{fold}.pkl'
        with open(model_training_data_path, "wb") as f:
            pickle.dump({
                "model": MODEL,
                "X_train": X_train,
                "y_train": y_train,
                "X_test": X_test,
                "y_test": y_test,
            }, f)

        # evaluate
        metrics.evaluate(y_test, y_prob, y_pred, model_training_data_path=model_training_data_path)

metrics.store(f'{MODEL_OUTPUT_FOLDER}/metrics.csv')

Repeats:   0%|          | 0/2 [00:00<?, ?it/s]

Fold:   0%|          | 0/5 [00:00<?, ?it/s]





















Fold:   0%|          | 0/5 [00:00<?, ?it/s]





















In [82]:
metrics.results.describe()

Unnamed: 0,accuracy_score,log_loss,hamming_loss,f1_score__micro,f1_score__macro,f1_score__weighted,f1_score__samples,precision_score__micro,precision_score__macro,precision_score__weighted,...,recall_score__macro,recall_score__weighted,recall_score__samples,jaccard_score__micro,jaccard_score__macro,jaccard_score__weighted,jaccard_score__samples,roc_auc_score,label_ranking_loss,coverage_error
count,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,...,10.0,10.0,10.0,10.0,10.0,10.0,10.0,0.0,10.0,10.0
mean,0.028264,333.493094,0.068375,0.835371,0.451258,0.833027,0.837153,0.839406,0.46655,0.83646,...,0.443956,0.831408,0.847445,0.717324,0.353614,0.759142,0.732798,,0.186376,152.843967
std,0.008263,14.354913,0.002665,0.005879,0.008207,0.007276,0.004822,0.004782,0.006909,0.006218,...,0.012817,0.008846,0.006129,0.008656,0.008087,0.008627,0.006806,,0.006446,2.013518
min,0.018182,308.850216,0.064084,0.825381,0.437841,0.821534,0.828538,0.832017,0.455485,0.826457,...,0.42861,0.818849,0.838118,0.702679,0.341186,0.746104,0.720734,,0.175845,148.642975
25%,0.021074,324.645224,0.066407,0.830928,0.444927,0.827255,0.833796,0.837176,0.460771,0.833271,...,0.431412,0.82497,0.844306,0.710759,0.346586,0.752445,0.727833,,0.182241,152.410331
50%,0.02562,339.099101,0.068142,0.836687,0.453002,0.834195,0.837584,0.838961,0.468439,0.837324,...,0.443882,0.830085,0.846514,0.719227,0.355581,0.759516,0.733053,,0.186275,152.87686
75%,0.03595,342.837591,0.070577,0.8395,0.457315,0.838893,0.840724,0.843662,0.472367,0.841809,...,0.4562,0.838189,0.851182,0.723396,0.359529,0.766304,0.738431,,0.190058,154.098347
max,0.041322,349.291111,0.072369,0.843177,0.461857,0.843495,0.843734,0.845563,0.473733,0.843662,...,0.459083,0.845031,0.857603,0.728874,0.364069,0.77237,0.741174,,0.196705,155.87438
