## This Jupyter notebook is used to predict the fingerprint of a given spec2vec embedding.

In [56]:
from sklearn.metrics import accuracy_score, f1_score, log_loss, precision_score, recall_score, jaccard_score, roc_auc_score, hamming_loss, label_ranking_loss, coverage_error
from sklearn.model_selection import KFold
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import  ClassifierChain
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.dummy import DummyClassifier
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import pickle
import os

RANDOM_STATE = 27082023
np.random.seed(RANDOM_STATE)

FINGERPRINTS_PATH = "./embeddings/tms_maccs_fingerprint.csv"
SPEC2VEC_PATH = "./embeddings/tms_spec2vec_embeddings.csv"

CLASSIFIER_NAME = "decision_tree_chain"
CLASSIFIER_OUTPUT_FOLDER = "./models/tms_final/"
os.makedirs(os.path.join(CLASSIFIER_OUTPUT_FOLDER, CLASSIFIER_NAME, "models"), exist_ok=False)

### Metrics definition and calculation

In [57]:
Y_PRED_SCORES = [accuracy_score, log_loss, hamming_loss] # input y predictions and y true
Y_PRED_SCORES_WITH_AVERAGING = [f1_score, precision_score, recall_score, jaccard_score] # input y predictions and y true and use one of the following: "micro", "macro", "weighted", "samples"
Y_PROB_SCORES = [roc_auc_score, label_ranking_loss, coverage_error] # input y probabilities and y true

In [58]:
METRICS = []
METRIC_NAMES = []
for metric in Y_PRED_SCORES:
    METRICS.append(lambda y_true, y_prob, y_pred, metric=metric: metric(y_true, y_pred))
    METRIC_NAMES.append(metric.__name__)
for metric in Y_PRED_SCORES_WITH_AVERAGING:
    for average in ["micro", "macro", "weighted", "samples"]:
        zero_division = 0 if metric.__name__ == "jaccard_score" else np.nan
        METRICS.append(lambda y_true, y_prob, y_pred, metric=metric, average=average: metric(y_true, y_pred, average=average, zero_division=zero_division))
        METRIC_NAMES.append(metric.__name__ + "__" + average)
for metric in Y_PROB_SCORES:
    METRICS.append(lambda y_true, y_prob, y_pred, metric=metric: metric(y_true, y_prob))
    METRIC_NAMES.append(metric.__name__)

In [59]:
# Test metrics
y_true = np.array([[0, 1, 1], [1, 0, 0], [0, 1, 0]])
y_pred = np.array([[0, 1, 1], [0, 0, 0], [0, 1, 1]])
y_prob = np.array([[0.1, 0.9, 0.8], [0.7, 0.1, 0.05], [0.2, 0.0, 0.05]])

for metric, metric_name in zip(METRICS, METRIC_NAMES):
    print(metric_name, end=": ")
    print(metric(y_true, y_prob, y_pred))

accuracy_score: 0.3333333333333333
log_loss: 1.0593512767826487
hamming_loss: 0.2222222222222222
f1_score__micro: 0.75
f1_score__macro: 0.5555555555555555
f1_score__weighted: 0.6666666666666666
f1_score__samples: 0.5555555555555555
precision_score__micro: 0.75
precision_score__macro: 0.5
precision_score__weighted: 0.625
precision_score__samples: 0.5
recall_score__micro: 0.75
recall_score__macro: 0.6666666666666666
recall_score__weighted: 0.75
recall_score__samples: 0.6666666666666666
jaccard_score__micro: 0.6
jaccard_score__macro: 0.5
jaccard_score__weighted: 0.625
jaccard_score__samples: 0.5
roc_auc_score: 0.8333333333333334
label_ranking_loss: 0.3333333333333333
coverage_error: 2.0




### Load and parse data

In [60]:
fingerprints = pd.read_csv(FINGERPRINTS_PATH)
fingerprints.rename(columns={"InChIKey": "inchikey", "Name": "name", "InChI": "inchi"}, inplace=True)
if "name" in fingerprints.columns:
    fingerprints.drop(columns=["name"], inplace=True)
if "inchi" in fingerprints.columns:
    fingerprints.drop(columns=["inchi"], inplace=True)
fingerprints.set_index("inchikey", inplace=True)
print("NaNs:", fingerprints.isna().sum().sum())
fingerprints = fingerprints.astype(bool)
print(fingerprints.shape)
fingerprints.head()

NaNs: 0
(104, 166)


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,156,157,158,159,160,161,162,163,164,165
inchikey,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
FWZOFSHJDAIJQE-UHFFFAOYSA-N,False,False,False,False,False,False,False,False,False,False,...,True,False,True,True,False,True,True,True,True,False
VUNXPEWGQXFNOL-UHFFFAOYSA-N,False,False,False,False,False,False,False,False,False,False,...,True,False,True,True,False,True,True,True,True,False
JFPSLJJGWCHYOE-WOJBJXKFSA-N,False,False,False,False,False,False,False,False,False,False,...,True,False,True,True,False,True,True,True,True,False
VGYQPKLQPQJSQU-UHFFFAOYSA-N,False,False,False,False,False,False,False,False,False,False,...,True,False,True,True,False,True,True,True,True,False
NLUDHDUQAJYEEH-IZZNHLLZSA-N,False,False,False,False,False,False,False,False,False,False,...,True,False,True,True,False,True,True,True,True,False


In [61]:
spec2vec = pd.read_csv(SPEC2VEC_PATH)
spec2vec.rename(columns={"InChI Key": "inchikey", "Name": "name"}, inplace=True)
if "name" in spec2vec.columns:
    spec2vec = spec2vec.drop(columns=["name"])
spec2vec = spec2vec.set_index("inchikey")
spec2vec = spec2vec.astype(float)
print(spec2vec.shape)
spec2vec.head()

(3144, 300)


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
inchikey,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,15.387566,-0.466594,128.212979,-51.784879,50.710794,100.454237,37.555077,-105.826131,31.051446,62.338263,...,-17.074715,131.993879,-11.480879,-94.839212,-45.291035,87.989262,-229.266169,58.848309,-28.095784,15.842633
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,113.82143,78.168829,-6.747713,81.789962,203.772493,109.101307,9.375376,14.774937,-70.287549,35.959517,...,13.594595,-115.2703,-111.846199,111.038387,15.181126,30.404949,-152.397945,20.853796,-23.316306,37.147708
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,83.186291,-41.029504,31.763091,35.200038,-35.305679,101.72932,56.067705,-25.064136,40.739751,75.361117,...,1.274054,74.557059,32.747522,23.077897,85.837911,-11.749879,-120.201318,82.090422,35.842974,29.0582
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,118.81771,-64.274214,38.892811,49.980126,33.930059,102.045817,0.513328,-3.982716,-42.755118,74.909037,...,97.707563,68.794085,-164.242781,160.243625,58.744467,-6.496885,-117.118548,-10.736083,-47.655912,204.336708
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,91.464955,13.040805,84.813821,13.731519,134.206405,-26.45999,-189.242868,116.861201,19.808913,87.929782,...,13.920619,-35.539763,-58.136726,73.218996,4.029479,7.711222,20.721324,43.896813,79.700217,-22.953945


In [62]:
# Validate embeddings
print("Nan values: ", spec2vec.isna().sum().sum())

Nan values:  0


In [63]:
# For Both df in index repalce \xa0 with space and strip (remove leading and trailing spaces)
spec2vec.index = spec2vec.index.str.replace("\xa0", " ").str.strip()
fingerprints.index = fingerprints.index.str.replace("\xa0", " ").str.strip()

In [64]:
# Missing inchikeys in spec2vec
set(fingerprints.index.unique()) - (set(spec2vec.index.unique()))

{'AYONZGOWFAKCNA-UHFFFAOYSA-N', 'OIBARLCQMDCDSG-NSHDSACASA-N'}

In [65]:
# Missing inchikeys in fingerprints
set(spec2vec.index.unique()) - set(fingerprints.index.unique())

{'HBWAMRSFAPVOKZ-UHFFFAOYSA-N',
 'HGGWBFIRNWOJCL-CPDXTSBQSA-N',
 'JZGPZUIFYWMNKG-UHFFFAOYSA-N',
 'ORYOBNFVKJSNIY-UHFFFAOYSA-N'}

In [66]:
# Merge the dataframes to obtain X and y matrices (we add suffixes for later extraction)
merged = pd.merge(spec2vec.add_suffix("_x"), fingerprints.add_suffix("_y"), left_index=True, right_index=True, how="inner")
print(merged.shape)
merged.head()

(3052, 466)


Unnamed: 0_level_0,0_x,1_x,2_x,3_x,4_x,5_x,6_x,7_x,8_x,9_x,...,156_y,157_y,158_y,159_y,160_y,161_y,162_y,163_y,164_y,165_y
inchikey,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AWZDROKRYZXWBO-UHFFFAOYSA-N,-8.020377,-19.590963,47.416324,14.734204,-37.178691,281.631005,133.814303,-130.229745,163.191635,89.101363,...,True,False,True,True,False,True,True,True,True,False
AWZDROKRYZXWBO-UHFFFAOYSA-N,-91.541454,-35.753494,38.966337,-86.308427,19.250948,148.185439,99.37363,-186.472006,217.512989,160.863203,...,True,False,True,True,False,True,True,True,True,False
AWZDROKRYZXWBO-UHFFFAOYSA-N,33.186816,-64.306986,122.437815,132.50053,-12.631758,231.384121,2.159783,196.827764,-9.579701,148.339587,...,True,False,True,True,False,True,True,True,True,False
AWZDROKRYZXWBO-UHFFFAOYSA-N,132.202141,-85.069309,-99.805639,-198.594718,78.260691,193.392056,178.165336,327.913047,365.212265,10.10689,...,True,False,True,True,False,True,True,True,True,False
AWZDROKRYZXWBO-UHFFFAOYSA-N,83.362389,-42.695541,37.22887,-10.914806,-108.479697,133.004002,24.000027,-0.53134,176.526312,105.578328,...,True,False,True,True,False,True,True,True,True,False


In [67]:
# X is data from merged with suffix _x
X = merged.filter(regex="_x$").to_numpy()
# y is data from merged with suffix _y
y = merged.filter(regex="_y$").to_numpy().astype(int)

In [68]:
X.shape, y.shape

((3052, 300), (3052, 166))

### Train

In [69]:
weights = [
    {0: (y[:, i] == 0).sum() / y.shape[0], 1: (y[:, i] == 1).sum() / y.shape[0]} for i in range(y.shape[1])
]

In [70]:
# multilabel_classifier = DummyClassifier(strategy="most_frequent", random_state=RANDOM_STATE)

In [71]:
# multilabel_classifier = DummyClassifier(strategy="prior", random_state=RANDOM_STATE)

In [72]:
# classifier = DecisionTreeClassifier(random_state=RANDOM_STATE, class_weight=weights)
# multilabel_classifier = OneVsRestClassifier(classifier, n_jobs=-1)

In [73]:
# multilabel_classifier = DecisionTreeClassifier(random_state=RANDOM_STATE, class_weight=weights)

In [74]:
# classifier = RandomForestClassifier(n_estimators=100, random_state=RANDOM_STATE, class_weight=weights)
# multilabel_classifier = OneVsRestClassifier(classifier, n_jobs=-1)

In [75]:
# multilabel_classifier = RandomForestClassifier(n_estimators=300, random_state=RANDOM_STATE, class_weight=weights, n_jobs=-1)

In [79]:
classifier = DecisionTreeClassifier(random_state=RANDOM_STATE, class_weight="balanced")
multilabel_classifier = ClassifierChain(classifier, random_state=RANDOM_STATE, order="random", cv=3)

In [80]:
metrics_df = pd.DataFrame(columns=["repeat", "fold", "data_path"] + METRIC_NAMES)

def evaluate(y_true, y_prob, y_pred):
    row = {}
    for metric, metric_name in zip(METRICS, METRIC_NAMES):
        try:
            row[metric_name] = metric(y_true, y_prob, y_pred)
        except Exception as e:
            print(e)
            row[metric_name] = np.nan
            
    return row

def log_to_df(row):
    global metrics_df
    metrics_df = pd.concat([metrics_df, pd.DataFrame(row, index=[0])], ignore_index=True)


def fix_probability(p, c):
    if len(c) == len(p[0]) == 2:
        return p[:, 1]
    
    if c[0] == 0:
        return p[:, 0]
    
    return 1 - p[:, 0]

def fix_probabilities(y_prob, classes):
    if isinstance(y_prob, list) or (isinstance(y_prob, np.ndarray) and y_prob.ndim == 3):
        l = [fix_probability(p, c) for p, c in zip(y_prob, classes)]
        y_prob = np.array(l).T
        
    return y_prob

In [81]:
REPEATS = 2
K = 5

classifiers = []
scores = []
train_test_indices = []
for i in tqdm(range(REPEATS), desc="Repeats"):
    kf = KFold(n_splits=K, shuffle=True, random_state=RANDOM_STATE + i)

    for fold, (train_index, test_index) in tqdm(enumerate(kf.split(X, y)), desc="Fold", total=K):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        multilabel_classifier.fit(X_train, y_train)

        y_pred = multilabel_classifier.predict(X_test)
        y_prob = multilabel_classifier.predict_proba(X_test)
        y_prob = fix_probabilities(y_prob, multilabel_classifier.classes_)
        
        evaluation = evaluate(y_test, y_prob, y_pred)

        evaluation["repeat"] = i
        evaluation["fold"] = fold

        classifier_data_path = os.path.join(CLASSIFIER_OUTPUT_FOLDER, CLASSIFIER_NAME, "models", f"repeat_{i}_fold_{fold}.pkl")
        classifier_data = {
            "classifier": multilabel_classifier,
            "X_train": X_train,
            "y_train": y_train,
            "X_test": X_test,
            "y_test": y_test,
        }
        with open(classifier_data_path, "wb") as f:
            pickle.dump(classifier_data, f)

        evaluation["data_path"] = classifier_data_path
        log_to_df(evaluation)

Repeats:   0%|          | 0/2 [00:00<?, ?it/s]

Fold:   0%|          | 0/5 [00:00<?, ?it/s]

IndexError: index 1 is out of bounds for axis 1 with size 1

In [None]:
y_test.shape, y_prob.shape, y_pred.shape

((610, 166), (610, 166), (610, 166))

In [None]:
dummy_metrics_df = pd.read_csv("./models/tms_final/dummy_most_frequent/metrics.csv")

In [None]:
dummy_metrics_df

Unnamed: 0,repeat,fold,data_path,accuracy_score,log_loss,hamming_loss,f1_score__micro,f1_score__macro,f1_score__weighted,f1_score__samples,...,recall_score__macro,recall_score__weighted,recall_score__samples,jaccard_score__micro,jaccard_score__macro,jaccard_score__weighted,jaccard_score__samples,roc_auc_score,label_ranking_loss,coverage_error
0,0,0,./models/tms_final/dummy_most_frequent\models\...,0.0,546.074631,0.084623,0.765882,0.143597,0.631889,0.778962,...,0.150602,0.656457,0.690915,0.620591,0.138416,0.612958,0.655084,,0.634645,166.0
1,0,1,./models/tms_final/dummy_most_frequent\models\...,0.0,528.077819,0.08135,0.77348,0.143859,0.642221,0.788067,...,0.150602,0.666052,0.703775,0.630629,0.138889,0.623929,0.668628,,0.626433,166.0
2,0,2,./models/tms_final/dummy_most_frequent\models\...,0.0,533.751732,0.083577,0.767097,0.143113,0.634627,0.781547,...,0.150602,0.660929,0.698879,0.622188,0.137636,0.614583,0.660529,,0.630763,166.0
3,0,3,./models/tms_final/dummy_most_frequent\models\...,0.0,545.282479,0.085848,0.761659,0.142826,0.6275,0.777098,...,0.150602,0.654479,0.694045,0.615064,0.137172,0.606966,0.654953,,0.634028,166.0
4,0,4,./models/tms_final/dummy_most_frequent\models\...,0.0,519.013298,0.080792,0.773712,0.143408,0.643473,0.785782,...,0.150602,0.669091,0.701682,0.630938,0.13812,0.62389,0.66414,,0.628547,166.0
5,1,0,./models/tms_final/dummy_most_frequent\models\...,0.0,527.433016,0.082296,0.770213,0.143353,0.638397,0.783335,...,0.150602,0.664561,0.699456,0.626298,0.137923,0.618076,0.662388,,0.62993,166.0
6,1,1,./models/tms_final/dummy_most_frequent\models\...,0.0,552.178176,0.086082,0.762169,0.143288,0.627303,0.777791,...,0.150602,0.65264,0.692586,0.61573,0.137933,0.607892,0.655221,,0.634237,166.0
7,1,2,./models/tms_final/dummy_most_frequent\models\...,0.0,513.029457,0.08013,0.774849,0.143273,0.645383,0.786997,...,0.150602,0.671637,0.704531,0.632452,0.137883,0.625325,0.666868,,0.626917,166.0
8,1,3,./models/tms_final/dummy_most_frequent\models\...,0.0,547.958501,0.084881,0.765375,0.14358,0.6313,0.779891,...,0.150602,0.655614,0.692982,0.619926,0.138446,0.612735,0.656502,,0.633635,166.0
9,1,4,./models/tms_final/dummy_most_frequent\models\...,0.0,531.591861,0.082797,0.769303,0.143319,0.637468,0.783452,...,0.150602,0.662763,0.699745,0.625095,0.138051,0.618326,0.662365,,0.629691,166.0


In [None]:
metrics_df

Unnamed: 0,repeat,fold,data_path,accuracy_score,log_loss,hamming_loss,f1_score__micro,f1_score__macro,f1_score__weighted,f1_score__samples,...,recall_score__macro,recall_score__weighted,recall_score__samples,jaccard_score__micro,jaccard_score__macro,jaccard_score__weighted,jaccard_score__samples,roc_auc_score,label_ranking_loss,coverage_error
0,0,0,./models/tms_final/decision_tree\models\repeat...,0.166939,431.570455,0.091061,0.777521,0.280197,0.757749,0.783695,...,0.269153,0.754653,0.778344,0.63602,0.22205,0.688471,0.670604,,0.575515,166.0
1,0,1,./models/tms_final/decision_tree\models\repeat...,0.152209,424.6027,0.092225,0.773785,0.271973,0.756514,0.781192,...,0.26483,0.756407,0.784758,0.631035,0.215825,0.690639,0.665785,,0.573686,166.0
2,0,2,./models/tms_final/decision_tree\models\repeat...,0.132787,436.456452,0.098894,0.758745,0.259246,0.740691,0.767226,...,0.256913,0.746763,0.776668,0.611273,0.206423,0.675868,0.646904,,0.581347,166.0
3,0,3,./models/tms_final/decision_tree\models\repeat...,0.139344,448.943327,0.100464,0.75505,0.274114,0.736894,0.765312,...,0.266662,0.738774,0.771567,0.606491,0.21604,0.670218,0.64626,,0.584586,166.0
4,0,4,./models/tms_final/decision_tree\models\repeat...,0.155738,412.993597,0.090332,0.776973,0.270336,0.763056,0.783518,...,0.265522,0.762235,0.78567,0.635287,0.215418,0.695007,0.669826,,0.574349,166.0
5,1,0,./models/tms_final/decision_tree\models\repeat...,0.170213,421.753376,0.094305,0.769212,0.277221,0.753681,0.777106,...,0.271224,0.757245,0.784669,0.624975,0.218124,0.686154,0.662016,,0.576606,166.0
6,1,1,./models/tms_final/decision_tree\models\repeat...,0.165303,440.733703,0.095991,0.767215,0.289251,0.749157,0.775308,...,0.282017,0.748461,0.778418,0.622343,0.225592,0.680154,0.661259,,0.575864,166.0
7,1,2,./models/tms_final/decision_tree\models\repeat...,0.159016,415.240198,0.090905,0.774038,0.264804,0.758484,0.784557,...,0.258662,0.758418,0.785155,0.631372,0.211385,0.693366,0.672652,,0.578796,166.0
8,1,3,./models/tms_final/decision_tree\models\repeat...,0.147541,441.123066,0.099516,0.760499,0.268476,0.74644,0.769554,...,0.265093,0.748211,0.777669,0.613553,0.213727,0.679586,0.652317,,0.579792,166.0
9,1,4,./models/tms_final/decision_tree\models\repeat...,0.119672,450.672163,0.098529,0.756451,0.245743,0.735202,0.766869,...,0.239202,0.734591,0.764761,0.6083,0.197839,0.672239,0.647811,,0.588061,166.0


In [None]:
metrics_df.to_csv(os.path.join(CLASSIFIER_OUTPUT_FOLDER, CLASSIFIER_NAME, "metrics.csv"), index=False)