In [1]:
import os
import sys
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.metrics import classification_report, make_scorer, recall_score, f1_score
from sklearn.preprocessing import OneHotEncoder
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.utils import class_weight

In [2]:
t = TabNetClassifier()



In [3]:
t.get_params()

{'cat_dims': [],
 'cat_emb_dim': [],
 'cat_idxs': [],
 'clip_value': 1,
 'device_name': 'auto',
 'epsilon': 1e-15,
 'gamma': 1.3,
 'grouped_features': [],
 'input_dim': None,
 'lambda_sparse': 0.001,
 'mask_type': 'sparsemax',
 'momentum': 0.02,
 'n_a': 8,
 'n_d': 8,
 'n_indep_decoder': 1,
 'n_independent': 2,
 'n_shared': 2,
 'n_shared_decoder': 1,
 'n_steps': 3,
 'optimizer_fn': torch.optim.adam.Adam,
 'optimizer_params': {'lr': 0.02},
 'output_dim': None,
 'scheduler_fn': None,
 'scheduler_params': {},
 'seed': 0,
 'verbose': 1}

In [21]:
data = np.load("../../stroke_data/features")
pvals = np.load('../../stroke_data/features_p_values')
targets = np.array(np.load("../../stroke_data/targets", allow_pickle=True)) - 1 # (1, 2) -> (0, 1)
data = data[:, np.argsort(pvals)[:30000]]

In [22]:
res = {}
skf = StratifiedKFold(n_splits=5)

for i, (train_index, test_index) in enumerate(skf.split(data, targets)):
    print(f"training on fold {i}")
    X_train, y_train = data[train_index, :], targets[train_index]
    X_test, y_test = data[test_index, :], targets[test_index]

    unique_classes = np.unique(y_train)
    weights = class_weight.compute_class_weight('balanced', classes=unique_classes, y=y_train)
    class_weights = dict(zip(unique_classes, weights))

    cat_count = X_train.shape[1]
    tb = TabNetClassifier(cat_idxs=list(range(cat_count)), cat_dims = [3]*cat_count, cat_emb_dim = [4]*cat_count)
    tb.fit(X_train, y_train, batch_size=256, max_epochs=50, weights=class_weights) # weights = 1
    y_pred = tb.predict(X_test)
    res[f"fold {i}"] = classification_report(y_test, y_pred, output_dict=True)

training on fold 0




epoch 0  | loss: 0.74469 |  0:00:33s
epoch 1  | loss: 0.69518 |  0:01:04s
epoch 2  | loss: 0.6787  |  0:01:36s
epoch 3  | loss: 0.63119 |  0:02:09s
epoch 4  | loss: 0.53234 |  0:02:40s
epoch 5  | loss: 0.38425 |  0:03:10s
epoch 6  | loss: 0.24496 |  0:03:42s
epoch 7  | loss: 0.14068 |  0:04:13s
epoch 8  | loss: 0.10659 |  0:04:43s
epoch 9  | loss: 0.07503 |  0:05:14s
epoch 10 | loss: 0.05997 |  0:05:44s
epoch 11 | loss: 0.06684 |  0:06:16s
epoch 12 | loss: 0.04494 |  0:06:46s
epoch 13 | loss: 0.05505 |  0:07:18s
epoch 14 | loss: 0.03892 |  0:07:50s
epoch 15 | loss: 0.03766 |  0:08:20s
epoch 16 | loss: 0.0249  |  0:08:50s
epoch 17 | loss: 0.03377 |  0:10:09s
epoch 18 | loss: 0.04402 |  0:11:05s
epoch 19 | loss: 0.02688 |  0:11:36s
epoch 20 | loss: 0.03035 |  0:12:08s
epoch 21 | loss: 0.05566 |  0:12:38s
epoch 22 | loss: 0.04064 |  0:13:06s
epoch 23 | loss: 0.02718 |  0:13:34s
epoch 24 | loss: 0.02172 |  0:14:03s
epoch 25 | loss: 0.03321 |  0:14:31s
epoch 26 | loss: 0.02824 |  0:15:00s
e



epoch 0  | loss: 0.73745 |  0:00:28s
epoch 1  | loss: 0.69068 |  0:00:56s
epoch 2  | loss: 0.67269 |  0:01:24s
epoch 3  | loss: 0.63371 |  0:01:53s
epoch 4  | loss: 0.57954 |  0:02:21s
epoch 5  | loss: 0.47059 |  0:02:49s
epoch 6  | loss: 0.35685 |  0:03:17s
epoch 7  | loss: 0.26809 |  0:03:47s
epoch 8  | loss: 0.16937 |  0:04:16s
epoch 9  | loss: 0.11413 |  0:04:44s
epoch 10 | loss: 0.07654 |  0:05:12s
epoch 11 | loss: 0.05712 |  0:05:40s
epoch 12 | loss: 0.03965 |  0:06:08s
epoch 13 | loss: 0.03605 |  0:06:36s
epoch 14 | loss: 0.04543 |  0:07:05s
epoch 15 | loss: 0.04054 |  0:07:34s
epoch 16 | loss: 0.0333  |  0:08:02s
epoch 17 | loss: 0.0365  |  0:08:30s
epoch 18 | loss: 0.02571 |  0:08:58s
epoch 19 | loss: 0.01554 |  0:09:26s
epoch 20 | loss: 0.01503 |  0:09:54s
epoch 21 | loss: 0.02046 |  0:10:22s
epoch 22 | loss: 0.02128 |  0:10:50s
epoch 23 | loss: 0.02843 |  0:11:18s
epoch 24 | loss: 0.03401 |  0:11:46s
epoch 25 | loss: 0.0233  |  0:12:14s
epoch 26 | loss: 0.02117 |  0:12:42s
e



epoch 0  | loss: 0.73792 |  0:00:28s
epoch 1  | loss: 0.68778 |  0:00:56s
epoch 2  | loss: 0.67899 |  0:01:23s
epoch 3  | loss: 0.64966 |  0:01:51s
epoch 4  | loss: 0.57742 |  0:02:20s
epoch 5  | loss: 0.503   |  0:02:48s
epoch 6  | loss: 0.3633  |  0:03:16s
epoch 7  | loss: 0.24714 |  0:03:44s
epoch 8  | loss: 0.18893 |  0:04:12s
epoch 9  | loss: 0.14759 |  0:04:40s
epoch 10 | loss: 0.08355 |  0:05:07s
epoch 11 | loss: 0.07511 |  0:05:35s
epoch 12 | loss: 0.07    |  0:06:03s
epoch 13 | loss: 0.05424 |  0:06:31s
epoch 14 | loss: 0.03535 |  0:06:59s
epoch 15 | loss: 0.0386  |  0:07:27s
epoch 16 | loss: 0.03992 |  0:07:55s
epoch 17 | loss: 0.04342 |  0:08:23s
epoch 18 | loss: 0.03329 |  0:08:50s
epoch 19 | loss: 0.04105 |  0:09:18s
epoch 20 | loss: 0.02377 |  0:09:46s
epoch 21 | loss: 0.02201 |  0:10:14s
epoch 22 | loss: 0.02125 |  0:10:42s
epoch 23 | loss: 0.01793 |  0:11:09s
epoch 24 | loss: 0.01024 |  0:11:37s
epoch 25 | loss: 0.01367 |  0:12:05s
epoch 26 | loss: 0.01326 |  0:12:33s
e



epoch 0  | loss: 0.74542 |  0:00:28s
epoch 1  | loss: 0.68902 |  0:00:56s
epoch 2  | loss: 0.67199 |  0:01:24s
epoch 3  | loss: 0.62883 |  0:01:52s
epoch 4  | loss: 0.53365 |  0:02:20s
epoch 5  | loss: 0.36975 |  0:02:48s
epoch 6  | loss: 0.25928 |  0:03:16s
epoch 7  | loss: 0.16626 |  0:03:44s
epoch 8  | loss: 0.11491 |  0:04:12s
epoch 9  | loss: 0.07215 |  0:04:40s
epoch 10 | loss: 0.0492  |  0:05:09s
epoch 11 | loss: 0.04091 |  0:05:37s
epoch 12 | loss: 0.02559 |  0:06:05s
epoch 13 | loss: 0.02401 |  0:06:33s
epoch 14 | loss: 0.02829 |  0:07:02s
epoch 15 | loss: 0.02889 |  0:07:30s
epoch 16 | loss: 0.02411 |  0:07:59s
epoch 17 | loss: 0.01863 |  0:08:27s
epoch 18 | loss: 0.02599 |  0:08:55s
epoch 19 | loss: 0.02987 |  0:09:24s
epoch 20 | loss: 0.02249 |  0:09:52s
epoch 21 | loss: 0.02271 |  0:10:20s
epoch 22 | loss: 0.02539 |  0:10:48s
epoch 23 | loss: 0.02076 |  0:11:16s
epoch 24 | loss: 0.01803 |  0:11:44s
epoch 25 | loss: 0.01645 |  0:12:13s
epoch 26 | loss: 0.02133 |  0:12:41s
e



epoch 0  | loss: 0.7382  |  0:00:28s
epoch 1  | loss: 0.68766 |  0:00:56s
epoch 2  | loss: 0.67431 |  0:01:24s
epoch 3  | loss: 0.64444 |  0:01:52s
epoch 4  | loss: 0.58384 |  0:02:20s
epoch 5  | loss: 0.53231 |  0:02:48s
epoch 6  | loss: 0.39749 |  0:03:16s
epoch 7  | loss: 0.25287 |  0:03:44s
epoch 8  | loss: 0.17177 |  0:04:12s
epoch 9  | loss: 0.08906 |  0:04:40s
epoch 10 | loss: 0.07127 |  0:05:08s
epoch 11 | loss: 0.04981 |  0:05:36s
epoch 12 | loss: 0.03561 |  0:06:04s
epoch 13 | loss: 0.04294 |  0:06:32s
epoch 14 | loss: 0.02536 |  0:07:01s
epoch 15 | loss: 0.03045 |  0:07:29s
epoch 16 | loss: 0.02834 |  0:07:58s
epoch 17 | loss: 0.03176 |  0:08:26s
epoch 18 | loss: 0.04467 |  0:08:54s
epoch 19 | loss: 0.03857 |  0:09:23s
epoch 20 | loss: 0.0336  |  0:09:51s
epoch 21 | loss: 0.05582 |  0:10:19s
epoch 22 | loss: 0.03037 |  0:10:47s
epoch 23 | loss: 0.03312 |  0:11:15s
epoch 24 | loss: 0.03543 |  0:11:43s
epoch 25 | loss: 0.03257 |  0:12:11s
epoch 26 | loss: 0.01949 |  0:12:39s
e

In [23]:
# 0 - controls
# 1 - cases
aggregated = {"0": {"precision": [], "recall": [], "f1-score": [], "accuracy": []},
              "1": {"precision": [], "recall": [], "f1-score": [], "accuracy": []}}
for _, fold in res.items():
    for key in ["0", "1"]:
        aggregated[key]["precision"].append(fold[key]["precision"])
        aggregated[key]["recall"].append(fold[key]["recall"])
        aggregated[key]["f1-score"].append(fold[key]["f1-score"])
        aggregated[key]["accuracy"].append(fold["accuracy"])

df = pd.DataFrame(aggregated["0"])
mean_row = pd.DataFrame([df.mean()], index=['mean'])
std_row = pd.DataFrame([df.std()], index=['std'])
df = pd.concat([df, mean_row, std_row])
df

Unnamed: 0,precision,recall,f1-score,accuracy
0,1.0,0.538462,0.7,0.946092
1,0.917647,0.604651,0.728972,0.947842
2,0.923077,0.369231,0.527473,0.922662
3,0.183333,0.084615,0.115789,0.848921
4,0.588235,0.076923,0.136054,0.885791
mean,0.722459,0.334776,0.441658,0.910261
std,0.340605,0.247273,0.298428,0.04246


In [24]:
df = pd.DataFrame(aggregated["1"])
mean_row = pd.DataFrame([df.mean()], index=['mean'])
std_row = pd.DataFrame([df.std()], index=['std'])
df = pd.concat([df, mean_row, std_row])
df

Unnamed: 0,precision,recall,f1-score,accuracy
0,0.942474,1.0,0.970385,0.946092
1,0.950341,0.992879,0.971144,0.947842
2,0.922642,0.995927,0.957884,0.922662
3,0.886882,0.950102,0.917404,0.848921
4,0.890411,0.992872,0.938854,0.885791
mean,0.91855,0.986356,0.951134,0.910261
std,0.029131,0.020476,0.022941,0.04246
