In [1]:
import re
import os
import pickle 
import numpy as np 
import pandas as pd
import sys
import matplotlib.pyplot as plt 

from sklearn import metrics
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedGroupKFold 
from sklearn.model_selection import cross_val_score

from sklearn.metrics import roc_auc_score, make_scorer
from sklearn.metrics import mean_squared_error
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.dummy import DummyClassifier

import xgboost as xgb
from xgboost import XGBClassifier



from google.colab import drive
drive.mount("/content/drive")


Mounted at /content/drive


In [2]:
# load the tab used for embedding, only the training set of course

with open("drive/MyDrive/OrlyPred/Homomer_embeds/results/embeds_Mar_22/train_set.pkl", 'rb') as f:
  overall_train_set = pickle.load(f)

# index reset is important for the stratified splitting and the saving to lists
overall_train_set.reset_index(drop=True, inplace=True)

In [3]:
# define the input, using the codes since this is convenient to later extract rows from the general table. Actually the input is the embeddings
# the labls, y, are the predifined nsub (number of subunits annotated to the relevant pdb code)
# groups - the cluster representatives, used in order to jave all the sequences from the same cluster in the same set (train/validation)

X = overall_train_set["code"]
y = overall_train_set["nsub"]
groups = overall_train_set["representative"]
X

0        5ahz_1
1        3q6m_1
2        1luq_1
3        3t6f_1
4        1srf_1
          ...  
28823    4zt1_1
28824    4a56_1
28825    5hap_1
28826    4s2l_1
28827    5faq_1
Name: code, Length: 28828, dtype: object

In [4]:
# generate groups for k-fold cross validation, used in the next few cells
# this is used when one run is carried out, for the cross validation there is a different code below 

cv = StratifiedGroupKFold(n_splits=10, shuffle=True, random_state=1)
train_lst = []
test_lst = []
for train_idxs, test_idxs in cv.split(X, y, groups):
    train_lst.append(X[train_idxs].tolist())
    test_lst.append(X[test_idxs].tolist())
    print("train_lst", train_lst)
    print("test_lst", test_lst)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [5]:
train_idx_df = pd.DataFrame(train_lst).transpose()
train_idx_df.rename(columns={0:"train_0", 1:"train_1", 2:"train_2", 3:"train_3", 4:"train_4", 5:"train_5", 6:"train_6", 7:"train_7", 8:"train_8", 9:"train_9"}, inplace=True)
print(train_idx_df)
test_idx_df = pd.DataFrame(test_lst).transpose()
test_idx_df.rename(columns={0:"test_0", 1:"test_1", 2:"test_2", 3:"test_3", 4:"test_4", 5:"test_5", 6:"test_6", 7:"test_7", 8:"test_8", 9:"test_9"}, inplace=True)
print(test_idx_df)
merged_train_test = pd.concat([train_idx_df, test_idx_df], axis=1, join="outer")


      train_0 train_1 train_2 train_3 train_4 train_5 train_6 train_7 train_8  \
0      5ahz_1  5ahz_1  5ahz_1  5ahz_1  5ahz_1  5ahz_1  5ahz_1  5ahz_1  3q6m_1   
1      3q6m_1  3q6m_1  3q6m_1  1luq_1  3q6m_1  3q6m_1  3q6m_1  3q6m_1  1luq_1   
2      1luq_1  1luq_1  1luq_1  3t6f_1  1luq_1  1luq_1  4wog_1  1luq_1  3t6f_1   
3      3t6f_1  3t6f_1  3t6f_1  1srf_1  3t6f_1  3t6f_1  4gda_1  3t6f_1  1srf_1   
4      1srf_1  1srf_1  1srf_1  1vwa_1  1srf_1  1srf_1  1ort_1  1srf_1  1vwa_1   
...       ...     ...     ...     ...     ...     ...     ...     ...     ...   
26191    None    None    None    None  4zt1_1    None    None    None    None   
26192    None    None    None    None  4a56_1    None    None    None    None   
26193    None    None    None    None  5hap_1    None    None    None    None   
26194    None    None    None    None  4s2l_1    None    None    None    None   
26195    None    None    None    None  5faq_1    None    None    None    None   

      train_9  
0      5ahz

In [6]:
train_set = overall_train_set[overall_train_set["code"].isin(merged_train_test["train_0"])]
test_set = overall_train_set[overall_train_set["code"].isin(merged_train_test["test_0"])]

In [7]:
from sklearn.neural_network import MLPClassifier

X_train = train_set['embeddings'].tolist()
y_train = train_set['nsub']

X_test = test_set['embeddings'].tolist()
y_test = test_set['nsub']



In [None]:
 # the basic plain vanilla MLP trained on one fold, for a baseline/initial model
clf = MLPClassifier(solver='adam', random_state=1, learning_rate_init=0.001)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)




In [4]:
 # train an MLP with k-fold cross valdation (k=10)
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import cross_val_score


X = overall_train_set["embeddings"]
y = overall_train_set["nsub"]
groups = overall_train_set["representative"]
cv = StratifiedGroupKFold(n_splits=10)



clf = MLPClassifier(solver='adam', random_state=1, learning_rate_init=0.001)


for train_idxs, test_idxs in cv.split(X, y, groups):
    precision_lst, recall_lst, f1_lst, adj_balanced_accuracy = [], [], [], []
    clf.fit(np.vstack(X[train_idxs]), y[train_idxs])
    # print(clf.score(np.vstack(X[test_idxs]), y[test_idxs]))
    clf.fit(np.vstack(X[train_idxs]), y[train_idxs])
    y_pred = clf.predict(np.vstack(X[test_idxs]))
    precision_lst.append(precision_score(y[test_idxs], y_pred, average='weighted', zero_division=0))
    recall_lst.append(recall_score(y[test_idxs], y_pred, average='weighted', zero_division=0))
    f1_lst.append(f1_score(y[test_idxs], y_pred, average='weighted', zero_division=0))
    adj_balanced_accuracy.append(metrics.balanced_accuracy_score(y[test_idxs], y_pred, adjusted=True))
    print("Adjusted Balanced accuracy: %.3f" % metrics.balanced_accuracy_score(y[test_idxs], y_pred, adjusted=True))
    print('F-measure: %.3f' % f1_score(y[test_idxs], y_pred, average='weighted', zero_division=0))
    print(metrics.classification_report(y[test_idxs],y_pred, zero_division=0))
    print(metrics.confusion_matrix(y[test_idxs],y_pred))

    # scores = cross_val_score(clf, X, y, cv=cv)

    # print("TRAIN:", X[train_idxs])
    # print("      ", y[train_idxs])
    # print(" TEST:", X[test_idxs])
    # print("      ", y[test_idxs])



Adjusted Balanced accuracy: 0.263
F-measure: 0.635
              precision    recall  f1-score   support

         1.0       0.73      0.79      0.76      1315
         2.0       0.58      0.60      0.59       988
         3.0       0.46      0.36      0.41       118
         4.0       0.45      0.33      0.38       265
         5.0       1.00      0.93      0.96        29
         6.0       0.45      0.32      0.37        78
         7.0       0.00      0.00      0.00         2
         8.0       0.79      0.50      0.61        30
         9.0       0.00      0.00      0.00         1
        10.0       0.50      0.14      0.22         7
        12.0       0.30      0.18      0.22        17
        14.0       0.00      0.00      0.00         1
        24.0       0.00      0.00      0.00         5
        60.0       0.00      0.00      0.00         0

    accuracy                           0.64      2856
   macro avg       0.38      0.30      0.32      2856
weighted avg       0.63      



Adjusted Balanced accuracy: 0.220
F-measure: 0.612
              precision    recall  f1-score   support

         1.0       0.75      0.76      0.76      1329
         2.0       0.53      0.55      0.54       990
         3.0       0.49      0.30      0.37       118
         4.0       0.46      0.54      0.50       265
         5.0       0.75      0.27      0.40        11
         6.0       0.10      0.09      0.09        78
         7.0       0.00      0.00      0.00         2
         8.0       0.86      0.20      0.32        30
         9.0       0.00      0.00      0.00         1
        10.0       0.00      0.00      0.00         7
        12.0       0.55      0.35      0.43        17
        13.0       0.00      0.00      0.00         1
        14.0       0.00      0.00      0.00         1
        24.0       0.80      0.80      0.80         5

    accuracy                           0.62      2855
   macro avg       0.38      0.28      0.30      2855
weighted avg       0.62      



Adjusted Balanced accuracy: 0.259
F-measure: 0.636
              precision    recall  f1-score   support

         1.0       0.78      0.69      0.73      1329
         2.0       0.56      0.71      0.63       989
         3.0       0.42      0.26      0.32       118
         4.0       0.52      0.49      0.51       265
         5.0       0.00      0.00      0.00        11
         6.0       0.42      0.44      0.43        78
         7.0       0.00      0.00      0.00         5
         8.0       0.16      0.10      0.12        29
        10.0       0.00      0.00      0.00         8
        12.0       0.64      0.41      0.50        17
        14.0       0.00      0.00      0.00         1
        24.0       1.00      1.00      1.00         5
        60.0       0.00      0.00      0.00         4

    accuracy                           0.64      2859
   macro avg       0.35      0.32      0.33      2859
weighted avg       0.64      0.64      0.64      2859

[[913 363  24  18   1   7   



Adjusted Balanced accuracy: 0.346
F-measure: 0.631
              precision    recall  f1-score   support

         1.0       0.74      0.77      0.76      1328
         2.0       0.55      0.58      0.57       989
         3.0       0.55      0.49      0.52       118
         4.0       0.54      0.37      0.44       264
         5.0       0.20      0.18      0.19        11
         6.0       0.34      0.38      0.36        78
         7.0       0.00      0.00      0.00         1
         8.0       0.60      0.62      0.61        29
        10.0       0.00      0.00      0.00         7
        11.0       0.00      0.00      0.00         1
        12.0       0.59      0.59      0.59        17
        14.0       0.30      0.16      0.21        19
        24.0       0.80      1.00      0.89         4

    accuracy                           0.64      2866
   macro avg       0.40      0.40      0.39      2866
weighted avg       0.63      0.64      0.63      2866

[[1025  273    6    7    2  



Adjusted Balanced accuracy: 0.021
F-measure: 0.457
              precision    recall  f1-score   support

         1.0       0.60      0.73      0.66      1329
         2.0       0.39      0.50      0.44       989
         3.0       0.00      0.00      0.00       118
         4.0       0.00      0.00      0.00       264
         5.0       0.00      0.00      0.00        11
         6.0       0.00      0.00      0.00        78
         7.0       0.00      0.00      0.00         1
         8.0       0.00      0.00      0.00        29
        10.0       0.00      0.00      0.00         7
        12.0       0.00      0.00      0.00        17
        14.0       0.00      0.00      0.00         1
        24.0       0.00      0.00      0.00        22

    accuracy                           0.51      2866
   macro avg       0.08      0.10      0.09      2866
weighted avg       0.41      0.51      0.46      2866

[[971 358   0   0   0   0   0   0   0   0   0   0]
 [499 490   0   0   0   0   0  

KeyboardInterrupt: ignored