# This is Step 3 in the Pipeline - Training ML Prediction Model
With this notebook we can train various ML classifiers to tackle multi-lable prediction problem. We are predicting Spec2Vec embeddings from molecular fingerprints.

### Imports

In [1]:
from sklearn.metrics import accuracy_score, f1_score, log_loss, precision_score, recall_score, jaccard_score, roc_auc_score, hamming_loss, label_ranking_loss, coverage_error
from sklearn.model_selection import KFold
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import  ClassifierChain
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegression
from mass_spectra.similarity_voting import SimilarityVoting
from wrappers.nn import NN
from wrappers.catboost import CatBoost
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import pickle
from random import shuffle, seed
from math import ceil
import os
from torch.nn import BCEWithLogitsLoss

### Parameters

In [2]:
RANDOM_STATE = 27082023
seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)

# path to merged fingerprint and embedding data (fingerprint columns should be prefixed with 'fingerprint_' and embedding columns should be prefixed with 'embedding_').
MERGED_PATH = './source/embedding/all_positive_tms_maccs/merged.csv'
MODEL_OUTPUT_FOLDER = "./source/model/all_positive_tms_maccs/"

In [3]:
assert os.path.isfile(MERGED_PATH)
assert os.path.isdir(MODEL_OUTPUT_FOLDER)
assert MERGED_PATH.endswith('.csv')

In [4]:
ESTIMATOR = None

In [5]:
MODEL = CatBoost(iterations=100, allow_const_label=True)

In [6]:
MODEL_CLASS = MODEL.__class__.__name__
ESTIMATOR_CLASS = ESTIMATOR.__class__.__name__ if ESTIMATOR is not None else 'Multioutput'
MODEL_OUTPUT_FOLDER = f'{MODEL_OUTPUT_FOLDER}{MODEL_CLASS}_{ESTIMATOR_CLASS}'
os.makedirs(f'{MODEL_OUTPUT_FOLDER}/models', exist_ok=False)
os.makedirs(f'{MODEL_OUTPUT_FOLDER}/unseen_inchi_keys_models', exist_ok=False)

### Metrics Definition
Creates metrics which can be called with (y_true, y_prob, y_pred) for easier use. It also creates multiple combinations of metrics for different averaging methods.

In [7]:
Y_PRED_SCORES = [accuracy_score, log_loss, hamming_loss] # input y predictions and y true
Y_PRED_SCORES_WITH_AVERAGING = [f1_score, precision_score, recall_score, jaccard_score] # input y predictions and y true and use one of the following: "micro", "macro", "weighted", "samples"
Y_PROB_SCORES = [roc_auc_score, label_ranking_loss, coverage_error] # input y probabilities and y true

In [8]:
METRICS = []
METRIC_NAMES = []
for metric in Y_PRED_SCORES:
    METRICS.append(lambda y_true, y_prob, y_pred, metric=metric: metric(y_true, y_pred))
    METRIC_NAMES.append(metric.__name__)
for metric in Y_PRED_SCORES_WITH_AVERAGING:
    for average in ["micro", "macro", "weighted", "samples"]:
        zero_division = 0 if metric.__name__ == "jaccard_score" else np.nan
        METRICS.append(lambda y_true, y_prob, y_pred, metric=metric, average=average: metric(y_true, y_pred, average=average, zero_division=zero_division))
        METRIC_NAMES.append(metric.__name__ + "__" + average)
for metric in Y_PROB_SCORES:
    METRICS.append(lambda y_true, y_prob, y_pred, metric=metric: metric(y_true, y_prob))
    METRIC_NAMES.append(metric.__name__)

In [9]:
class Metrics:
    def __init__(self, metrics, metric_names, repeats=2, folds=5):
        self.metrics = metrics
        self.metric_names = metric_names
        
        self.repeats = repeats
        self.folds = folds
        self.i = 0

        self.results = pd.DataFrame(columns=['repeat', 'fold', 'model_training_data_path'] + self.metric_names)
    
    def evaluate(self, y_true, y_prob, y_pred, model_training_data_path=None):
        entry = {
            'repeat': self.i // self.folds,
            'fold': self.i % self.folds,
            'model_training_data_path': model_training_data_path
        }
        for metric, metric_name in zip(self.metrics, self.metric_names):
            try:
                entry[metric_name] = metric(y_true, y_prob, y_pred)
            except ValueError as e:
                print("Warning: ", e)
                entry[metric_name] = np.nan
        
        self.results = pd.concat([self.results, pd.DataFrame(entry, index=[0])], ignore_index=True)
        self.i += 1
    
    def store(self, filename):
        self.results.to_csv(filename, index=False)

    def current(self, metric_name):
        return self.results[metric_name].iloc[-1]

### Load Data

In [10]:
merged_df = pd.read_csv(MERGED_PATH)
merged_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3025 entries, 0 to 3024
Columns: 467 entries, inchi_key to embedding_299
dtypes: float64(466), object(1)
memory usage: 10.8+ MB


In [11]:
f'Number of NaNs: {merged_df.isna().sum().sum()}' # should be 0

'Number of NaNs: 0'

In [12]:
X = merged_df.filter(regex='^embedding_')
y = merged_df.filter(regex='^fingerprint_')
X.shape, y.shape

((3025, 300), (3025, 166))

In [13]:
X = X.to_numpy()
y = y.to_numpy()

### Train- K-fold Cross Validation

In [14]:
REPEATS = 2
K = 5
metrics = Metrics(METRICS, METRIC_NAMES, REPEATS, K)

for i in tqdm(range(REPEATS), desc="Repeats"):
    kf = KFold(n_splits=K, shuffle=True, random_state=RANDOM_STATE + i)

    for fold, (train_index, test_index) in tqdm(enumerate(kf.split(X, y)), desc="Fold", total=K):
        # train
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        MODEL.fit(X_train, y_train)

        # predict
        y_pred = MODEL.predict(X_test)
        y_prob = MODEL.predict_proba(X_test)

        # store train data
        model_training_data_path = f'{MODEL_OUTPUT_FOLDER}/models/{i}_{fold}.pkl'
        with open(model_training_data_path, "wb") as f:
            pickle.dump({
                "model": MODEL,
                "X_train": X_train,
                "y_train": y_train,
                "X_test": X_test,
                "y_test": y_test,
            }, f)

        # evaluate
        metrics.evaluate(y_test, y_prob, y_pred, model_training_data_path=model_training_data_path)

        # display current results
        print('Label ranking loss: ', metrics.current('label_ranking_loss'))
        print('F1 Weighted: ', metrics.current('f1_score__weighted'))
        
metrics.store(f'{MODEL_OUTPUT_FOLDER}/metrics.csv')

Repeats:   0%|          | 0/2 [00:00<?, ?it/s]

Fold:   0%|          | 0/5 [00:00<?, ?it/s]

Learning rate set to 0.124115
0:	learn: 0.5119835	total: 6.76s	remaining: 11m 9s
1:	learn: 0.3996997	total: 12.9s	remaining: 10m 30s
2:	learn: 0.3281198	total: 19.3s	remaining: 10m 25s
3:	learn: 0.2819886	total: 25.9s	remaining: 10m 20s
4:	learn: 0.2467844	total: 32.3s	remaining: 10m 13s
5:	learn: 0.2220733	total: 38.4s	remaining: 10m 1s
6:	learn: 0.2058443	total: 44.7s	remaining: 9m 53s
7:	learn: 0.1935400	total: 51.5s	remaining: 9m 52s
8:	learn: 0.1828827	total: 58.4s	remaining: 9m 50s
9:	learn: 0.1733197	total: 1m 7s	remaining: 10m 5s
10:	learn: 0.1666126	total: 1m 17s	remaining: 10m 27s
11:	learn: 0.1605774	total: 1m 25s	remaining: 10m 28s
12:	learn: 0.1558873	total: 1m 36s	remaining: 10m 42s
13:	learn: 0.1516093	total: 1m 44s	remaining: 10m 44s
14:	learn: 0.1468923	total: 1m 52s	remaining: 10m 36s
15:	learn: 0.1437709	total: 1m 59s	remaining: 10m 28s
16:	learn: 0.1399004	total: 2m 6s	remaining: 10m 19s
17:	learn: 0.1365069	total: 2m 13s	remaining: 10m 8s
18:	learn: 0.1328161	total



Label ranking loss:  0.009107227832156792
F1 Weighted:  0.8649028425403413
Learning rate set to 0.124115
0:	learn: 0.5131097	total: 6.91s	remaining: 11m 24s
1:	learn: 0.4046390	total: 13.7s	remaining: 11m 11s
2:	learn: 0.3324120	total: 21s	remaining: 11m 19s
3:	learn: 0.2863485	total: 28.6s	remaining: 11m 25s
4:	learn: 0.2525065	total: 36.5s	remaining: 11m 33s
5:	learn: 0.2268130	total: 44.2s	remaining: 11m 32s
6:	learn: 0.2098868	total: 51.8s	remaining: 11m 27s
7:	learn: 0.1975591	total: 59.1s	remaining: 11m 20s
8:	learn: 0.1869558	total: 1m 6s	remaining: 11m 13s
9:	learn: 0.1775633	total: 1m 13s	remaining: 11m 2s
10:	learn: 0.1700447	total: 1m 20s	remaining: 10m 55s
11:	learn: 0.1644592	total: 1m 27s	remaining: 10m 45s
12:	learn: 0.1596181	total: 1m 34s	remaining: 10m 34s
13:	learn: 0.1550791	total: 1m 41s	remaining: 10m 24s
14:	learn: 0.1511314	total: 1m 48s	remaining: 10m 14s
15:	learn: 0.1474118	total: 1m 55s	remaining: 10m 5s
16:	learn: 0.1437536	total: 2m 2s	remaining: 9m 56s
17



Label ranking loss:  0.007989597192506437
F1 Weighted:  0.885984749680835
Learning rate set to 0.124115
0:	learn: 0.5123744	total: 7.04s	remaining: 11m 37s
1:	learn: 0.4015340	total: 14.6s	remaining: 11m 56s
2:	learn: 0.3294900	total: 22.1s	remaining: 11m 55s
3:	learn: 0.2829959	total: 29.6s	remaining: 11m 50s
4:	learn: 0.2487727	total: 36.5s	remaining: 11m 33s
5:	learn: 0.2230835	total: 43.4s	remaining: 11m 20s
6:	learn: 0.2066590	total: 50.3s	remaining: 11m 7s
7:	learn: 0.1945422	total: 57.4s	remaining: 11m
8:	learn: 0.1845809	total: 1m 5s	remaining: 10m 57s
9:	learn: 0.1749276	total: 1m 12s	remaining: 10m 50s
10:	learn: 0.1674732	total: 1m 18s	remaining: 10m 37s
11:	learn: 0.1616008	total: 1m 25s	remaining: 10m 26s
12:	learn: 0.1566391	total: 1m 32s	remaining: 10m 15s
13:	learn: 0.1526669	total: 1m 38s	remaining: 10m 5s
14:	learn: 0.1483820	total: 1m 45s	remaining: 9m 55s
15:	learn: 0.1447996	total: 1m 51s	remaining: 9m 47s
16:	learn: 0.1409475	total: 1m 59s	remaining: 9m 44s
17:	le



Label ranking loss:  0.009291293454118643
F1 Weighted:  0.8701072455608838
Learning rate set to 0.124115
0:	learn: 0.5115950	total: 6.67s	remaining: 11m
1:	learn: 0.4004870	total: 13.3s	remaining: 10m 50s
2:	learn: 0.3289179	total: 19.9s	remaining: 10m 44s
3:	learn: 0.2824204	total: 26.7s	remaining: 10m 41s
4:	learn: 0.2477224	total: 33.7s	remaining: 10m 39s
5:	learn: 0.2220317	total: 40.4s	remaining: 10m 32s
6:	learn: 0.2051360	total: 47s	remaining: 10m 24s
7:	learn: 0.1930851	total: 53.8s	remaining: 10m 18s
8:	learn: 0.1825066	total: 1m	remaining: 10m 15s
9:	learn: 0.1730303	total: 1m 8s	remaining: 10m 12s
10:	learn: 0.1658904	total: 1m 14s	remaining: 10m 6s
11:	learn: 0.1601777	total: 1m 21s	remaining: 9m 59s
12:	learn: 0.1554621	total: 1m 28s	remaining: 9m 51s
13:	learn: 0.1508236	total: 1m 36s	remaining: 9m 52s
14:	learn: 0.1468019	total: 1m 43s	remaining: 9m 48s
15:	learn: 0.1435831	total: 1m 50s	remaining: 9m 42s
16:	learn: 0.1395134	total: 1m 58s	remaining: 9m 38s
17:	learn: 0.



Label ranking loss:  0.009640741090173028
F1 Weighted:  0.8648791851841343
Learning rate set to 0.124115
0:	learn: 0.5131423	total: 8.3s	remaining: 13m 41s
1:	learn: 0.4025137	total: 16.2s	remaining: 13m 12s
2:	learn: 0.3307471	total: 23.1s	remaining: 12m 25s
3:	learn: 0.2846802	total: 29.8s	remaining: 11m 56s
4:	learn: 0.2507740	total: 36.4s	remaining: 11m 31s
5:	learn: 0.2258420	total: 43.9s	remaining: 11m 27s
6:	learn: 0.2091139	total: 50.7s	remaining: 11m 14s
7:	learn: 0.1962521	total: 57.6s	remaining: 11m 2s
8:	learn: 0.1865607	total: 1m 4s	remaining: 10m 50s
9:	learn: 0.1775282	total: 1m 12s	remaining: 10m 52s
10:	learn: 0.1700900	total: 1m 19s	remaining: 10m 41s
11:	learn: 0.1639991	total: 1m 26s	remaining: 10m 31s
12:	learn: 0.1588539	total: 1m 32s	remaining: 10m 22s
13:	learn: 0.1548327	total: 1m 40s	remaining: 10m 15s
14:	learn: 0.1506815	total: 1m 46s	remaining: 10m 6s
15:	learn: 0.1467186	total: 1m 53s	remaining: 9m 58s
16:	learn: 0.1418420	total: 2m 1s	remaining: 9m 50s
17



Label ranking loss:  0.00802343408325835
F1 Weighted:  0.885696415034728


Fold:   0%|          | 0/5 [00:00<?, ?it/s]

Learning rate set to 0.124115
0:	learn: 0.5118227	total: 7.09s	remaining: 11m 42s
1:	learn: 0.3963365	total: 14s	remaining: 11m 27s
2:	learn: 0.3261911	total: 21s	remaining: 11m 18s
3:	learn: 0.2807094	total: 28.1s	remaining: 11m 14s
4:	learn: 0.2466846	total: 35s	remaining: 11m 5s
5:	learn: 0.2216719	total: 42s	remaining: 10m 57s
6:	learn: 0.2049878	total: 49s	remaining: 10m 51s
7:	learn: 0.1930698	total: 56s	remaining: 10m 43s
8:	learn: 0.1823530	total: 1m 3s	remaining: 10m 37s
9:	learn: 0.1725193	total: 1m 10s	remaining: 10m 30s
10:	learn: 0.1652357	total: 1m 17s	remaining: 10m 23s
11:	learn: 0.1595661	total: 1m 23s	remaining: 10m 15s
12:	learn: 0.1548811	total: 1m 30s	remaining: 10m 8s
13:	learn: 0.1504276	total: 1m 37s	remaining: 10m 1s
14:	learn: 0.1462721	total: 1m 45s	remaining: 10m
15:	learn: 0.1426459	total: 1m 53s	remaining: 9m 56s
16:	learn: 0.1389903	total: 2m 1s	remaining: 9m 53s
17:	learn: 0.1354725	total: 2m 9s	remaining: 9m 49s
18:	learn: 0.1326030	total: 2m 16s	remain



Label ranking loss:  0.010312487189982351
F1 Weighted:  0.8653728766213252
Learning rate set to 0.124115
0:	learn: 0.5134153	total: 7.26s	remaining: 11m 58s
1:	learn: 0.4031706	total: 14.4s	remaining: 11m 44s
2:	learn: 0.3317071	total: 21.6s	remaining: 11m 38s
3:	learn: 0.2851470	total: 29.1s	remaining: 11m 37s
4:	learn: 0.2515907	total: 36s	remaining: 11m 23s
5:	learn: 0.2255442	total: 42.9s	remaining: 11m 11s
6:	learn: 0.2091624	total: 49.9s	remaining: 11m 2s
7:	learn: 0.1963311	total: 56.8s	remaining: 10m 53s
8:	learn: 0.1852613	total: 1m 3s	remaining: 10m 45s
9:	learn: 0.1753920	total: 1m 10s	remaining: 10m 37s
10:	learn: 0.1678123	total: 1m 17s	remaining: 10m 30s
11:	learn: 0.1620226	total: 1m 24s	remaining: 10m 21s
12:	learn: 0.1575459	total: 1m 31s	remaining: 10m 14s
13:	learn: 0.1533327	total: 1m 38s	remaining: 10m 6s
14:	learn: 0.1492133	total: 1m 45s	remaining: 9m 59s
15:	learn: 0.1458765	total: 1m 52s	remaining: 9m 51s
16:	learn: 0.1424874	total: 1m 59s	remaining: 9m 43s
17:



Label ranking loss:  0.0077330557569079
F1 Weighted:  0.8878411787723414
Learning rate set to 0.124115
0:	learn: 0.5134745	total: 6.92s	remaining: 11m 25s
1:	learn: 0.4025342	total: 13.8s	remaining: 11m 17s
2:	learn: 0.3309599	total: 20.8s	remaining: 11m 11s
3:	learn: 0.2848743	total: 27.8s	remaining: 11m 7s
4:	learn: 0.2500226	total: 34.8s	remaining: 11m
5:	learn: 0.2246165	total: 41.6s	remaining: 10m 51s
6:	learn: 0.2082770	total: 48.5s	remaining: 10m 44s
7:	learn: 0.1962178	total: 55.5s	remaining: 10m 38s
8:	learn: 0.1857203	total: 1m 2s	remaining: 10m 31s
9:	learn: 0.1767835	total: 1m 9s	remaining: 10m 24s
10:	learn: 0.1693410	total: 1m 16s	remaining: 10m 18s
11:	learn: 0.1637000	total: 1m 23s	remaining: 10m 10s
12:	learn: 0.1589201	total: 1m 30s	remaining: 10m 4s
13:	learn: 0.1545948	total: 1m 37s	remaining: 9m 57s
14:	learn: 0.1509185	total: 1m 44s	remaining: 9m 50s
15:	learn: 0.1482628	total: 1m 51s	remaining: 9m 43s
16:	learn: 0.1437587	total: 1m 58s	remaining: 9m 36s
17:	learn



Label ranking loss:  0.007716336452823894
F1 Weighted:  0.8880707819100604
Learning rate set to 0.124115
0:	learn: 0.5127000	total: 6.89s	remaining: 11m 22s
1:	learn: 0.4027229	total: 13.8s	remaining: 11m 18s
2:	learn: 0.3295689	total: 20.8s	remaining: 11m 13s
3:	learn: 0.2826225	total: 27.7s	remaining: 11m 5s
4:	learn: 0.2492463	total: 34.7s	remaining: 10m 58s
5:	learn: 0.2242387	total: 41.5s	remaining: 10m 50s
6:	learn: 0.2081083	total: 48.5s	remaining: 10m 43s
7:	learn: 0.1956405	total: 55.5s	remaining: 10m 38s
8:	learn: 0.1858071	total: 1m 2s	remaining: 10m 31s
9:	learn: 0.1757784	total: 1m 9s	remaining: 10m 26s
10:	learn: 0.1687124	total: 1m 16s	remaining: 10m 18s
11:	learn: 0.1626991	total: 1m 23s	remaining: 10m 11s
12:	learn: 0.1575309	total: 1m 30s	remaining: 10m 5s
13:	learn: 0.1529525	total: 1m 37s	remaining: 9m 58s
14:	learn: 0.1486449	total: 1m 44s	remaining: 9m 51s
15:	learn: 0.1453339	total: 1m 51s	remaining: 9m 44s
16:	learn: 0.1410365	total: 1m 58s	remaining: 9m 37s
17:



Label ranking loss:  0.009009547001448164
F1 Weighted:  0.8679273060920888
Learning rate set to 0.124115
0:	learn: 0.5128421	total: 6.62s	remaining: 10m 55s
1:	learn: 0.4013589	total: 13.5s	remaining: 11m 2s
2:	learn: 0.3296065	total: 20.6s	remaining: 11m 4s
3:	learn: 0.2832820	total: 27.4s	remaining: 10m 58s
4:	learn: 0.2489850	total: 34s	remaining: 10m 46s
5:	learn: 0.2233727	total: 40.7s	remaining: 10m 38s
6:	learn: 0.2068102	total: 47.4s	remaining: 10m 29s
7:	learn: 0.1947329	total: 54.2s	remaining: 10m 22s
8:	learn: 0.1838874	total: 1m 1s	remaining: 10m 19s
9:	learn: 0.1751730	total: 1m 8s	remaining: 10m 13s
10:	learn: 0.1677149	total: 1m 15s	remaining: 10m 7s
11:	learn: 0.1616625	total: 1m 21s	remaining: 10m
12:	learn: 0.1568865	total: 1m 28s	remaining: 9m 54s
13:	learn: 0.1525818	total: 1m 35s	remaining: 9m 46s
14:	learn: 0.1486434	total: 1m 43s	remaining: 9m 46s
15:	learn: 0.1448497	total: 1m 50s	remaining: 9m 40s
16:	learn: 0.1410470	total: 1m 57s	remaining: 9m 35s
17:	learn: 



Label ranking loss:  0.008721217112792551
F1 Weighted:  0.8736348278195384


In [15]:
metrics.results.describe()

Unnamed: 0,accuracy_score,log_loss,hamming_loss,f1_score__micro,f1_score__macro,f1_score__weighted,f1_score__samples,precision_score__micro,precision_score__macro,precision_score__weighted,...,recall_score__macro,recall_score__weighted,recall_score__samples,jaccard_score__micro,jaccard_score__macro,jaccard_score__weighted,jaccard_score__samples,roc_auc_score,label_ranking_loss,coverage_error
count,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,...,10.0,10.0,10.0,10.0,10.0,10.0,10.0,0.0,10.0,10.0
mean,0.184463,305.358345,0.040671,0.897284,0.457059,0.875442,0.902902,0.948659,0.676562,0.944401,...,0.392693,0.851247,0.87075,0.81377,0.375821,0.816391,0.83832,,0.008754,42.268926
std,0.014393,17.704378,0.002974,0.007059,0.018273,0.010229,0.005904,0.003101,0.020782,0.004278,...,0.015736,0.011615,0.008499,0.011624,0.015243,0.011865,0.008376,,0.000878,0.714485
min,0.163636,280.698852,0.036642,0.889022,0.436872,0.864879,0.894734,0.944679,0.640796,0.937597,...,0.374136,0.837489,0.857865,0.800215,0.356777,0.803718,0.82541,,0.007716,40.895868
25%,0.171074,287.591319,0.037992,0.891701,0.442535,0.866011,0.898353,0.946205,0.662173,0.942145,...,0.380577,0.841476,0.865153,0.804566,0.365026,0.806265,0.832112,,0.007998,41.816529
50%,0.187603,312.447564,0.041681,0.895067,0.45205,0.871871,0.901274,0.948329,0.6819,0.943894,...,0.388188,0.846964,0.868051,0.810072,0.371722,0.812583,0.836903,,0.008865,42.522314
75%,0.192149,317.704652,0.04305,0.904396,0.471277,0.885913,0.908528,0.951029,0.688169,0.94824,...,0.404864,0.863637,0.87868,0.825477,0.38777,0.828416,0.845351,,0.009245,42.705785
max,0.204959,327.60249,0.044339,0.906218,0.489024,0.888071,0.91059,0.953146,0.711774,0.949795,...,0.42216,0.866728,0.882025,0.828518,0.402686,0.831708,0.849862,,0.010312,43.219835


### Train With Unseen InChI Keys

In [16]:
def split_dataset(X, y, test_inchi_keys=[]):
    # get index from merged_df
    test_index = merged_df[merged_df['inchi_key'].isin(test_inchi_keys)].index
    train_index = merged_df[~merged_df['inchi_key'].isin(test_inchi_keys)].index

    # split X and y
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

    return X_train, X_test, y_train, y_test

In [17]:
all_inchi_keys = list(merged_df['inchi_key'].unique())
shuffle(all_inchi_keys)

In [18]:
hidden_inchi_keys = 10

REPEATS = 1
K = ceil(len(all_inchi_keys) / hidden_inchi_keys)
metrics = Metrics(METRICS, METRIC_NAMES, REPEATS, K)

for i in tqdm(range(REPEATS), desc="Repeats"):
    # Reshuffle
    shuffle(all_inchi_keys)

    for end_i in tqdm(range(hidden_inchi_keys, len(all_inchi_keys), hidden_inchi_keys), desc="Fold", total=K):
        start_i = end_i - hidden_inchi_keys
        if end_i + hidden_inchi_keys > len(all_inchi_keys):
            end_i = len(all_inchi_keys)

        # train
        test_inchi_keys = all_inchi_keys[start_i:end_i]
        X_train, X_test, y_train, y_test = split_dataset(X, y, test_inchi_keys)

        MODEL.fit(X_train, y_train)

        # predict
        y_pred = MODEL.predict(X_test)
        y_prob = MODEL.predict_proba(X_test)

        # store train data
        model_training_data_path = f'{MODEL_OUTPUT_FOLDER}/unseen_inchi_keys_models/{start_i}_{end_i}.pkl'
        with open(model_training_data_path, "wb") as f:
            pickle.dump({
                "model": MODEL,
                "X_train": X_train,
                "y_train": y_train,
                "X_test": X_test,
                "y_test": y_test,
            }, f)

        # evaluate
        metrics.evaluate(y_test, y_prob, y_pred, model_training_data_path=model_training_data_path)

        # display current results
        print('Label ranking loss: ', metrics.current('label_ranking_loss'))
        print('F1 Weighted: ', metrics.current('f1_score__weighted'))

metrics.store(f'{MODEL_OUTPUT_FOLDER}/unseen_inchi_keys_metrics.csv')

Repeats:   0%|          | 0/1 [00:00<?, ?it/s]

Fold:   0%|          | 0/11 [00:00<?, ?it/s]

Learning rate set to 0.130527
0:	learn: 0.5047740	total: 6.77s	remaining: 11m 10s
1:	learn: 0.3913612	total: 13.5s	remaining: 11m 1s
2:	learn: 0.3206733	total: 20.3s	remaining: 10m 55s
3:	learn: 0.2751381	total: 27.1s	remaining: 10m 50s
4:	learn: 0.2428783	total: 33.8s	remaining: 10m 42s
5:	learn: 0.2190464	total: 40.5s	remaining: 10m 35s
6:	learn: 0.2038517	total: 47.3s	remaining: 10m 28s
7:	learn: 0.1918911	total: 54.1s	remaining: 10m 22s
8:	learn: 0.1817520	total: 1m	remaining: 10m 15s
9:	learn: 0.1716021	total: 1m 7s	remaining: 10m 9s
10:	learn: 0.1646468	total: 1m 14s	remaining: 10m 3s
11:	learn: 0.1591259	total: 1m 21s	remaining: 9m 56s
12:	learn: 0.1540795	total: 1m 28s	remaining: 9m 49s
13:	learn: 0.1506712	total: 1m 34s	remaining: 9m 42s
14:	learn: 0.1461850	total: 1m 41s	remaining: 9m 35s
15:	learn: 0.1431173	total: 1m 48s	remaining: 9m 28s
16:	learn: 0.1390475	total: 1m 55s	remaining: 9m 21s
17:	learn: 0.1358993	total: 2m 1s	remaining: 9m 15s
18:	learn: 0.1328775	total: 2m 8



Label ranking loss:  0.037511012344650756
F1 Weighted:  0.8198620511094892
Learning rate set to 0.130424
0:	learn: 0.5029493	total: 7.45s	remaining: 12m 17s
1:	learn: 0.3860843	total: 14.4s	remaining: 11m 44s
2:	learn: 0.3157405	total: 21.4s	remaining: 11m 32s
3:	learn: 0.2710956	total: 28.6s	remaining: 11m 25s
4:	learn: 0.2389025	total: 35.5s	remaining: 11m 14s
5:	learn: 0.2150085	total: 43.9s	remaining: 11m 27s
6:	learn: 0.1997160	total: 54.1s	remaining: 11m 58s
7:	learn: 0.1882768	total: 1m 1s	remaining: 11m 43s
8:	learn: 0.1780475	total: 1m 8s	remaining: 11m 28s
9:	learn: 0.1690485	total: 1m 15s	remaining: 11m 15s
10:	learn: 0.1618025	total: 1m 21s	remaining: 11m 2s
11:	learn: 0.1563518	total: 1m 28s	remaining: 10m 49s
12:	learn: 0.1512829	total: 1m 35s	remaining: 10m 38s
13:	learn: 0.1471149	total: 1m 42s	remaining: 10m 27s
14:	learn: 0.1429220	total: 1m 49s	remaining: 10m 19s
15:	learn: 0.1391029	total: 1m 58s	remaining: 10m 23s
16:	learn: 0.1351726	total: 2m 5s	remaining: 10m 14



Label ranking loss:  0.04191598890058593
F1 Weighted:  0.7589765125067598
Learning rate set to 0.130527
0:	learn: 0.5017598	total: 6.67s	remaining: 11m
1:	learn: 0.3882914	total: 13.5s	remaining: 11m 1s
2:	learn: 0.3169918	total: 20.3s	remaining: 10m 54s
3:	learn: 0.2718211	total: 27.1s	remaining: 10m 51s
4:	learn: 0.2363335	total: 33.9s	remaining: 10m 43s
5:	learn: 0.2134148	total: 40.7s	remaining: 10m 37s
6:	learn: 0.1979955	total: 47.5s	remaining: 10m 31s
7:	learn: 0.1865533	total: 54.4s	remaining: 10m 25s
8:	learn: 0.1754390	total: 1m 1s	remaining: 10m 19s
9:	learn: 0.1672395	total: 1m 8s	remaining: 10m 13s
10:	learn: 0.1602774	total: 1m 14s	remaining: 10m 5s
11:	learn: 0.1549389	total: 1m 21s	remaining: 9m 58s
12:	learn: 0.1503809	total: 1m 28s	remaining: 9m 51s
13:	learn: 0.1465031	total: 1m 35s	remaining: 9m 44s
14:	learn: 0.1423939	total: 1m 41s	remaining: 9m 37s
15:	learn: 0.1389147	total: 1m 48s	remaining: 9m 30s
16:	learn: 0.1351711	total: 1m 55s	remaining: 9m 23s
17:	learn:



Label ranking loss:  0.0389392300523989
F1 Weighted:  0.7444327533930131
Learning rate set to 0.130854
0:	learn: 0.4986968	total: 7.34s	remaining: 12m 6s
1:	learn: 0.3837463	total: 14.6s	remaining: 11m 54s
2:	learn: 0.3115726	total: 23.2s	remaining: 12m 29s
3:	learn: 0.2669422	total: 30.8s	remaining: 12m 18s
4:	learn: 0.2333088	total: 38.5s	remaining: 12m 11s
5:	learn: 0.2098788	total: 45.9s	remaining: 11m 59s
6:	learn: 0.1933695	total: 53.5s	remaining: 11m 51s
7:	learn: 0.1824565	total: 1m	remaining: 11m 37s
8:	learn: 0.1723570	total: 1m 7s	remaining: 11m 23s
9:	learn: 0.1626821	total: 1m 14s	remaining: 11m 11s
10:	learn: 0.1555613	total: 1m 21s	remaining: 10m 58s
11:	learn: 0.1505468	total: 1m 28s	remaining: 10m 47s
12:	learn: 0.1459248	total: 1m 35s	remaining: 10m 36s
13:	learn: 0.1423675	total: 1m 42s	remaining: 10m 26s
14:	learn: 0.1384778	total: 1m 48s	remaining: 10m 15s
15:	learn: 0.1354506	total: 1m 55s	remaining: 10m 5s
16:	learn: 0.1314711	total: 2m 2s	remaining: 9m 56s
17:	l



Label ranking loss:  0.05927675347459425
F1 Weighted:  0.6967202747138178
Learning rate set to 0.131423
0:	learn: 0.5035272	total: 6.77s	remaining: 11m 9s
1:	learn: 0.3906330	total: 13.7s	remaining: 11m 10s
2:	learn: 0.3178679	total: 20.6s	remaining: 11m 5s
3:	learn: 0.2736718	total: 27.6s	remaining: 11m 1s
4:	learn: 0.2414570	total: 34.4s	remaining: 10m 53s
5:	learn: 0.2174530	total: 41.2s	remaining: 10m 46s
6:	learn: 0.2022816	total: 48.1s	remaining: 10m 39s
7:	learn: 0.1907593	total: 55s	remaining: 10m 32s
8:	learn: 0.1806926	total: 1m 1s	remaining: 10m 24s
9:	learn: 0.1712317	total: 1m 8s	remaining: 10m 18s
10:	learn: 0.1637407	total: 1m 15s	remaining: 10m 12s
11:	learn: 0.1585130	total: 1m 22s	remaining: 10m 5s
12:	learn: 0.1539832	total: 1m 29s	remaining: 9m 58s
13:	learn: 0.1497491	total: 1m 36s	remaining: 9m 54s
14:	learn: 0.1458219	total: 1m 43s	remaining: 9m 48s
15:	learn: 0.1423578	total: 1m 50s	remaining: 9m 40s
16:	learn: 0.1386032	total: 1m 57s	remaining: 9m 33s
17:	learn



Label ranking loss:  0.023639371044146044
F1 Weighted:  0.8123179974685273
Learning rate set to 0.130568
0:	learn: 0.4966647	total: 6.91s	remaining: 11m 24s
1:	learn: 0.3835503	total: 13.7s	remaining: 11m 11s
2:	learn: 0.3117812	total: 20.5s	remaining: 11m 3s
3:	learn: 0.2663638	total: 27.4s	remaining: 10m 57s
4:	learn: 0.2330950	total: 34.2s	remaining: 10m 49s
5:	learn: 0.2091900	total: 40.9s	remaining: 10m 40s
6:	learn: 0.1932425	total: 47.8s	remaining: 10m 34s
7:	learn: 0.1821584	total: 54.6s	remaining: 10m 27s
8:	learn: 0.1725551	total: 1m 1s	remaining: 10m 20s
9:	learn: 0.1635189	total: 1m 8s	remaining: 10m 14s
10:	learn: 0.1569192	total: 1m 15s	remaining: 10m 6s
11:	learn: 0.1513616	total: 1m 22s	remaining: 10m 1s
12:	learn: 0.1468622	total: 1m 28s	remaining: 9m 54s
13:	learn: 0.1430455	total: 1m 35s	remaining: 9m 47s
14:	learn: 0.1393464	total: 1m 42s	remaining: 9m 40s
15:	learn: 0.1363586	total: 1m 49s	remaining: 9m 33s
16:	learn: 0.1328369	total: 1m 56s	remaining: 9m 26s
17:	l



Label ranking loss:  0.05816933061178188
F1 Weighted:  0.7032928245189347
Learning rate set to 0.130956
0:	learn: 0.4987861	total: 7.07s	remaining: 11m 40s
1:	learn: 0.3794463	total: 13.9s	remaining: 11m 19s
2:	learn: 0.3094015	total: 20.8s	remaining: 11m 11s
3:	learn: 0.2650395	total: 27.7s	remaining: 11m 5s
4:	learn: 0.2323290	total: 34.5s	remaining: 10m 56s
5:	learn: 0.2086071	total: 41.3s	remaining: 10m 47s
6:	learn: 0.1934555	total: 48.7s	remaining: 10m 46s
7:	learn: 0.1817548	total: 56.3s	remaining: 10m 47s
8:	learn: 0.1717795	total: 1m 3s	remaining: 10m 46s
9:	learn: 0.1624459	total: 1m 13s	remaining: 10m 59s
10:	learn: 0.1556624	total: 1m 20s	remaining: 10m 53s
11:	learn: 0.1501078	total: 1m 29s	remaining: 10m 59s
12:	learn: 0.1460457	total: 1m 37s	remaining: 10m 54s
13:	learn: 0.1417051	total: 1m 45s	remaining: 10m 47s
14:	learn: 0.1376910	total: 1m 53s	remaining: 10m 45s
15:	learn: 0.1347543	total: 2m 1s	remaining: 10m 36s
16:	learn: 0.1311573	total: 2m 8s	remaining: 10m 28s


KeyboardInterrupt: 

In [19]:
metrics.results.describe()

Unnamed: 0,accuracy_score,log_loss,hamming_loss,f1_score__micro,f1_score__macro,f1_score__weighted,f1_score__samples,precision_score__micro,precision_score__macro,precision_score__weighted,...,recall_score__macro,recall_score__weighted,recall_score__samples,jaccard_score__micro,jaccard_score__macro,jaccard_score__weighted,jaccard_score__samples,roc_auc_score,label_ranking_loss,coverage_error
count,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,...,6.0,6.0,6.0,6.0,6.0,6.0,6.0,0.0,6.0,6.0
mean,0.033491,426.598018,0.074979,0.806446,0.247045,0.755934,0.812042,0.870754,0.304216,0.8103,...,0.237586,0.752615,0.7716,0.677001,0.211662,0.697851,0.703763,,0.043242,72.565061
std,0.031992,95.871002,0.016596,0.036732,0.013885,0.052316,0.034382,0.023611,0.037554,0.045172,...,0.014607,0.057674,0.053199,0.052005,0.013495,0.054956,0.047988,,0.013545,8.173327
min,0.0,307.650957,0.055945,0.767443,0.224523,0.69672,0.772295,0.842862,0.256751,0.750484,...,0.222172,0.684381,0.699187,0.622643,0.193877,0.638442,0.645287,,0.023639,58.74031
25%,0.006715,352.63903,0.061487,0.777293,0.241947,0.713578,0.78862,0.851319,0.277501,0.785116,...,0.225016,0.712182,0.742136,0.635906,0.202705,0.652776,0.674843,,0.037868,70.57617
50%,0.031642,425.280883,0.07495,0.800676,0.248618,0.751705,0.804197,0.87178,0.302393,0.807925,...,0.236053,0.742938,0.763655,0.667607,0.210791,0.690422,0.691723,,0.040428,74.189656
75%,0.062933,504.381293,0.08767,0.837234,0.254165,0.798983,0.839447,0.888815,0.334139,0.831494,...,0.250278,0.793886,0.814371,0.720572,0.222537,0.744007,0.742109,,0.054106,75.061716
max,0.065891,542.202213,0.09512,0.85096,0.264776,0.819862,0.856939,0.899107,0.349833,0.878601,...,0.254885,0.832636,0.836861,0.740583,0.22805,0.765723,0.765722,,0.059277,83.546667
