# This is Step 3 in the Pipeline - Training ML Prediction Model
With this notebook we can train various ML classifiers to tackle multi-lable prediction problem. We are predicting Spec2Vec embeddings from molecular fingerprints.

### Imports

In [21]:
from sklearn.metrics import accuracy_score, f1_score, log_loss, precision_score, recall_score, jaccard_score, roc_auc_score, hamming_loss, label_ranking_loss, coverage_error
from sklearn.model_selection import KFold
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import  ClassifierChain
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegression
from mass_spectra.similarity_voting import SimilarityVoting
from wrappers.nn import NN
from wrappers.catboost import CatBoost
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import pickle
from random import shuffle, seed
from math import ceil
import os
from torch.nn import BCEWithLogitsLoss

### Parameters

In [22]:
RANDOM_STATE = 27082023
seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)

# path to merged fingerprint and embedding data (fingerprint columns should be prefixed with 'fingerprint_' and embedding columns should be prefixed with 'embedding_').
MERGED_PATH = './source/embedding/all_positive_all_fingerprints/merged.csv'
MODEL_OUTPUT_FOLDER = "./source/model/all_positive_all_fingerprints/"

In [23]:
assert os.path.isfile(MERGED_PATH)
assert os.path.isdir(MODEL_OUTPUT_FOLDER)
assert MERGED_PATH.endswith('.csv')

In [24]:
ESTIMATOR = RandomForestClassifier(n_estimators=100, random_state=RANDOM_STATE)

In [25]:
# MODEL = CatBoost(num_trees=500, learning_rate =0.001, random_seed=RANDOM_STATE, allow_const_label=True, verbose=False, loss_function='MultiLogloss')
MODEL = OneVsRestClassifier(ESTIMATOR, n_jobs=-1)

In [26]:
MODEL_CLASS = MODEL.__class__.__name__
ESTIMATOR_CLASS = ESTIMATOR.__class__.__name__ if ESTIMATOR is not None else 'Multioutput'
MODEL_OUTPUT_FOLDER = f'{MODEL_OUTPUT_FOLDER}{MODEL_CLASS}_{ESTIMATOR_CLASS}'
os.makedirs(f'{MODEL_OUTPUT_FOLDER}/models', exist_ok=False)
os.makedirs(f'{MODEL_OUTPUT_FOLDER}/unseen_inchi_keys_models', exist_ok=False)

### Metrics Definition
Creates metrics which can be called with (y_true, y_prob, y_pred) for easier use. It also creates multiple combinations of metrics for different averaging methods.

In [27]:
Y_PRED_SCORES = [accuracy_score, log_loss, hamming_loss] # input y predictions and y true
Y_PRED_SCORES_WITH_AVERAGING = [f1_score, precision_score, recall_score, jaccard_score] # input y predictions and y true and use one of the following: "micro", "macro", "weighted", "samples"
Y_PROB_SCORES = [roc_auc_score, label_ranking_loss, coverage_error] # input y probabilities and y true

In [28]:
METRICS = []
METRIC_NAMES = []
for metric in Y_PRED_SCORES:
    METRICS.append(lambda y_true, y_prob, y_pred, metric=metric: metric(y_true, y_pred))
    METRIC_NAMES.append(metric.__name__)
for metric in Y_PRED_SCORES_WITH_AVERAGING:
    for average in ["micro", "macro", "weighted", "samples"]:
        zero_division = 0 if metric.__name__ == "jaccard_score" else np.nan
        METRICS.append(lambda y_true, y_prob, y_pred, metric=metric, average=average: metric(y_true, y_pred, average=average, zero_division=zero_division))
        METRIC_NAMES.append(metric.__name__ + "__" + average)
for metric in Y_PROB_SCORES:
    METRICS.append(lambda y_true, y_prob, y_pred, metric=metric: metric(y_true, y_prob))
    METRIC_NAMES.append(metric.__name__)

In [29]:
class Metrics:
    def __init__(self, metrics, metric_names, repeats=2, folds=5):
        self.metrics = metrics
        self.metric_names = metric_names
        
        self.repeats = repeats
        self.folds = folds
        self.i = 0

        self.results = pd.DataFrame(columns=['repeat', 'fold', 'model_training_data_path'] + self.metric_names)
    
    def evaluate(self, y_true, y_prob, y_pred, model_training_data_path=None):
        entry = {
            'repeat': self.i // self.folds,
            'fold': self.i % self.folds,
            'model_training_data_path': model_training_data_path
        }
        for metric, metric_name in zip(self.metrics, self.metric_names):
            try:
                entry[metric_name] = metric(y_true, y_prob, y_pred)
            except ValueError as e:
                print("Warning: ", e)
                entry[metric_name] = np.nan
        
        self.results = pd.concat([self.results, pd.DataFrame(entry, index=[0])], ignore_index=True)
        self.i += 1
    
    def store(self, filename):
        self.results.to_csv(filename, index=False)

    def current(self, metric_name):
        return self.results[metric_name].iloc[-1]

### Load Data

In [30]:
merged_df = pd.read_csv(MERGED_PATH)
merged_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3052 entries, 0 to 3051
Columns: 605 entries, inchi_key to embedding_299
dtypes: float64(300), int64(302), object(3)
memory usage: 14.1+ MB


In [31]:
f'Number of NaNs: {merged_df.isna().sum().sum()}' # should be 0

'Number of NaNs: 0'

In [32]:
X = merged_df.filter(regex='^embedding_')
y = merged_df.filter(regex='^fingerprint_')
X.shape, y.shape

((3052, 300), (3052, 302))

In [33]:
X = X.to_numpy()
y = y.to_numpy()

### Train- K-fold Cross Validation

In [34]:
REPEATS = 2
K = 5
metrics = Metrics(METRICS, METRIC_NAMES, REPEATS, K)

for i in tqdm(range(REPEATS), desc="Repeats"):
    kf = KFold(n_splits=K, shuffle=True, random_state=RANDOM_STATE + i)

    for fold, (train_index, test_index) in tqdm(enumerate(kf.split(X, y)), desc="Fold", total=K):
        # train
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        MODEL.fit(X_train, y_train)

        # predict
        y_pred = MODEL.predict(X_test)
        y_prob = MODEL.predict_proba(X_test)

        # store train data
        model_training_data_path = f'{MODEL_OUTPUT_FOLDER}/models/{i}_{fold}.pkl'
        with open(model_training_data_path, "wb") as f:
            pickle.dump({
                "model": MODEL,
                "X_train": X_train,
                "y_train": y_train,
                "X_test": X_test,
                "y_test": y_test,
            }, f)

        # evaluate
        metrics.evaluate(y_test, y_prob, y_pred, model_training_data_path=model_training_data_path)

        # display current results
        print('Label ranking loss: ', metrics.current('label_ranking_loss'))
        print('F1 Weighted: ', metrics.current('f1_score__weighted'))
        
metrics.store(f'{MODEL_OUTPUT_FOLDER}/metrics.csv')

Repeats:   0%|          | 0/2 [00:00<?, ?it/s]

Fold:   0%|          | 0/5 [00:00<?, ?it/s]



Label ranking loss:  0.026175160353849754
F1 Weighted:  0.7896648213123029




Label ranking loss:  0.026880317620413856
F1 Weighted:  0.7938960405467362




Label ranking loss:  0.028065997086972935
F1 Weighted:  0.781986985310107




Label ranking loss:  0.029446930962815775
F1 Weighted:  0.7722656806742612




Label ranking loss:  0.02794301886284499
F1 Weighted:  0.7844468723459775


Fold:   0%|          | 0/5 [00:00<?, ?it/s]



Label ranking loss:  0.02810112942551758
F1 Weighted:  0.7850888587863772




Label ranking loss:  0.030332323709172392
F1 Weighted:  0.7757286839576808




Label ranking loss:  0.02748394843916167
F1 Weighted:  0.7768668306226526




Label ranking loss:  0.02694477230083883
F1 Weighted:  0.7878886629074949




Label ranking loss:  0.026739617263915254
F1 Weighted:  0.7886796425034566


In [35]:
metrics.results.describe()

Unnamed: 0,accuracy_score,log_loss,hamming_loss,f1_score__micro,f1_score__macro,f1_score__weighted,f1_score__samples,precision_score__micro,precision_score__macro,precision_score__weighted,...,recall_score__macro,recall_score__weighted,recall_score__samples,jaccard_score__micro,jaccard_score__macro,jaccard_score__weighted,jaccard_score__samples,roc_auc_score,label_ranking_loss,coverage_error
count,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,...,10.0,10.0,10.0,10.0,10.0,10.0,10.0,0.0,10.0,10.0
mean,0.057663,959.839086,0.078278,0.829298,0.476334,0.783651,0.824041,0.909509,0.760452,0.902192,...,0.410009,0.762103,0.773223,0.708403,0.385049,0.700764,0.726037,,0.027811,120.599675
std,0.011842,20.212051,0.002141,0.00489,0.011774,0.006898,0.005388,0.004758,0.014127,0.005733,...,0.010948,0.006019,0.006131,0.007126,0.011182,0.007755,0.007437,,0.00128,1.183363
min,0.039344,929.760977,0.075053,0.820929,0.452812,0.772266,0.814058,0.904351,0.73891,0.892998,...,0.387391,0.751304,0.761102,0.69625,0.362855,0.687403,0.712257,,0.026175,118.166939
25%,0.04959,943.656402,0.077647,0.82616,0.468639,0.778147,0.821719,0.904903,0.748556,0.89842,...,0.403359,0.758132,0.770066,0.703812,0.377279,0.696005,0.722166,,0.026896,120.045104
50%,0.05892,959.005387,0.077832,0.829837,0.47837,0.784768,0.823454,0.90873,0.760963,0.902851,...,0.413239,0.763476,0.77416,0.709164,0.388071,0.701085,0.726125,,0.027713,120.728689
75%,0.065083,971.252413,0.078712,0.83305,0.485408,0.788482,0.828309,0.913981,0.772056,0.907411,...,0.417883,0.765063,0.776547,0.71387,0.392936,0.70674,0.730787,,0.028092,121.166794
max,0.07541,989.254358,0.081837,0.836302,0.492345,0.793896,0.831907,0.916289,0.777547,0.909168,...,0.424183,0.771812,0.783405,0.718659,0.39982,0.711735,0.737755,,0.030332,122.618033


### Train With Unseen InChI Keys

In [36]:
def split_dataset(X, y, test_inchi_keys=[]):
    # get index from merged_df
    test_index = merged_df[merged_df['inchi_key'].isin(test_inchi_keys)].index
    train_index = merged_df[~merged_df['inchi_key'].isin(test_inchi_keys)].index

    # split X and y
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

    return X_train, X_test, y_train, y_test

In [37]:
all_inchi_keys = list(merged_df['inchi_key'].unique())
shuffle(all_inchi_keys)

In [38]:
hidden_inchi_keys = 10

REPEATS = 1
K = ceil(len(all_inchi_keys) / hidden_inchi_keys)
metrics = Metrics(METRICS, METRIC_NAMES, REPEATS, K)

for i in tqdm(range(REPEATS), desc="Repeats"):
    # Reshuffle
    shuffle(all_inchi_keys)

    for end_i in tqdm(range(hidden_inchi_keys, len(all_inchi_keys), hidden_inchi_keys), desc="Fold", total=K):
        start_i = end_i - hidden_inchi_keys
        if end_i + hidden_inchi_keys > len(all_inchi_keys):
            end_i = len(all_inchi_keys)

        # train
        test_inchi_keys = all_inchi_keys[start_i:end_i]
        X_train, X_test, y_train, y_test = split_dataset(X, y, test_inchi_keys)

        MODEL.fit(X_train, y_train)

        # predict
        y_pred = MODEL.predict(X_test)
        y_prob = MODEL.predict_proba(X_test)

        # store train data
        model_training_data_path = f'{MODEL_OUTPUT_FOLDER}/unseen_inchi_keys_models/{start_i}_{end_i}.pkl'
        with open(model_training_data_path, "wb") as f:
            pickle.dump({
                "model": MODEL,
                "X_train": X_train,
                "y_train": y_train,
                "X_test": X_test,
                "y_test": y_test,
            }, f)

        # evaluate
        metrics.evaluate(y_test, y_prob, y_pred, model_training_data_path=model_training_data_path)

        # display current results
        print('Label ranking loss: ', metrics.current('label_ranking_loss'))
        print('F1 Weighted: ', metrics.current('f1_score__weighted'))

metrics.store(f'{MODEL_OUTPUT_FOLDER}/unseen_inchi_keys_metrics.csv')

Repeats:   0%|          | 0/1 [00:00<?, ?it/s]

Fold:   0%|          | 0/11 [00:00<?, ?it/s]



Label ranking loss:  0.10723528356252408
F1 Weighted:  0.6176084993136549


In [None]:
metrics.results.describe()

Unnamed: 0,accuracy_score,log_loss,hamming_loss,f1_score__micro,f1_score__macro,f1_score__weighted,f1_score__samples,precision_score__micro,precision_score__macro,precision_score__weighted,...,recall_score__macro,recall_score__weighted,recall_score__samples,jaccard_score__micro,jaccard_score__macro,jaccard_score__weighted,jaccard_score__samples,roc_auc_score,label_ranking_loss,coverage_error
count,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,...,6.0,6.0,6.0,6.0,6.0,6.0,6.0,0.0,6.0,6.0
mean,0.033491,426.598018,0.074979,0.806446,0.247045,0.755934,0.812042,0.870754,0.304216,0.8103,...,0.237586,0.752615,0.7716,0.677001,0.211662,0.697851,0.703763,,0.043242,72.565061
std,0.031992,95.871002,0.016596,0.036732,0.013885,0.052316,0.034382,0.023611,0.037554,0.045172,...,0.014607,0.057674,0.053199,0.052005,0.013495,0.054956,0.047988,,0.013545,8.173327
min,0.0,307.650957,0.055945,0.767443,0.224523,0.69672,0.772295,0.842862,0.256751,0.750484,...,0.222172,0.684381,0.699187,0.622643,0.193877,0.638442,0.645287,,0.023639,58.74031
25%,0.006715,352.63903,0.061487,0.777293,0.241947,0.713578,0.78862,0.851319,0.277501,0.785116,...,0.225016,0.712182,0.742136,0.635906,0.202705,0.652776,0.674843,,0.037868,70.57617
50%,0.031642,425.280883,0.07495,0.800676,0.248618,0.751705,0.804197,0.87178,0.302393,0.807925,...,0.236053,0.742938,0.763655,0.667607,0.210791,0.690422,0.691723,,0.040428,74.189656
75%,0.062933,504.381293,0.08767,0.837234,0.254165,0.798983,0.839447,0.888815,0.334139,0.831494,...,0.250278,0.793886,0.814371,0.720572,0.222537,0.744007,0.742109,,0.054106,75.061716
max,0.065891,542.202213,0.09512,0.85096,0.264776,0.819862,0.856939,0.899107,0.349833,0.878601,...,0.254885,0.832636,0.836861,0.740583,0.22805,0.765723,0.765722,,0.059277,83.546667
