In [None]:
import pandas as pd
from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.tab_model import TabNetClassifier
import torch
from sklearn.metrics import precision_score, recall_score, f1_score

In [None]:
train_data = pd.read_csv('train_data_MEWS.csv')
val_data = pd.read_csv('val_resampled_data_MEWS.csv')

In [None]:
X_train = train_data.drop(['diagnosis'], axis = 1)
y_train = train_data['diagnosis']

X_val = val_data.drop(['diagnosis'], axis = 1)
y_val = val_data['diagnosis']

In [None]:
unsupervised_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=0.05),
    mask_type='sparsemax'
)

unsupervised_model.fit(
    X_train=X_train.values,
    eval_set=[X_val.values],
    pretraining_ratio=1.0,
)

In [None]:
clf = TabNetClassifier(
    optimizer_fn=torch.optim.AdamW,
    optimizer_params=dict(lr=0.05),
    scheduler_params={
		      #"step_size": 10,
                      "gamma": 1.0,
		      "milestones": [50, 60, 70]
},
#    scheduler_fn=torch.optim.lr_scheduler.StepLR,
   scheduler_fn=torch.optim.lr_scheduler.MultiStepLR,
   mask_type='sparsemax'
)

clf.fit(
    X_train=X_train.values, y_train=y_train.values,
    eval_set=[(X_val.values, y_val.values)],
    eval_metric=['auc'],
    max_epochs=200,
    patience=50,
    from_unsupervised=unsupervised_model
)

In [None]:
with open('fimp.txt', 'w') as f:
    f.write(f'{clf.feature_importances_}')

saving_path_name = "./tabnet_model_test_1"
saved_filepath = clf.save_model(saving_path_name)

In [None]:
saving_path_name = "./tabnet_model_test_1.zip"
loaded_clf = TabNetClassifier()
loaded_clf.load_model(saving_path_name)

In [None]:
test_data = pd.read_csv('test_resampled_data_MEWS.csv')

X_test = test_data.drop(['diagnosis'], axis = 1)
y_test = test_data['diagnosis']

In [None]:
result = loaded_clf.predict(X_test.values)
accuracy = (result == y_test.values).mean()
precision = precision_score(y_test.values, result)
recall = recall_score(y_test.values, result)
f1 = f1_score(y_test.values, result)

print(f'Accuracy: {accuracy}\n')
print(f'Precision: {precision}\n')
print(f'Recall: {recall}\n')
print(f'F1-score: {f1}\n')