In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
print(os.getcwd())
os.chdir('/content/drive/My Drive/1006')
print(os.getcwd())

/content
/content/drive/My Drive/1006


In [3]:
import pandas as pd
import numpy as np
import ast
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
from sklearn.model_selection import GridSearchCV
from collections import defaultdict

In [4]:
dataset_name = 'eo'
embed_types = ['cvec_pca16', 'cvec_nmf16', 'cvec_umap16', 'cvec_tsne16', 'bert', 'roberta', 'distil', 'glove6B', 'universal']
selection_types = ['taddy_0', 'taddy_1', 'taddy_2', 'taddy_3', 'kmeans_0', 'kmeans_1', 'kmeans_2', 'kmeans_3', 'kld', 'ks', 'cos', 'recon']
counts = [100, 200, 300, 400, 700, 1000, 1400, 1800, 2400, 3000, 3600, 4200]

## Results Dict



In [13]:
key_list = ['random_noembed_0', 'random_noembed_1', 'random_noembed_2', 'random_noembed_3', 'topics_taddy_0', 'topics_taddy_1', 'topics_taddy_2', 'topics_taddy_3'] + [str(i)+'_'+str(j) for i in embed_types for j in selection_types]
acc_dict = defaultdict(list)
f1_dict = defaultdict(list)
roc_dict = defaultdict(list)

## Complete dataset

In [6]:
# Training
data = pd.read_csv(dataset_name+'_cvec_train.csv', index_col=0)
y_train = data['label'].to_numpy()
X_train = data.drop(columns=['label']).to_numpy()

# Evaluation
data = pd.read_csv(dataset_name+'_cvec_test.csv', index_col=0)
y_test = data['label'].to_numpy()
X_test = data.drop(columns=['label']).to_numpy()

del data

In [11]:
parameters = {'alpha':[1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18]}

gscv = GridSearchCV(MultinomialNB(), parameters, verbose=1, scoring='roc_auc')
gscv.fit(X_train, y_train)
mnb = gscv.best_estimator_
print(gscv.best_params_)
print(accuracy_score(y_test, mnb.predict(X_test)))
print(f1_score(y_test, mnb.predict(X_test)))
print(roc_auc_score(y_test, mnb.predict_proba(X_test)[:,1]))

Fitting 5 folds for each of 17 candidates, totalling 85 fits


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done  85 out of  85 | elapsed:  1.1min finished


{'alpha': 16}
0.681042654028436
0.38313473877176907
0.6580137957158743


## Random Pick

In [14]:
for i in range(4):
  indices_list = []
  with open("indices_eo_random_"+str(i)+".txt") as fh: 
    lines = fh.readlines()
    for line in lines:
      indices_list.append(ast.literal_eval(line))
  for lst in indices_list:
    gscv = GridSearchCV(MultinomialNB(), parameters, verbose=0, scoring='roc_auc')
    gscv.fit(X_train[lst], y_train[lst])
    mnb = gscv.best_estimator_
    print(gscv.best_params_)
    acc_dict['random_noembed_'+str(i)].append(accuracy_score(y_test, mnb.predict(X_test)))
    f1_dict['random_noembed_'+str(i)].append(f1_score(y_test, mnb.predict(X_test)))
    roc_dict['random_noembed_'+str(i)].append(roc_auc_score(y_test, mnb.predict_proba(X_test)[:,1]))

{'alpha': 0.01}
{'alpha': 0.1}
{'alpha': 2}
{'alpha': 4}
{'alpha': 2}
{'alpha': 2}
{'alpha': 2}
{'alpha': 0.01}
{'alpha': 0.1}
{'alpha': 12}
{'alpha': 1e-07}
{'alpha': 0.1}
{'alpha': 1}
{'alpha': 2}
{'alpha': 1}
{'alpha': 2}
{'alpha': 2}
{'alpha': 4}
{'alpha': 6}
{'alpha': 0.1}
{'alpha': 4}
{'alpha': 4}
{'alpha': 8}
{'alpha': 0.001}
{'alpha': 18}
{'alpha': 1}
{'alpha': 1}
{'alpha': 0.01}
{'alpha': 2}
{'alpha': 0.1}
{'alpha': 1e-05}
{'alpha': 4}
{'alpha': 2}
{'alpha': 0.01}
{'alpha': 0.01}
{'alpha': 1e-07}
{'alpha': 12}
{'alpha': 0.01}
{'alpha': 1}
{'alpha': 2}
{'alpha': 4}
{'alpha': 2}
{'alpha': 0.1}
{'alpha': 4}
{'alpha': 2}
{'alpha': 1}
{'alpha': 1}
{'alpha': 6}


## K-means Clustering

In [15]:
for i in range(4):
  for j in range(len(embed_types)):
    indices_list = []
    with open('indices_'+dataset_name+'_'+embed_types[j]+'_kmeans_'+str(i)+'.txt') as fh:
        lines = fh.readlines()
        for line in lines:
          indices_list.append(ast.literal_eval(line))
    for lst in indices_list:
      gscv = GridSearchCV(MultinomialNB(), parameters, verbose=0, scoring='roc_auc')
      gscv.fit(X_train[lst], y_train[lst])
      mnb = gscv.best_estimator_
      acc_dict[embed_types[j]+'_kmeans_'+str(i)].append(accuracy_score(y_test, mnb.predict(X_test)))
      f1_dict[embed_types[j]+'_kmeans_'+str(i)].append(f1_score(y_test, mnb.predict(X_test)))
      roc_dict[embed_types[j]+'_kmeans_'+str(i)].append(roc_auc_score(y_test, mnb.predict_proba(X_test)[:,1]))

## Greedy farthest points based on KL Divergence

In [16]:
indices_list = []
for j in range(len(embed_types)):
  with open('indices_'+dataset_name+'_'+embed_types[j]+'_kld.txt') as fh:
    lines = fh.readlines()
    for line in lines:
      indices_list.append(ast.literal_eval(line))

In [17]:
for i, lst in enumerate(indices_list):
  for c in counts:
    gscv = GridSearchCV(MultinomialNB(), parameters, verbose=0, scoring='roc_auc')
    gscv.fit(X_train[lst[:c]], y_train[lst[:c]])
    mnb = gscv.best_estimator_
    acc_dict[embed_types[i]+'_kld'].append(accuracy_score(y_test, mnb.predict(X_test)))
    f1_dict[embed_types[i]+'_kld'].append(f1_score(y_test, mnb.predict(X_test)))
    roc_dict[embed_types[i]+'_kld'].append(roc_auc_score(y_test, mnb.predict_proba(X_test)[:,1]))

## Greedy farthest points based on Kolmogorov Smirnov statistics

In [18]:
indices_list = []
for j in range(len(embed_types)):
  with open('indices_'+dataset_name+'_'+embed_types[j]+'_ks.txt') as fh:
    lines = fh.readlines()
    for line in lines:
      indices_list.append(ast.literal_eval(line))

In [19]:
for i, lst in enumerate(indices_list):
  for c in counts:
    gscv = GridSearchCV(MultinomialNB(), parameters, verbose=0, scoring='roc_auc')
    gscv.fit(X_train[lst[:c]], y_train[lst[:c]])
    mnb = gscv.best_estimator_
    acc_dict[embed_types[i]+'_ks'].append(accuracy_score(y_test, mnb.predict(X_test)))
    f1_dict[embed_types[i]+'_ks'].append(f1_score(y_test, mnb.predict(X_test)))
    roc_dict[embed_types[i]+'_ks'].append(roc_auc_score(y_test, mnb.predict_proba(X_test)[:,1]))

## Greedy farthest points based on cosine distance

In [20]:
indices_list = []
for j in range(len(embed_types)):
  with open('indices_'+dataset_name+'_'+embed_types[j]+'_cos.txt') as fh:
    lines = fh.readlines()
    for line in lines:
      indices_list.append(ast.literal_eval(line))

In [21]:
for i, lst in enumerate(indices_list):
  for c in counts:
    gscv = GridSearchCV(MultinomialNB(), parameters, verbose=0, scoring='roc_auc')
    gscv.fit(X_train[lst[:c]], y_train[lst[:c]])
    mnb = gscv.best_estimator_
    acc_dict[embed_types[i]+'_cos'].append(accuracy_score(y_test, mnb.predict(X_test)))
    f1_dict[embed_types[i]+'_cos'].append(f1_score(y_test, mnb.predict(X_test)))
    roc_dict[embed_types[i]+'_cos'].append(roc_auc_score(y_test, mnb.predict_proba(X_test)[:,1]))

## Reconstruction Loss Minimization

In [22]:
indices_list = []
for j in range(len(embed_types)):
  indices_list.append(list(np.load('indices_'+dataset_name+'_'+embed_types[j]+'_recon.npy')))

In [23]:
for i, lst in enumerate(indices_list):
  for c in counts:
    gscv = GridSearchCV(MultinomialNB(), parameters, verbose=0, scoring='roc_auc')
    gscv.fit(X_train[lst[:c]], y_train[lst[:c]])
    mnb = gscv.best_estimator_
    acc_dict[embed_types[i]+'_recon'].append(accuracy_score(y_test, mnb.predict(X_test)))
    f1_dict[embed_types[i]+'_recon'].append(f1_score(y_test, mnb.predict(X_test)))
    roc_dict[embed_types[i]+'_recon'].append(roc_auc_score(y_test, mnb.predict_proba(X_test)[:,1]))

## Taddy

In [57]:
for k in range(4):
  for i in range(len(embed_types)):
    indices_list = []
    with open('indices_'+dataset_name+'_'+embed_types[i]+'_taddy_'+str(k)+'.txt') as fh:
      lines = fh.readlines()
      for line in lines:
        indices_list.append(ast.literal_eval(line))
    for j in range(len(indices_list[0])):
      gscv = GridSearchCV(MultinomialNB(), parameters, verbose=0, scoring='roc_auc')
      gscv.fit(X_train[indices_list[0][j]], y_train[indices_list[0][j]])
      mnb = gscv.best_estimator_
      acc_dict[embed_types[i]+'_taddy_'+str(k)].append(accuracy_score(y_test, mnb.predict(X_test)))
      f1_dict[embed_types[i]+'_taddy_'+str(k)].append(f1_score(y_test, mnb.predict(X_test)))
      roc_dict[embed_types[i]+'_taddy_'+str(k)].append(roc_auc_score(y_test, mnb.predict_proba(X_test)[:,1]))

## Taddy Topics

In [68]:
for k in range(4):
  indices_list = []
  with open('indices_'+dataset_name+'_topics_taddy_'+str(k)+'.txt') as fh:
    lines = fh.readlines()
    for line in lines:
      indices_list.append(ast.literal_eval(line))
  for j in range(len(indices_list[0])):
    gscv = GridSearchCV(MultinomialNB(), parameters, verbose=0, scoring='roc_auc')
    gscv.fit(X_train[indices_list[0][j]], y_train[indices_list[0][j]])
    mnb = gscv.best_estimator_
    acc_dict['topics_taddy_'+str(k)].append(accuracy_score(y_test, mnb.predict(X_test)))
    f1_dict['topics_taddy_'+str(k)].append(f1_score(y_test, mnb.predict(X_test)))
    roc_dict['topics_taddy_'+str(k)].append(roc_auc_score(y_test, mnb.predict_proba(X_test)[:,1]))

In [70]:
pd.DataFrame.from_dict(roc_dict, orient='index').transpose()

Unnamed: 0,random_noembed_0,random_noembed_1,random_noembed_2,random_noembed_3,cvec_pca16_kmeans_0,cvec_nmf16_kmeans_0,cvec_umap16_kmeans_0,cvec_tsne16_kmeans_0,bert_kmeans_0,roberta_kmeans_0,distil_kmeans_0,glove6B_kmeans_0,universal_kmeans_0,cvec_pca16_kmeans_1,cvec_nmf16_kmeans_1,cvec_umap16_kmeans_1,cvec_tsne16_kmeans_1,bert_kmeans_1,roberta_kmeans_1,distil_kmeans_1,glove6B_kmeans_1,universal_kmeans_1,cvec_pca16_kmeans_2,cvec_nmf16_kmeans_2,cvec_umap16_kmeans_2,cvec_tsne16_kmeans_2,bert_kmeans_2,roberta_kmeans_2,distil_kmeans_2,glove6B_kmeans_2,universal_kmeans_2,cvec_pca16_kmeans_3,cvec_nmf16_kmeans_3,cvec_umap16_kmeans_3,cvec_tsne16_kmeans_3,bert_kmeans_3,roberta_kmeans_3,distil_kmeans_3,glove6B_kmeans_3,universal_kmeans_3,...,cvec_pca16_taddy_0,cvec_nmf16_taddy_0,cvec_umap16_taddy_0,cvec_tsne16_taddy_0,bert_taddy_0,roberta_taddy_0,distil_taddy_0,glove6B_taddy_0,universal_taddy_0,cvec_pca16_taddy_1,cvec_nmf16_taddy_1,cvec_umap16_taddy_1,cvec_tsne16_taddy_1,bert_taddy_1,roberta_taddy_1,distil_taddy_1,glove6B_taddy_1,universal_taddy_1,cvec_pca16_taddy_2,cvec_nmf16_taddy_2,cvec_umap16_taddy_2,cvec_tsne16_taddy_2,bert_taddy_2,roberta_taddy_2,distil_taddy_2,glove6B_taddy_2,universal_taddy_2,cvec_pca16_taddy_3,cvec_nmf16_taddy_3,cvec_umap16_taddy_3,cvec_tsne16_taddy_3,bert_taddy_3,roberta_taddy_3,distil_taddy_3,glove6B_taddy_3,universal_taddy_3,topics_taddy_0,topics_taddy_1,topics_taddy_2,topics_taddy_3
0,0.565183,0.536345,0.532701,0.518372,0.524512,0.511829,0.540173,0.541613,0.55935,0.577143,0.591148,0.585124,0.518755,0.558481,0.597719,0.567513,0.502694,0.527444,0.600282,0.560546,0.593036,0.586988,0.497565,0.53928,0.582689,0.489051,0.562689,0.557821,0.558895,0.600469,0.527074,0.584392,0.606555,0.551708,0.588458,0.481823,0.535507,0.545239,0.517032,0.576083,...,0.586407,0.562914,0.438996,0.550648,0.559541,0.576119,0.51946,0.549146,0.590099,0.590388,0.563829,0.474972,0.500224,0.480603,0.603232,0.521142,0.547662,0.519331,0.562811,0.565269,0.481084,0.519816,0.573588,0.497968,0.492268,0.536399,0.589829,0.590862,0.558471,0.477121,0.593102,0.593532,0.505199,0.599818,0.561339,0.57347,0.55221,0.558492,0.55326,0.516108
1,0.563593,0.614749,0.52132,0.623877,0.580283,0.568702,0.597014,0.58597,0.617988,0.529856,0.507213,0.550492,0.541706,0.615791,0.571312,0.575906,0.594621,0.581187,0.575307,0.573154,0.593418,0.602914,0.57799,0.545504,0.577457,0.531563,0.566366,0.531806,0.626497,0.611455,0.603992,0.593978,0.547227,0.598164,0.620625,0.550789,0.633744,0.540944,0.604064,0.526541,...,0.579755,0.568065,0.532433,0.588658,0.556468,0.588265,0.544278,0.550132,0.570911,0.578421,0.566468,0.521153,0.594449,0.521119,0.608249,0.578422,0.501916,0.585404,0.574959,0.565127,0.576587,0.590974,0.554476,0.52583,0.54167,0.587258,0.586327,0.578176,0.566479,0.514649,0.594321,0.628544,0.562001,0.570077,0.562104,0.592035,0.507689,0.515044,0.526928,0.526428
2,0.596559,0.548057,0.622774,0.598999,0.590877,0.616259,0.596248,0.630045,0.595872,0.614715,0.587593,0.60791,0.589584,0.606526,0.58544,0.595353,0.605629,0.611956,0.559677,0.583918,0.620643,0.58987,0.612836,0.573875,0.610291,0.617412,0.561441,0.553533,0.601569,0.608552,0.61111,0.60675,0.607471,0.53676,0.62976,0.569241,0.529391,0.617054,0.565979,0.591945,...,0.554302,0.57523,0.530377,0.576128,0.617651,0.537598,0.574878,0.530935,0.578527,0.556529,0.576649,0.530954,0.590396,0.62292,0.621291,0.594013,0.566611,0.61035,0.555518,0.573614,0.52585,0.597827,0.54741,0.596566,0.548865,0.607796,0.526421,0.555642,0.575145,0.575045,0.565824,0.583242,0.538228,0.587762,0.596384,0.570661,0.513305,0.506155,0.549689,0.513699
3,0.609832,0.583116,0.584104,0.59688,0.591788,0.599731,0.536426,0.614315,0.613372,0.626276,0.570886,0.603536,0.618418,0.633536,0.605144,0.568868,0.624693,0.583113,0.569482,0.617639,0.624047,0.626834,0.631401,0.580647,0.623563,0.581508,0.641419,0.621614,0.55898,0.604829,0.583084,0.606902,0.610603,0.610939,0.582395,0.572634,0.626566,0.621549,0.613249,0.628475,...,0.564641,0.570084,0.539841,0.6062,0.633138,0.588942,0.607843,0.5006,0.599972,0.585824,0.569939,0.550332,0.592858,0.606051,0.560976,0.547098,0.597713,0.599516,0.568273,0.569451,0.541043,0.611368,0.535447,0.551782,0.644931,0.614349,0.60902,0.567627,0.569195,0.591466,0.592333,0.614106,0.641347,0.554651,0.540549,0.597422,0.529438,0.541159,0.529517,0.519372
4,0.627989,0.614252,0.584687,0.650099,0.598072,0.632415,0.645072,0.604598,0.627189,0.602281,0.597327,0.608754,0.619728,0.628956,0.628531,0.619832,0.590714,0.640132,0.624401,0.642236,0.627596,0.566984,0.618197,0.614141,0.556629,0.631916,0.596434,0.5911,0.574639,0.630412,0.630385,0.612731,0.623102,0.623977,0.626308,0.621698,0.619807,0.629743,0.627145,0.622385,...,0.580505,0.561093,0.604065,0.596235,0.612034,0.63069,0.603298,0.609089,0.605079,0.575071,0.551945,0.596783,0.597372,0.648507,0.627846,0.598571,0.607367,0.61426,0.584801,0.549455,0.594188,0.596749,0.647054,0.619054,0.583668,0.611171,0.607668,0.581637,0.551533,0.590586,0.61815,0.632589,0.621634,0.644983,0.601176,0.60162,0.604229,0.609466,0.614814,0.597768
5,0.617488,0.625793,0.615902,0.612136,0.64286,0.627253,0.611149,0.644387,0.621916,0.634909,0.627314,0.636007,0.633421,0.643123,0.625206,0.624995,0.622726,0.614041,0.644789,0.658642,0.614591,0.648159,0.638703,0.635236,0.629427,0.590954,0.638687,0.64588,0.643529,0.639379,0.638837,0.622162,0.620611,0.629457,0.633382,0.635174,0.628844,0.646391,0.625732,0.633748,...,0.582187,0.576936,0.617006,0.631124,0.646188,0.640343,0.646249,0.609107,0.609644,0.582166,0.578692,0.618822,0.613818,0.632384,0.631228,0.57234,0.619742,0.599225,0.581022,0.576192,0.615531,0.627291,0.649304,0.649707,0.661504,0.620136,0.604024,0.58237,0.576836,0.617549,0.61403,0.657494,0.627625,0.634293,0.621632,0.617148,0.620335,0.615797,0.610191,0.616797
6,0.653137,0.655341,0.596235,0.617024,0.647977,0.669077,0.644818,0.622197,0.638099,0.639048,0.638738,0.641923,0.630298,0.639994,0.634478,0.632769,0.639663,0.637727,0.647541,0.630842,0.632248,0.63983,0.637634,0.635787,0.632142,0.654452,0.614133,0.648979,0.655641,0.646454,0.634593,0.650278,0.630907,0.654197,0.636645,0.628881,0.650831,0.62656,0.634487,0.62926,...,0.615131,0.600984,0.613915,0.641458,0.656132,0.655555,0.639365,0.629152,0.614542,0.613215,0.600984,0.613885,0.642062,0.651928,0.622301,0.605646,0.62762,0.625048,0.611947,0.601062,0.61479,0.640853,0.655463,0.66189,0.596431,0.625936,0.610105,0.610099,0.59918,0.612066,0.639656,0.648476,0.644809,0.575446,0.625773,0.615801,0.59855,0.631236,0.628297,0.62758
7,0.614468,0.646044,0.64625,0.642046,0.64552,0.654525,0.662963,0.640333,0.645773,0.63512,0.640348,0.648499,0.638791,0.621963,0.637151,0.619233,0.654276,0.653166,0.653044,0.651004,0.648053,0.649107,0.63051,0.64675,0.655291,0.628344,0.661427,0.6378,0.641825,0.646421,0.6453,0.640184,0.635086,0.649311,0.655006,0.650818,0.646519,0.645177,0.64421,0.636034,...,0.632164,0.63189,0.621801,0.638586,0.658372,0.650356,0.614289,0.635512,0.614302,0.632018,0.631797,0.620997,0.635892,0.630942,0.649787,0.613101,0.632632,0.613797,0.632057,0.624976,0.625806,0.634598,0.643894,0.641932,0.583579,0.632815,0.61341,0.631859,0.630966,0.627397,0.636673,0.653273,0.653939,0.623166,0.634816,0.603723,0.636005,0.635819,0.638272,0.642692
8,0.628412,0.649459,0.635273,0.650689,0.632804,0.652674,0.647317,0.636886,0.666052,0.654805,0.647624,0.645909,0.662905,0.65198,0.654036,0.654688,0.650962,0.656152,0.649414,0.64384,0.634268,0.641835,0.645131,0.637514,0.655881,0.646247,0.637796,0.660723,0.655879,0.642325,0.651085,0.643858,0.638883,0.654174,0.655293,0.642124,0.654511,0.6309,0.640915,0.655203,...,0.629418,0.638877,0.631098,0.650779,0.645761,0.655755,0.645192,0.639629,0.615667,0.629653,0.639061,0.627351,0.65092,0.648399,0.652195,0.643371,0.644149,0.621084,0.629587,0.638771,0.630977,0.650224,0.65018,0.650005,0.634762,0.644149,0.615176,0.631411,0.638881,0.632936,0.649865,0.647826,0.65127,0.645747,0.639749,0.617357,0.642909,0.643192,0.643542,0.643519
9,0.632054,0.648455,0.640639,0.645923,0.639246,0.644853,0.651266,0.657087,0.647995,0.649929,0.646623,0.639486,0.656292,0.653686,0.661698,0.652392,0.654092,0.659211,0.649895,0.654271,0.64385,0.652059,0.646695,0.649328,0.641327,0.639314,0.65922,0.640676,0.669427,0.645617,0.650467,0.655585,0.65512,0.639299,0.650618,0.661324,0.653529,0.660681,0.651369,0.645795,...,0.653333,0.65709,0.632264,0.644443,0.642149,0.65918,0.647409,0.638528,0.614784,0.653005,0.653388,0.632045,0.643936,0.652525,0.656487,0.647649,0.638195,0.629826,0.652448,0.657772,0.629776,0.643319,0.647792,0.653602,0.649407,0.6416,0.633997,0.652753,0.653241,0.629388,0.641431,0.641522,0.65337,0.646859,0.637444,0.628011,0.655634,0.649662,0.649675,0.649786


## Save dicts to csv

In [71]:
pd.DataFrame.from_dict(acc_dict, orient='index').transpose().to_csv('eo_acc_mnb.csv')
pd.DataFrame.from_dict(f1_dict, orient='index').transpose().to_csv('eo_f1_mnb.csv')
pd.DataFrame.from_dict(roc_dict, orient='index').transpose().to_csv('eo_roc_mnb.csv')