In [38]:
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
import pandas as pd
import json
import tracemalloc
import time
from pytorch_tabnet.callbacks import Callback


In [39]:
class TimeMemoryCallback(Callback):
    def __init__(self):
        self.epoch_times = []
        self.epoch_memory_usage = []

    def on_epoch_begin(self, epoch_idx, logs=None):
        self.epoch_start_time = time.time()
        tracemalloc.start()

    def on_epoch_end(self, epoch_idx, logs=None):
        end_time = time.time()
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()

        self.epoch_times.append(end_time - self.epoch_start_time)
        self.epoch_memory_usage.append({
            "current_MB": current / 1024 / 1024,
            "peak_MB": peak / 1024 / 1024
        })


def main_tabnet(dataset, epochs=10000):

    df = pd.read_csv(dataset)
    X = df.drop(columns=["Diabetes_012"])
    y = df["Diabetes_012"]

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
    model = TabNetClassifier()

    time_memory_callback = TimeMemoryCallback()
    model.fit(
        X_train=X_train, 
        y_train=y_train,
        eval_set=[(X_test, y_test)],
        eval_name=["val"],
        eval_metric=["accuracy"],
        max_epochs=epochs,
        patience=5,
        batch_size=1024,
        callbacks=[time_memory_callback]
    )
    
    y_pred = model.predict(X_test)
    print(f"Accuracy: {accuracy_score(y_test, y_pred) * 100:.2f}%")
    print(f"F1 Score: {f1_score(y_test, y_pred, average='macro') * 100:.2f}%")
    print(f"Precision: {precision_score(y_test, y_pred, average='macro') * 100:.2f}%")
    print(f"Recall: {recall_score(y_test, y_pred, average='macro') * 100:.2f}%")
    print(f"Confusion Matrix:\n{confusion_matrix(y_test, y_pred)}")

    # Convert confusion matrix to list of lists
    cm = confusion_matrix(y_test, y_pred)
    cm = [[int(x) for x in row] for row in cm]

    metrics = {
        "epochs": epochs,
        "confusion_matrix": cm,
        "epoch_times": time_memory_callback.epoch_times,
        "memory_usages": time_memory_callback.epoch_memory_usage,
        "accuracy": accuracy_score(y_test, y_pred),
        "f1_score": f1_score(y_test, y_pred, average="macro"),
        "precision": precision_score(y_test, y_pred, average="macro"),
        "recall": recall_score(y_test, y_pred, average="macro")
    }
    with open(f"./results/tabnet_metrics_{dataset.split("/")[-1].strip(".csv")}.json", "w") as f:
        json.dump(metrics, f)
    

In [40]:
main_tabnet("./processed_datasets/CL/dataset_dataset_without_da_pca_rfe_CL.csv")




epoch 0  | loss: 0.27954 | val_accuracy: 0.92133 |  0:00:31s
epoch 1  | loss: 0.24581 | val_accuracy: 0.92205 |  0:01:02s
epoch 2  | loss: 0.24702 | val_accuracy: 0.91955 |  0:01:33s
epoch 3  | loss: 0.24463 | val_accuracy: 0.91988 |  0:02:04s
epoch 4  | loss: 0.24397 | val_accuracy: 0.91986 |  0:02:35s
epoch 5  | loss: 0.24915 | val_accuracy: 0.92099 |  0:03:05s
epoch 6  | loss: 0.24525 | val_accuracy: 0.9204  |  0:03:37s

Early stopping occurred at epoch 6 with best_epoch = 1 and best_val_accuracy = 0.92205




Accuracy: 92.21%
F1 Score: 92.35%
Precision: 93.99%
Recall: 91.99%
Confusion Matrix:
[[36852   105     0]
 [ 1124 30476    40]
 [ 6610    58 26560]]
