In [1]:
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
import pandas as pd
import json
import tracemalloc
import time
from pytorch_tabnet.callbacks import Callback


In [2]:
class TimeMemoryCallback(Callback):
    def __init__(self):
        self.epoch_times = []
        self.epoch_memory_usage = []

    def on_epoch_begin(self, epoch_idx, logs=None):
        self.epoch_start_time = time.time()
        tracemalloc.start()

    def on_epoch_end(self, epoch_idx, logs=None):
        end_time = time.time()
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()

        self.epoch_times.append(end_time - self.epoch_start_time)
        self.epoch_memory_usage.append({
            "current_MB": current / 1024 / 1024,
            "peak_MB": peak / 1024 / 1024
        })


def main_tabnet(dataset, epochs=10000):
    target = "Gestational Diabetes"
    df = pd.read_csv(dataset)
    X = df.drop(columns=[target])
    y = df[target]

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
    model = TabNetClassifier()

    time_memory_callback = TimeMemoryCallback()
    model.fit(
        X_train=X_train, 
        y_train=y_train,
        eval_set=[(X_test, y_test)],
        eval_name=["val"],
        eval_metric=["accuracy"],
        max_epochs=epochs,
        patience=5,
        batch_size=1024,
        callbacks=[time_memory_callback]
    )
    
    y_pred = model.predict(X_test)
    print(f"Accuracy: {accuracy_score(y_test, y_pred) * 100:.2f}%")
    print(f"F1 Score: {f1_score(y_test, y_pred, average='macro') * 100:.2f}%")
    print(f"Precision: {precision_score(y_test, y_pred, average='macro') * 100:.2f}%")
    print(f"Recall: {recall_score(y_test, y_pred, average='macro') * 100:.2f}%")
    print(f"Confusion Matrix:\n{confusion_matrix(y_test, y_pred)}")

    # Convert confusion matrix to list of lists
    cm = confusion_matrix(y_test, y_pred)
    cm = [[int(x) for x in row] for row in cm]

    metrics = {
        "epochs": epochs,
        "confusion_matrix": cm,
        "epoch_times": time_memory_callback.epoch_times,
        "memory_usages": time_memory_callback.epoch_memory_usage,
        "accuracy": accuracy_score(y_test, y_pred),
        "f1_score": f1_score(y_test, y_pred, average="macro"),
        "precision": precision_score(y_test, y_pred, average="macro"),
        "recall": recall_score(y_test, y_pred, average="macro")
    }
    with open(f"./results/tabnet_metrics_{dataset.split("/")[-1].strip(".csv")}.json", "w") as f:
        json.dump(metrics, f)
    

In [3]:
main_tabnet("./processed_datasets/dataset_preprocessed_smote.csv")




epoch 0  | loss: 0.68701 | val_accuracy: 0.60385 |  0:00:02s
epoch 1  | loss: 0.65865 | val_accuracy: 0.61718 |  0:00:04s
epoch 2  | loss: 0.65101 | val_accuracy: 0.62564 |  0:00:06s
epoch 3  | loss: 0.64817 | val_accuracy: 0.62726 |  0:00:08s
epoch 4  | loss: 0.64821 | val_accuracy: 0.62286 |  0:00:10s
epoch 5  | loss: 0.64184 | val_accuracy: 0.62772 |  0:00:12s
epoch 6  | loss: 0.63955 | val_accuracy: 0.63387 |  0:00:14s
epoch 7  | loss: 0.63725 | val_accuracy: 0.63584 |  0:00:17s
epoch 8  | loss: 0.63687 | val_accuracy: 0.63317 |  0:00:19s
epoch 9  | loss: 0.6322  | val_accuracy: 0.6363  |  0:00:21s
epoch 10 | loss: 0.63101 | val_accuracy: 0.63758 |  0:00:23s
epoch 11 | loss: 0.62579 | val_accuracy: 0.6465  |  0:00:25s
epoch 12 | loss: 0.61997 | val_accuracy: 0.65299 |  0:00:27s
epoch 13 | loss: 0.61726 | val_accuracy: 0.65148 |  0:00:29s
epoch 14 | loss: 0.61449 | val_accuracy: 0.65705 |  0:00:31s
epoch 15 | loss: 0.61542 | val_accuracy: 0.65844 |  0:00:33s
epoch 16 | loss: 0.6141 



Accuracy: 67.77%
F1 Score: 67.47%
Precision: 68.55%
Recall: 67.84%
Confusion Matrix:
[[2514 1830]
 [ 951 3333]]
