## This Jupyter notebook is used to predict the fingerprint of a given spec2vec embedding.

In [31]:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import label_ranking_loss
import pickle
RANDOM_STATE = 25082023

In [4]:
fingerprints = pd.read_csv("./embeddings/tms_maccs_fingerprint.csv")
fingerprints = fingerprints.drop(columns=["Name", "InChI"])
fingerprints = fingerprints.set_index("InChI Key")
fingerprints.index.name = "inchikey"
fingerprints = fingerprints.astype(bool)
print(fingerprints.shape)
fingerprints.head()

(105, 192)


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,182,183,184,185,186,187,188,189,190,191
inchikey,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
FWZOFSHJDAIJQE-UHFFFAOYSA-N,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
VUNXPEWGQXFNOL-UHFFFAOYSA-N,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
JFPSLJJGWCHYOE-WOJBJXKFSA-N,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
VGYQPKLQPQJSQU-UHFFFAOYSA-N,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
NLUDHDUQAJYEEH-IZZNHLLZSA-N,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False


In [5]:
# Simple indicator analysis
print("Mean true indicators: ", fingerprints.sum(axis=1).mean())
print("Std true indicators: ", fingerprints.sum(axis=1).std())
print("Min true indicators: ", fingerprints.sum(axis=1).min())
print("Max true indicators: ", fingerprints.sum(axis=1).max())

Mean true indicators:  36.23809523809524
Std true indicators:  17.869464613349937
Min true indicators:  21
Max true indicators:  192


In [6]:
true_class_weight = 1 - fingerprints.sum(axis=1).mean() / fingerprints.shape[1]
false_class_weight = 1 - true_class_weight
true_class_weight, false_class_weight

(0.8112599206349206, 0.18874007936507942)

In [7]:
# Validate fingerprint
print("Nan values: ", fingerprints.isna().sum().sum())

Nan values:  0


In [8]:
spec2vec = pd.read_csv("./embeddings/tms_spec2vec_embeddings.csv")
spec2vec = spec2vec.drop(columns=["name"])
spec2vec = spec2vec.set_index("inchikey")
spec2vec = spec2vec.astype(float)
print(spec2vec.shape)
spec2vec.head()

(3144, 300)


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
inchikey,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,15.387566,-0.466594,128.212979,-51.784879,50.710794,100.454237,37.555077,-105.826131,31.051446,62.338263,...,-17.074715,131.993879,-11.480879,-94.839212,-45.291035,87.989262,-229.266169,58.848309,-28.095784,15.842633
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,113.82143,78.168829,-6.747713,81.789962,203.772493,109.101307,9.375376,14.774937,-70.287549,35.959517,...,13.594595,-115.2703,-111.846199,111.038387,15.181126,30.404949,-152.397945,20.853796,-23.316306,37.147708
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,83.186291,-41.029504,31.763091,35.200038,-35.305679,101.72932,56.067705,-25.064136,40.739751,75.361117,...,1.274054,74.557059,32.747522,23.077897,85.837911,-11.749879,-120.201318,82.090422,35.842974,29.0582
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,118.81771,-64.274214,38.892811,49.980126,33.930059,102.045817,0.513328,-3.982716,-42.755118,74.909037,...,97.707563,68.794085,-164.242781,160.243625,58.744467,-6.496885,-117.118548,-10.736083,-47.655912,204.336708
VFFKJOXNCSJSAQ-UHFFFAOYSA-N,91.464955,13.040805,84.813821,13.731519,134.206405,-26.45999,-189.242868,116.861201,19.808913,87.929782,...,13.920619,-35.539763,-58.136726,73.218996,4.029479,7.711222,20.721324,43.896813,79.700217,-22.953945


In [9]:
# Validate embeddings
print("Nan values: ", spec2vec.isna().sum().sum())

Nan values:  0


In [10]:
# For Both df in index repalce \xa0 with space and strip (remove leading and trailing spaces)
spec2vec.index = spec2vec.index.str.replace("\xa0", " ").str.strip()
fingerprints.index = fingerprints.index.str.replace("\xa0", " ").str.strip()

In [11]:
# Missing inchikeys in spec2vec
set(fingerprints.index.unique()) - (set(spec2vec.index.unique()))

{'AYONZGOWFAKCNA-UHFFFAOYSA-N', 'OIBARLCQMDCDSG-NSHDSACASA-N'}

In [12]:
# Missing inchikeys in fingerprints
set(spec2vec.index.unique()) - set(fingerprints.index.unique())

{'HGGWBFIRNWOJCL-CPDXTSBQSA-N',
 'JZGPZUIFYWMNKG-UHFFFAOYSA-N',
 'ORYOBNFVKJSNIY-UHFFFAOYSA-N'}

In [13]:
# Merge the dataframes to obtain X and y matrices (we add suffixes for later extraction)
merged = pd.merge(spec2vec.add_suffix("_x"), fingerprints.add_suffix("_y"), left_index=True, right_index=True, how="inner")
print(merged.shape)
merged.head()

(3082, 492)


Unnamed: 0_level_0,0_x,1_x,2_x,3_x,4_x,5_x,6_x,7_x,8_x,9_x,...,182_y,183_y,184_y,185_y,186_y,187_y,188_y,189_y,190_y,191_y
inchikey,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AWZDROKRYZXWBO-UHFFFAOYSA-N,-8.020377,-19.590963,47.416324,14.734204,-37.178691,281.631005,133.814303,-130.229745,163.191635,89.101363,...,False,False,False,False,False,False,False,False,False,False
AWZDROKRYZXWBO-UHFFFAOYSA-N,-91.541454,-35.753494,38.966337,-86.308427,19.250948,148.185439,99.37363,-186.472006,217.512989,160.863203,...,False,False,False,False,False,False,False,False,False,False
AWZDROKRYZXWBO-UHFFFAOYSA-N,33.186816,-64.306986,122.437815,132.50053,-12.631758,231.384121,2.159783,196.827764,-9.579701,148.339587,...,False,False,False,False,False,False,False,False,False,False
AWZDROKRYZXWBO-UHFFFAOYSA-N,132.202141,-85.069309,-99.805639,-198.594718,78.260691,193.392056,178.165336,327.913047,365.212265,10.10689,...,False,False,False,False,False,False,False,False,False,False
AWZDROKRYZXWBO-UHFFFAOYSA-N,83.362389,-42.695541,37.22887,-10.914806,-108.479697,133.004002,24.000027,-0.53134,176.526312,105.578328,...,False,False,False,False,False,False,False,False,False,False


In [14]:
# X is data from merged with suffix _x
X = merged.filter(regex="_x$").to_numpy()
# y is data from merged with suffix _y
y = merged.filter(regex="_y$").to_numpy().astype(int)

In [15]:
X.shape, y.shape

((3082, 300), (3082, 192))

In [16]:
classifier = RandomForestClassifier(n_estimators=100, random_state=RANDOM_STATE, class_weight={0: false_class_weight, 1: true_class_weight})

In [19]:
one_vs_rest_classifier = OneVsRestClassifier(classifier, n_jobs=-1)

In [20]:
REPEATS = 2
K = 5

best_one_vs_rest_classifier = None
best_accuracy = 0
for i in range(REPEATS):
    kf = KFold(n_splits=K, shuffle=True, random_state=RANDOM_STATE)

    for fold, (train_index, test_index) in enumerate(kf.split(X)):
        print("Fold: ", fold)
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        one_vs_rest_classifier.fit(X_train, y_train)
        y_pred = one_vs_rest_classifier.predict(X_test)
        print("Accuracy: ", (y_pred == y_test).mean())

        if (y_pred == y_test).mean() > best_accuracy:
            best_accuracy = (y_pred == y_test).mean()
            best_one_vs_rest_classifier = one_vs_rest_classifier
print("Best accuracy: ", best_accuracy)

Fold:  0
Accuracy:  0.9437635062128579
Fold:  1
Accuracy:  0.9429278092922745
Fold:  2
Accuracy:  0.9386160714285714
Fold:  3
Accuracy:  0.9407382981601732
Fold:  4
Accuracy:  0.941896645021645
Fold:  0
Accuracy:  0.9437635062128579
Fold:  1
Accuracy:  0.9429278092922745
Fold:  2
Accuracy:  0.9386160714285714
Fold:  3
Accuracy:  0.9407382981601732
Fold:  4
Accuracy:  0.941896645021645
Best accuracy:  0.9437635062128579


In [32]:
y_pred = best_one_vs_rest_classifier.predict(X)
label_ranking_loss(y, y_pred)

0.04147317383942748

### Evaluation

In [21]:
instance = merged.iloc[0]
instance_inchikey = instance.name
instance_inchikey

'AWZDROKRYZXWBO-UHFFFAOYSA-N'

In [22]:
instance_X = instance.filter(regex="_x$").to_numpy()
instance_y = instance.filter(regex="_y$").to_numpy().astype(int)

In [23]:
instance_y

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0,
       0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
       0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [25]:
one_vs_rest_classifier.score(instance_X.reshape(1, -1), instance_y.reshape(1, -1))

1.0

In [24]:
one_vs_rest_classifier.predict_proba(instance_X.reshape(1, -1))

array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  , 1.  , 0.8 , 0.  ,
        0.  , 0.  , 0.  , 0.05, 0.  , 0.  , 0.  , 1.  , 0.  , 0.02, 0.02,
        0.02, 0.  , 0.  , 0.02, 0.  , 0.  , 0.  , 0.  , 0.05, 0.02, 0.  ,
        0.  , 0.  , 0.  , 0.  , 0.  , 0.75, 0.  , 0.  , 0.  , 0.  , 0.  ,
        0.01, 0.04, 0.  , 0.  , 0.  , 0.  , 0.03, 0.01, 0.  , 0.01, 0.21,
        0.  , 0.  , 0.  , 0.01, 0.01, 0.04, 0.  , 1.  , 0.01, 0.88, 0.02,
        0.  , 0.  , 0.02, 0.  , 0.  , 0.  , 0.  , 0.01, 1.  , 0.08, 0.  ,
        0.02, 0.01, 0.  , 0.02, 1.  , 0.  , 0.  , 0.03, 0.01, 0.02, 0.8 ,
        0.02, 0.04, 1.  , 0.83, 0.  , 0.12, 0.01, 0.87, 0.11, 0.02, 0.  ,
        0.01, 1.  , 0.96, 0.08, 0.09, 0.16, 0.01, 0.06, 0.  , 0.01, 0.01,
        0.03, 0.02, 1.  , 0.88, 1.  , 1.  , 0.04, 0.16, 1.  , 0.01, 0.08,
        0.  , 0.87, 0.01, 0.02, 0.03, 0.  , 0.03, 0.01, 1.  , 0.01, 1.  ,
        0.16, 0.98, 0.1 , 0.13, 1.  , 

### Save classifier

In [27]:
pickle.dump(one_vs_rest_classifier, open("./models/tms/tms_one_vs_rest_classifier.pkl", "wb"))