In [1]:
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
import torch

In [2]:
df=pd.read_csv('cleaned.csv',index_col=0)
train_copy=df.drop(columns=['bmi','avg_glucose_level','age'])
X = train_copy.drop('smoking_status', axis=1)
y = train_copy['smoking_status']


In [3]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y, test_size=0.2, random_state=42
)


X_train_np = X_train.astype(np.float32)
X_test_np = X_test.astype(np.float32)
y_train_np = y_train.values.astype(int)
y_test_np = y_test.values.astype(int)


tabnet_model = TabNetClassifier(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=0.001),
    scheduler_params={"step_size":10, "gamma":0.9},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    verbose=10,
)



In [4]:
tabnet_model.fit(
    X_train=X_train_np, y_train=y_train_np,
    eval_set=[(X_test_np, y_test_np)],
    eval_metric=['accuracy'],
    max_epochs=50,
    patience=10,
    batch_size=90,
    virtual_batch_size=128,
)


y_pred = tabnet_model.predict(X_test_np)
acc = accuracy_score(y_test_np, y_pred)
print(f"\n TabNet Test Accuracy: {acc:.4f}")
print(classification_report(y_test_np, y_pred))

epoch 0  | loss: 0.7569  | val_0_accuracy: 0.54227 |  0:00:00s
epoch 10 | loss: 0.68843 | val_0_accuracy: 0.5656  |  0:00:01s
epoch 20 | loss: 0.67919 | val_0_accuracy: 0.5656  |  0:00:02s

Early stopping occurred at epoch 25 with best_epoch = 15 and best_val_0_accuracy = 0.56997

 TabNet Test Accuracy: 0.5700
              precision    recall  f1-score   support

           0       0.58      0.71      0.64       371
           1       0.54      0.40      0.46       315

    accuracy                           0.57       686
   macro avg       0.56      0.56      0.55       686
weighted avg       0.57      0.57      0.56       686



