In [37]:
import torch
import pandas as pd
import numpy as np
import sklearn

from sklearn.covariance import EllipticEnvelope
from sklearn.metrics import f1_score, confusion_matrix, classification_report

import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings(action='ignore')

In [38]:
pd.options.display.float_format = '{:.2f}'.format

## Data load

In [39]:
train = pd.read_csv('../dataset/train.csv')
val = pd.read_csv('../dataset/val.csv')
test = pd.read_csv('../dataset/test.csv')

In [40]:
val_normal, val_abnormal = val.Class.value_counts()
val_ratio = val_abnormal / val_normal
print(val_normal, val_abnormal, val_ratio)


28432 30 0.0010551491277433877


In [41]:
train_x = train.drop(columns=['ID'])

In [42]:
model = EllipticEnvelope(support_fraction=0.994, contamination=val_ratio, random_state=42)
model.fit(train_x)

In [43]:
val_x = val.drop(columns=['ID', 'Class'])
val_y = val['Class']

In [57]:
prob = model.score_samples(val_x)
prob = torch.tensor(prob, dtype=torch.float)
topk_indices = torch.topk(prob, k=30, largest=False).indices

In [62]:
prob[topk_indices]

tensor([-1.6300e+08, -1.6231e+08, -1.6151e+08, -1.6070e+08, -1.6070e+08,
        -1.6055e+08, -1.6034e+08, -1.4058e+08, -1.4057e+08, -1.3992e+08,
        -1.1355e+08, -1.1321e+08, -1.1306e+08, -5.6672e+07, -5.2002e+07,
        -3.1518e+07, -1.9844e+07, -1.8336e+07, -1.0434e+07, -1.0420e+07,
        -9.4871e+06, -9.4738e+06, -9.4321e+06, -7.0460e+06, -6.9968e+06,
        -6.9921e+06, -4.7260e+06, -4.7208e+06, -3.5430e+06, -3.5013e+06])

In [61]:
prob[topk_indices][-1]

tensor(-3501284.7500)

In [60]:
-3.5013e+06

-3501300.0

In [45]:
torch.topk(prob, k=29, largest=False)

torch.return_types.topk(
values=tensor([-1.6300e+08, -1.6231e+08, -1.6151e+08, -1.6070e+08, -1.6070e+08,
        -1.6055e+08, -1.6034e+08, -1.4058e+08, -1.4057e+08, -1.3992e+08,
        -1.1355e+08, -1.1321e+08, -1.1306e+08, -5.6672e+07, -5.2002e+07,
        -3.1518e+07, -1.9844e+07, -1.8336e+07, -1.0434e+07, -1.0420e+07,
        -9.4871e+06, -9.4738e+06, -9.4321e+06, -7.0460e+06, -6.9968e+06,
        -6.9921e+06, -4.7260e+06, -4.7208e+06, -3.5430e+06]),
indices=tensor([15054, 15345,  4396, 15030, 15029, 15027,  4267,  1196,  1201,   836,
         1547,  1210,  1047,  3055, 15425, 24358, 25042,  7702, 24742, 15306,
          677, 12797,   641,  7000, 27998, 25504, 12377, 13706, 24110]))

In [10]:
pred = torch.zeros(len(val_x), dtype=torch.long)
pred[topk_indices] = 1

In [11]:
def get_pred(model, x, k) :
    prob = model.score_samples(x)
    prob = torch.tensor(prob, dtype=torch.float)
    topk_indices = torch.topk(prob, k=k, largest=False).indices
    
    pred = torch.zeros(len(x), dtype=torch.int8)
    pred[topk_indices] = 1
    
    return pred.tolist(), prob.tolist()

In [12]:
val_pred, val_prob = get_pred(model, val_x, 29)

In [13]:
val_score = f1_score(val_y, val_pred, average='macro')
val_score

0.9236496787663914

In [14]:
print(classification_report(val_y, val_pred))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00     28432
           1       0.86      0.83      0.85        30

    accuracy                           1.00     28462
   macro avg       0.93      0.92      0.92     28462
weighted avg       1.00      1.00      1.00     28462



In [15]:
tn, fp, fn, tp = confusion_matrix(val_y, val_pred).ravel()
print('tp : ', tp, ', fp : ', fp, ', tn : ', tn, ', fn : ', fn)

tp :  25 , fp :  4 , tn :  28428 , fn :  5


In [16]:
confusion_matrix(val_y, val_pred)

array([[28428,     4],
       [    5,    25]])

In [17]:
25/30

0.8333333333333334

In [18]:
24/28

0.8571428571428571

In [19]:
answer = np.where(np.array(val_y) != np.array(val_pred))[0]

In [20]:
answer.shape

(9,)

In [21]:
answer

array([   71,  1047,  1210,  4039,  7000,  9326, 14221, 15306, 28146])

In [22]:
val.iloc[answer, :]

Unnamed: 0,ID,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V22,V23,V24,V25,V26,V27,V28,V29,V30,Class
71,624,-3.04,-3.16,1.09,2.29,1.36,-1.06,0.33,-0.07,-0.27,...,0.44,1.38,-0.29,0.28,-0.15,-0.25,0.04,7.08,-0.99,1
1047,10457,-1.69,4.08,-7.77,5.57,-3.7,-3.12,-7.85,1.75,-3.27,...,-0.54,-0.32,-0.02,1.18,-0.21,1.91,0.95,-0.29,-0.79,0
1210,12157,-5.91,6.04,-11.09,5.8,-6.19,-3.32,-9.27,4.34,-3.25,...,-0.43,0.21,-0.03,0.24,-0.29,2.02,0.63,0.95,-0.75,0
4039,40526,1.16,2.84,-4.05,4.78,2.95,-2.01,1.74,-0.41,-2.45,...,-0.43,-0.53,-0.6,1.34,0.55,0.01,0.16,-0.29,-0.52,1
7000,70037,-7.45,-3.48,-5.39,-0.06,-4.01,-0.78,0.95,0.53,-0.04,...,-1.36,-1.83,-0.34,-0.3,-0.46,1.19,-0.85,8.94,-0.36,0
9326,93789,1.08,0.96,-0.28,2.74,0.41,-0.32,0.04,0.18,-0.97,...,-0.06,-0.05,-0.03,0.4,0.07,0.03,0.06,-0.31,-0.24,1
14221,142558,-1.43,-0.8,1.12,0.39,-0.28,-0.06,1.33,0.2,-0.55,...,0.01,0.84,0.11,0.16,-0.62,-0.12,0.04,4.64,0.0,1
15306,153458,-3.14,2.01,-0.61,5.95,-1.87,0.76,-3.31,0.23,-0.94,...,0.8,0.51,-0.05,-0.51,0.44,-0.17,0.68,0.33,0.17,0
28146,281675,1.99,0.16,-2.58,0.41,1.15,-0.1,0.22,-0.07,0.58,...,-0.3,-0.07,-0.45,0.31,-0.29,0.0,-0.02,0.29,1.01,1


In [23]:
val.iloc[answer, :]['Class']

71       1
1047     0
1210     0
4039     1
7000     0
9326     1
14221    1
15306    0
28146    1
Name: Class, dtype: int64

In [24]:
# corr_df = val[val['Class'] == 1].drop(columns=['ID', 'Class']).transpose().corr().round(2)
# corr_df

In [25]:
# corr_df.style.background_gradient(cmap='coolwarm')

In [26]:
# tmp_df = pd.concat([val.iloc[[71,4039, 9326, 14221, 28146], :], val.iloc[:25, :]])

In [27]:
# tmp_df.drop(columns=['ID', 'Class']).transpose().corr().style.background_gradient(cmap='coolwarm')

In [28]:
tmp_df = pd.concat([val.iloc[answer, :], val[val['Class'] == 1]])
# tmp_df.drop_duplicates()

In [29]:
tmp_df.drop_duplicates().drop(columns=['Class', 'ID']).transpose().corr().style.background_gradient(cmap='coolwarm')

Unnamed: 0,71,1047,1210,4039,7000,9326,14221,15306,28146,641,677,836,1196,1201,1547,3055,4267,4396,7702,12377,12797,13706,15027,15029,15030,15054,15345,15425,24110,24358,24742,25042,25504,27998
71,1.0,0.088775,0.113378,0.156827,0.623358,0.049562,0.709394,0.143056,-0.039229,0.144502,0.132161,0.078777,0.142229,0.138155,0.083301,0.087708,0.09433,0.130778,0.370575,0.166524,0.227951,0.195633,0.136623,0.127766,0.127766,0.086529,0.092447,0.080694,0.285927,0.209679,0.174253,0.147183,0.109802,0.002241
1047,0.088775,1.0,0.975476,0.183556,0.40896,0.229687,0.0534,0.82494,0.151044,0.711428,0.698985,0.989796,0.811057,0.800891,0.721848,0.937042,0.84832,0.745004,0.854127,0.735322,0.840319,0.817054,0.912849,0.872392,0.872392,0.552177,0.653555,0.923234,0.671674,0.840365,0.684072,0.796931,0.734022,0.87668
1210,0.113378,0.975476,1.0,0.17119,0.491602,0.196206,0.076364,0.838108,0.127721,0.690682,0.700491,0.975756,0.897961,0.892145,0.85185,0.950257,0.924339,0.847275,0.878753,0.730012,0.863415,0.804069,0.970814,0.950709,0.950709,0.66714,0.739686,0.945065,0.679032,0.838266,0.784413,0.803283,0.711639,0.846763
4039,0.156827,0.183556,0.17119,1.0,-0.105652,0.845705,-0.048136,0.054926,0.587642,0.794204,0.768324,0.14061,0.302756,0.298319,0.114497,0.021988,-0.040007,0.017632,0.118556,0.694622,0.178166,0.519135,0.194797,0.193201,0.193201,0.129336,0.144226,0.029978,0.592903,0.505729,0.030573,0.673746,0.741824,0.457197
7000,0.623358,0.40896,0.491602,-0.105652,1.0,-0.206579,0.621784,0.383963,-0.045287,0.150647,0.156409,0.412278,0.482221,0.48184,0.537959,0.433848,0.542605,0.53996,0.68926,0.218306,0.547126,0.254047,0.520995,0.543183,0.543183,0.384871,0.362901,0.426361,0.265201,0.337363,0.62903,0.245739,0.148482,0.28677
9326,0.049562,0.229687,0.196206,0.845705,-0.206579,1.0,0.051132,0.179607,0.381411,0.723143,0.692095,0.190681,0.284763,0.27916,0.107978,0.085431,-0.007927,0.045887,0.109425,0.649773,0.135413,0.48365,0.183483,0.175746,0.175746,0.118841,0.159668,0.126512,0.461803,0.461615,0.088804,0.607578,0.70828,0.389607
14221,0.709394,0.0534,0.076364,-0.048136,0.621784,0.051132,1.0,0.094003,-0.238759,-0.007133,-0.071754,0.019369,0.02444,0.022399,0.057221,0.02802,0.052132,0.080119,0.384488,0.093135,0.180068,0.052962,0.042003,0.064284,0.064284,0.032735,0.02111,0.031423,-0.035505,-0.007702,0.314317,-0.061965,-0.036544,-0.084154
15306,0.143056,0.82494,0.838108,0.054926,0.383963,0.179607,0.094003,1.0,-0.148765,0.549473,0.576895,0.834714,0.766925,0.760182,0.671038,0.856933,0.783267,0.738601,0.707888,0.516887,0.672539,0.711481,0.819586,0.782456,0.782456,0.592417,0.69597,0.902403,0.538109,0.772087,0.742195,0.62894,0.555768,0.58689
28146,-0.039229,0.151044,0.127721,0.587642,-0.045287,0.381411,-0.238759,-0.148765,1.0,0.499643,0.529133,0.14145,0.186711,0.181995,0.051495,-0.02609,-0.021673,-0.0024,0.101343,0.444886,0.048894,0.261151,0.152753,0.151396,0.151396,0.071561,0.071159,0.018916,0.404669,0.322096,-0.068268,0.505935,0.587922,0.552554
641,0.144502,0.711428,0.690682,0.794204,0.150647,0.723143,-0.007133,0.549473,0.499643,1.0,0.971303,0.670815,0.693363,0.684402,0.495218,0.545132,0.455384,0.445009,0.559543,0.905204,0.587956,0.820914,0.666025,0.644727,0.644727,0.427819,0.493628,0.580534,0.794207,0.823547,0.419697,0.927906,0.941741,0.814083
