In [1]:
import pandas as pd
import time
import numpy as np
import xgboost as xgb
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score, make_scorer, f1_score, matthews_corrcoef
from sklearn.model_selection import RandomizedSearchCV,  GridSearchCV
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import plot_confusion_matrix

In [2]:
embeds = ["emb/sc_ppi_emb_d32_e1_l120_w20_k10_p1", "emb/sc_ppi_emb_d32_e3_l120_w10_k10_p2",
          "emb/sc_ppi_emb_d32_e3_l80_w20_k10_p0.5", "emb/sc_ppi_emb_d32_e3_l80_w20_k10_p1", 
          "emb/sc_ppi_emb_d64_e1_l120_w20_k10_p1", "emb/sc_ppi_emb_d64_e3_l120_w10_k10_p1",
          "emb/sc_ppi_emb_d64_e3_l120_w20_k20_p2", "emb/sc_ppi_emb_d64_e3_l80_w10_k10_p2",
          "emb/sc_ppi_emb_d64_e3_l80_w20_k10_p0.5", "emb/sc_ppi_emb_d64_e3_l80_w20_k20_p0.5"]
fill = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [3]:
#print(embeds[0]+"_out.csv")
#s = embeds[0]+"_out.csv"
#m = pd.read_csv(s)
#m.head()
data = {'Embeddings': embeds,
        'Accuracy': fill,
        'Balanced Accuracy Score': fill,
        'F1 Score': fill,
        'Matthews Correlation Coefficient': fill
        }
results = pd.DataFrame(data, columns= ['Embeddings', 'Accuracy', 'Balanced Accuracy Score', 'F1 Score', 'Matthews Correlation Coefficient'])
results.head()

Unnamed: 0,Embeddings,Accuracy,Balanced Accuracy Score,F1 Score,Matthews Correlation Coefficient
0,emb/sc_ppi_emb_d32_e1_l120_w20_k10_p1,0,0,0,0
1,emb/sc_ppi_emb_d32_e3_l120_w10_k10_p2,0,0,0,0
2,emb/sc_ppi_emb_d32_e3_l80_w20_k10_p0.5,0,0,0,0
3,emb/sc_ppi_emb_d32_e3_l80_w20_k10_p1,0,0,0,0
4,emb/sc_ppi_emb_d64_e1_l120_w20_k10_p1,0,0,0,0


In [4]:
df = pd.DataFrame(columns= ['Embeddings', 'Accuracy', 'Balanced Accuracy Score', 'F1 Score', 'Matthews Correlation Coefficient', 'Time Taken'])
df.head()

Unnamed: 0,Embeddings,Accuracy,Balanced Accuracy Score,F1 Score,Matthews Correlation Coefficient,Time Taken


In [6]:
for x in embeds:
    X = pd.read_csv(x+".emb.csv")
    X.drop(columns=X.columns[0], axis=1, inplace=True)
    X.head()
    y = pd.read_csv(x+".emb_out.csv")
    start = time.time()
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42, stratify=y)
    
    param_grid = {
         'max_depth': [5, 6, 7],
         'learning_rate': [0.2, 0.15, 0.1],
         'min_child_weight' : [1, 3, 5],
         'gamma': [1.0, 2.0, 3.0],
         'reg_lambda': [10.0, 20.0, 100.0],
         'scale_pos_weight': [1]
    }

    clf = GridSearchCV(estimator = xgb.XGBClassifier(objective='binary:logistic',
                                                              seed=42,
                                                              subsample=0.9,
                                                              colsample_bytree=0.5
                                                               ),
                                                              param_grid = param_grid,
                                                              scoring = 'roc_auc',
                                                              verbose = 2,
                                                              n_jobs = 10,
                                                              cv = 4)

    clf.fit(X_train,
            y_train,
            early_stopping_rounds=10,
            eval_metric='auc',
            eval_set=[(X_test, y_test)],
            verbose=True)

    time_taken = time.time() - start
    print(clf.best_estimator_)
    print(clf.best_params_)
    
    predictions = clf.predict(X_test)
    acc = accuracy_score(y_test, predictions)
    balanced_acc = balanced_accuracy_score(y_test, predictions)
    f1 = f1_score(y_test, predictions)
    matt = matthews_corrcoef(y_test, predictions)
    row = [x, acc, balanced_acc, f1, matt, time_taken]
    df.loc[x] = row
    
    

Fitting 4 folds for each of 243 candidates, totalling 972 fits


  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[0]	validation_0-auc:0.56412
[1]	validation_0-auc:0.75632
[2]	validation_0-auc:0.79377
[3]	validation_0-auc:0.83099
[4]	validation_0-auc:0.83860
[5]	validation_0-auc:0.83884
[6]	validation_0-auc:0.84338
[7]	validation_0-auc:0.85500
[8]	validation_0-auc:0.85481
[9]	validation_0-auc:0.85919
[10]	validation_0-auc:0.87116
[11]	validation_0-auc:0.87691
[12]	validation_0-auc:0.88190
[13]	validation_0-auc:0.88573
[14]	validation_0-auc:0.88741
[15]	validation_0-auc:0.88899
[16]	validation_0-auc:0.89110
[17]	validation_0-auc:0.89219
[18]	validation_0-auc:0.89496
[19]	validation_0-auc:0.89728
[20]	validation_0-auc:0.90045
[21]	validation_0-auc:0.90204
[22]	validation_0-auc:0.90187
[23]	validation_0-auc:0.90538
[24]	validation_0-auc:0.90702
[25]	validation_0-auc:0.90881
[26]	validation_0-auc:0.91048
[27]	validation_0-auc:0.91090
[28]	validation_0-auc:0.91221
[29]	validation_0-auc:0.91272
[30]	validation_0-auc:0.91392
[31]	validation_0-auc:0.91492
[32]	validation_0-auc:0.91619
[33]	validation_0-au

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[18]	validation_0-auc:0.90058
[19]	validation_0-auc:0.90154
[20]	validation_0-auc:0.90256
[21]	validation_0-auc:0.90430
[22]	validation_0-auc:0.90509
[23]	validation_0-auc:0.90533
[24]	validation_0-auc:0.90620
[25]	validation_0-auc:0.90654
[26]	validation_0-auc:0.90704
[27]	validation_0-auc:0.90766
[28]	validation_0-auc:0.90692
[29]	validation_0-auc:0.90804
[30]	validation_0-auc:0.90936
[31]	validation_0-auc:0.90888
[32]	validation_0-auc:0.90851
[33]	validation_0-auc:0.90858
[34]	validation_0-auc:0.90990
[35]	validation_0-auc:0.91025
[36]	validation_0-auc:0.91058
[37]	validation_0-auc:0.91114
[38]	validation_0-auc:0.91221
[39]	validation_0-auc:0.91253
[40]	validation_0-auc:0.91273
[41]	validation_0-auc:0.91399
[42]	validation_0-auc:0.91514
[43]	validation_0-auc:0.91589
[44]	validation_0-auc:0.91582
[45]	validation_0-auc:0.91545
[46]	validation_0-auc:0.91512
[47]	validation_0-auc:0.91586
[48]	validation_0-auc:0.91597
[49]	validation_0-auc:0.91657
[50]	validation_0-auc:0.91731
[51]	valid

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[20]	validation_0-auc:0.91194
[21]	validation_0-auc:0.91421
[22]	validation_0-auc:0.91593
[23]	validation_0-auc:0.91715
[24]	validation_0-auc:0.91758
[25]	validation_0-auc:0.91972
[26]	validation_0-auc:0.91960
[27]	validation_0-auc:0.92037
[28]	validation_0-auc:0.92110
[29]	validation_0-auc:0.92197
[30]	validation_0-auc:0.92230
[31]	validation_0-auc:0.92306
[32]	validation_0-auc:0.92380
[33]	validation_0-auc:0.92377
[34]	validation_0-auc:0.92366
[35]	validation_0-auc:0.92516
[36]	validation_0-auc:0.92546
[37]	validation_0-auc:0.92639
[38]	validation_0-auc:0.92613
[39]	validation_0-auc:0.92716
[40]	validation_0-auc:0.92804
[41]	validation_0-auc:0.92879
[42]	validation_0-auc:0.92914
[43]	validation_0-auc:0.92944
[44]	validation_0-auc:0.92993
[45]	validation_0-auc:0.93059
[46]	validation_0-auc:0.93123
[47]	validation_0-auc:0.93157
[48]	validation_0-auc:0.93197
[49]	validation_0-auc:0.93197
[50]	validation_0-auc:0.93251
[51]	validation_0-auc:0.93330
[52]	validation_0-auc:0.93402
[53]	valid

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[19]	validation_0-auc:0.91185
[20]	validation_0-auc:0.91356
[21]	validation_0-auc:0.91294
[22]	validation_0-auc:0.91483
[23]	validation_0-auc:0.91663
[24]	validation_0-auc:0.91724
[25]	validation_0-auc:0.91826
[26]	validation_0-auc:0.91921
[27]	validation_0-auc:0.91971
[28]	validation_0-auc:0.92132
[29]	validation_0-auc:0.92207
[30]	validation_0-auc:0.92266
[31]	validation_0-auc:0.92317
[32]	validation_0-auc:0.92277
[33]	validation_0-auc:0.92376
[34]	validation_0-auc:0.92392
[35]	validation_0-auc:0.92462
[36]	validation_0-auc:0.92487
[37]	validation_0-auc:0.92583
[38]	validation_0-auc:0.92583
[39]	validation_0-auc:0.92683
[40]	validation_0-auc:0.92739
[41]	validation_0-auc:0.92782
[42]	validation_0-auc:0.92798
[43]	validation_0-auc:0.92868
[44]	validation_0-auc:0.92895
[45]	validation_0-auc:0.92960
[46]	validation_0-auc:0.92996
[47]	validation_0-auc:0.93022
[48]	validation_0-auc:0.93081
[49]	validation_0-auc:0.93129
[50]	validation_0-auc:0.93134
[51]	validation_0-auc:0.93163
[52]	valid

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[12]	validation_0-auc:0.89524
[13]	validation_0-auc:0.89936
[14]	validation_0-auc:0.90294
[15]	validation_0-auc:0.90442
[16]	validation_0-auc:0.90522
[17]	validation_0-auc:0.90468
[18]	validation_0-auc:0.90487
[19]	validation_0-auc:0.90658
[20]	validation_0-auc:0.90641
[21]	validation_0-auc:0.91013
[22]	validation_0-auc:0.91223
[23]	validation_0-auc:0.91249
[24]	validation_0-auc:0.91378
[25]	validation_0-auc:0.91421
[26]	validation_0-auc:0.91376
[27]	validation_0-auc:0.91587
[28]	validation_0-auc:0.91680
[29]	validation_0-auc:0.91742
[30]	validation_0-auc:0.91766
[31]	validation_0-auc:0.91740
[32]	validation_0-auc:0.91687
[33]	validation_0-auc:0.91893
[34]	validation_0-auc:0.91965
[35]	validation_0-auc:0.91991
[36]	validation_0-auc:0.92048
[37]	validation_0-auc:0.92077
[38]	validation_0-auc:0.92157
[39]	validation_0-auc:0.92178
[40]	validation_0-auc:0.92203
[41]	validation_0-auc:0.92257
[42]	validation_0-auc:0.92258
[43]	validation_0-auc:0.92304
[44]	validation_0-auc:0.92425
[45]	valid

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[11]	validation_0-auc:0.88861
[12]	validation_0-auc:0.89121
[13]	validation_0-auc:0.89160
[14]	validation_0-auc:0.89233
[15]	validation_0-auc:0.89435
[16]	validation_0-auc:0.89573
[17]	validation_0-auc:0.89806
[18]	validation_0-auc:0.89913
[19]	validation_0-auc:0.90083
[20]	validation_0-auc:0.89908
[21]	validation_0-auc:0.90071
[22]	validation_0-auc:0.90118
[23]	validation_0-auc:0.90228
[24]	validation_0-auc:0.90314
[25]	validation_0-auc:0.90409
[26]	validation_0-auc:0.90521
[27]	validation_0-auc:0.90564
[28]	validation_0-auc:0.90706
[29]	validation_0-auc:0.90877
[30]	validation_0-auc:0.90925
[31]	validation_0-auc:0.91020
[32]	validation_0-auc:0.91103
[33]	validation_0-auc:0.91231
[34]	validation_0-auc:0.91268
[35]	validation_0-auc:0.91316
[36]	validation_0-auc:0.91399
[37]	validation_0-auc:0.91430
[38]	validation_0-auc:0.91416
[39]	validation_0-auc:0.91489
[40]	validation_0-auc:0.91529
[41]	validation_0-auc:0.91536
[42]	validation_0-auc:0.91557
[43]	validation_0-auc:0.91643
[44]	valid

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[13]	validation_0-auc:0.89822
[14]	validation_0-auc:0.89895
[15]	validation_0-auc:0.90045
[16]	validation_0-auc:0.90187
[17]	validation_0-auc:0.90303
[18]	validation_0-auc:0.90584
[19]	validation_0-auc:0.90625
[20]	validation_0-auc:0.90550
[21]	validation_0-auc:0.90719
[22]	validation_0-auc:0.90754
[23]	validation_0-auc:0.90737
[24]	validation_0-auc:0.90778
[25]	validation_0-auc:0.90894
[26]	validation_0-auc:0.90903
[27]	validation_0-auc:0.91072
[28]	validation_0-auc:0.91032
[29]	validation_0-auc:0.91067
[30]	validation_0-auc:0.91141
[31]	validation_0-auc:0.91190
[32]	validation_0-auc:0.91303
[33]	validation_0-auc:0.91241
[34]	validation_0-auc:0.91239
[35]	validation_0-auc:0.91253
[36]	validation_0-auc:0.91355
[37]	validation_0-auc:0.91379
[38]	validation_0-auc:0.91358
[39]	validation_0-auc:0.91406
[40]	validation_0-auc:0.91429
[41]	validation_0-auc:0.91446
[42]	validation_0-auc:0.91489
[43]	validation_0-auc:0.91535
[44]	validation_0-auc:0.91680
[45]	validation_0-auc:0.91712
[46]	valid

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[11]	validation_0-auc:0.88869
[12]	validation_0-auc:0.88929
[13]	validation_0-auc:0.89243
[14]	validation_0-auc:0.89576
[15]	validation_0-auc:0.89769
[16]	validation_0-auc:0.90078
[17]	validation_0-auc:0.90205
[18]	validation_0-auc:0.90466
[19]	validation_0-auc:0.90399
[20]	validation_0-auc:0.90406
[21]	validation_0-auc:0.90522
[22]	validation_0-auc:0.90564
[23]	validation_0-auc:0.90751
[24]	validation_0-auc:0.90684
[25]	validation_0-auc:0.90857
[26]	validation_0-auc:0.90920
[27]	validation_0-auc:0.91066
[28]	validation_0-auc:0.91046
[29]	validation_0-auc:0.91117
[30]	validation_0-auc:0.91357
[31]	validation_0-auc:0.91461
[32]	validation_0-auc:0.91643
[33]	validation_0-auc:0.91751
[34]	validation_0-auc:0.91816
[35]	validation_0-auc:0.91942
[36]	validation_0-auc:0.92048
[37]	validation_0-auc:0.92160
[38]	validation_0-auc:0.92189
[39]	validation_0-auc:0.92165
[40]	validation_0-auc:0.92130
[41]	validation_0-auc:0.92138
[42]	validation_0-auc:0.92181
[43]	validation_0-auc:0.92139
[44]	valid

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[13]	validation_0-auc:0.89266
[14]	validation_0-auc:0.89912
[15]	validation_0-auc:0.90297
[16]	validation_0-auc:0.90502
[17]	validation_0-auc:0.90721
[18]	validation_0-auc:0.90872
[19]	validation_0-auc:0.90933
[20]	validation_0-auc:0.90931
[21]	validation_0-auc:0.91045
[22]	validation_0-auc:0.91126
[23]	validation_0-auc:0.91278
[24]	validation_0-auc:0.91446
[25]	validation_0-auc:0.91707
[26]	validation_0-auc:0.91784
[27]	validation_0-auc:0.91861
[28]	validation_0-auc:0.91962
[29]	validation_0-auc:0.92030
[30]	validation_0-auc:0.92106
[31]	validation_0-auc:0.92067
[32]	validation_0-auc:0.92162
[33]	validation_0-auc:0.92252
[34]	validation_0-auc:0.92376
[35]	validation_0-auc:0.92419
[36]	validation_0-auc:0.92485
[37]	validation_0-auc:0.92645
[38]	validation_0-auc:0.92796
[39]	validation_0-auc:0.92819
[40]	validation_0-auc:0.92761
[41]	validation_0-auc:0.92765
[42]	validation_0-auc:0.92800
[43]	validation_0-auc:0.92859
[44]	validation_0-auc:0.92892
[45]	validation_0-auc:0.92932
[46]	valid

  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, warn=True)


[15]	validation_0-auc:0.88402
[16]	validation_0-auc:0.88401
[17]	validation_0-auc:0.88542
[18]	validation_0-auc:0.88696
[19]	validation_0-auc:0.88862
[20]	validation_0-auc:0.88977
[21]	validation_0-auc:0.89187
[22]	validation_0-auc:0.89116
[23]	validation_0-auc:0.89248
[24]	validation_0-auc:0.89407
[25]	validation_0-auc:0.89634
[26]	validation_0-auc:0.89853
[27]	validation_0-auc:0.89978
[28]	validation_0-auc:0.90310
[29]	validation_0-auc:0.90432
[30]	validation_0-auc:0.90716
[31]	validation_0-auc:0.90839
[32]	validation_0-auc:0.90875
[33]	validation_0-auc:0.90960
[34]	validation_0-auc:0.91093
[35]	validation_0-auc:0.91255
[36]	validation_0-auc:0.91315
[37]	validation_0-auc:0.91447
[38]	validation_0-auc:0.91463
[39]	validation_0-auc:0.91545
[40]	validation_0-auc:0.91643
[41]	validation_0-auc:0.91699
[42]	validation_0-auc:0.91732
[43]	validation_0-auc:0.91741
[44]	validation_0-auc:0.91750
[45]	validation_0-auc:0.91762
[46]	validation_0-auc:0.91806
[47]	validation_0-auc:0.91813
[48]	valid

In [7]:
df.head()
np.savetxt("last_results.txt", df, fmt='%s', header='embed acc bal_acc f1 matt_corr time_taken', comments='')