In [1]:
import optuna
import pandas as pd
import numpy as np
from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
from sklearn.preprocessing import StandardScaler
import shap 

In [2]:
data = pd.read_csv('encoded_training.csv')
data

Unnamed: 0,ID,PATID,UCX_abnormal,ua_bacteria,ua_bili,ua_blood,ua_clarity,ua_color,ua_epi,ua_glucose,...,ua_spec_grav,ua_urobili,ua_wbc,age,abxUTI,ethnicity_Hispanic or Latino,ethnicity_Non-Hispanic,ethnicity_Patient Refused,ethnicity_Unknown,Female
0,25960,17080,0,0,0,0,1,1,0,0,...,1.020,0,1,69,0,False,True,False,False,False
1,25961,17080,1,1,0,2,0,1,1,0,...,1.014,0,1,73,1,False,True,False,False,False
2,25964,17080,0,1,0,3,1,1,1,0,...,1.022,1,1,24,1,False,True,False,False,True
3,25969,17084,0,1,0,0,1,1,2,0,...,1.010,1,1,42,0,False,True,False,False,True
4,25971,17086,0,0,0,1,1,1,1,0,...,1.023,0,2,86,1,False,True,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21358,80367,55350,1,0,0,0,0,1,0,0,...,1.012,0,2,61,1,False,True,False,False,False
21359,80373,55355,0,0,0,1,0,1,1,0,...,1.017,0,2,69,0,False,True,False,False,False
21360,80374,55356,1,1,0,3,1,1,1,0,...,1.014,0,3,81,1,False,True,False,False,False
21361,80379,55360,1,4,0,1,1,1,0,0,...,1.019,0,3,64,1,True,False,False,False,False


In [3]:
# Define features and target
X = data.drop(columns=['UCX_abnormal', 'ID', 'PATID'])  # Dropping ID columns and target
y = data['UCX_abnormal']

In [4]:
# Optional: Standardize features if necessary
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

In [5]:
# Encode the labels as TabNet requires numerical labels
y_encoded = LabelEncoder().fit_transform(y)

In [6]:
# Define the Optuna objective function for TabNet
def objective(trial):
    # Suggest hyperparameters for the TabNet model
    n_d = trial.suggest_int('n_d', 8, 64)  # Dimensionality of the decision layer
    n_a = trial.suggest_int('n_a', 8, 64)  # Dimensionality of the attention layer
    n_steps = trial.suggest_int('n_steps', 3, 10)  # Number of steps in the architecture
    gamma = trial.suggest_float('gamma', 1.0, 2.0)  # Relaxation parameter
    lambda_sparse = trial.suggest_float('lambda_sparse', 1e-5, 1e-3, log=True)  # Sparse regularization strength
    n_independent = trial.suggest_int('n_independent', 1, 5)  # Number of independent Gated Linear Units
    n_shared = trial.suggest_int('n_shared', 1, 5)  # Number of shared Gated Linear Units

    # Define the TabNet model
    model = TabNetClassifier(
        n_d=n_d,
        n_a=n_a,
        n_steps=n_steps,
        gamma=gamma,
        lambda_sparse=lambda_sparse,
        n_independent=n_independent,
        n_shared=n_shared,
        verbose=0
    )
    
    # Split the data into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)
    
    # Train the model
    model.fit(
        X_train, y_train,
        eval_set=[(X_val, y_val)],
        eval_metric=['accuracy'],
        max_epochs=100,
        patience=10,  # Early stopping
        batch_size=1024,
        virtual_batch_size=128
    )
    
    # Make predictions on the validation set
    y_pred_val = model.predict(X_val)
    
    # Calculate validation accuracy
    accuracy = accuracy_score(y_val, y_pred_val)
    
    # Return the negative accuracy (Optuna minimizes)
    return -accuracy

In [7]:

# Create a study object for TabNet
study = optuna.create_study(direction='minimize')  # Minimize the negative accuracy

[I 2024-09-23 00:13:50,448] A new study created in memory with name: no-name-5f0310d1-38bb-46bd-bc6d-c3883e407e7d


In [8]:
# Optimize the study
study.optimize(objective, n_trials=30)  # Run the optimization for 30 trials



Early stopping occurred at epoch 36 with best_epoch = 26 and best_val_0_accuracy = 0.84086


[I 2024-09-23 00:16:17,685] Trial 0 finished with value: -0.8408612216241517 and parameters: {'n_d': 44, 'n_a': 39, 'n_steps': 9, 'gamma': 1.4696909236719968, 'lambda_sparse': 2.2586211145484574e-05, 'n_independent': 1, 'n_shared': 4}. Best is trial 0 with value: -0.8408612216241517.



Early stopping occurred at epoch 22 with best_epoch = 12 and best_val_0_accuracy = 0.84414


[I 2024-09-23 00:17:55,100] Trial 1 finished with value: -0.8441376082377721 and parameters: {'n_d': 13, 'n_a': 56, 'n_steps': 7, 'gamma': 1.0082362766412425, 'lambda_sparse': 0.00016413708141444004, 'n_independent': 5, 'n_shared': 4}. Best is trial 1 with value: -0.8441376082377721.



Early stopping occurred at epoch 18 with best_epoch = 8 and best_val_0_accuracy = 0.82331


[I 2024-09-23 00:18:46,376] Trial 2 finished with value: -0.8233091504797566 and parameters: {'n_d': 30, 'n_a': 13, 'n_steps': 7, 'gamma': 1.8522304733591266, 'lambda_sparse': 9.131130204736656e-05, 'n_independent': 4, 'n_shared': 2}. Best is trial 1 with value: -0.8441376082377721.



Early stopping occurred at epoch 38 with best_epoch = 28 and best_val_0_accuracy = 0.84414


[I 2024-09-23 00:21:09,200] Trial 3 finished with value: -0.8441376082377721 and parameters: {'n_d': 57, 'n_a': 35, 'n_steps': 5, 'gamma': 1.6146905504006712, 'lambda_sparse': 1.2030710635907865e-05, 'n_independent': 5, 'n_shared': 5}. Best is trial 1 with value: -0.8441376082377721.



Early stopping occurred at epoch 41 with best_epoch = 31 and best_val_0_accuracy = 0.83524


[I 2024-09-23 00:23:20,029] Trial 4 finished with value: -0.8352445588579452 and parameters: {'n_d': 29, 'n_a': 18, 'n_steps': 9, 'gamma': 1.4157509315878591, 'lambda_sparse': 2.762499037904247e-05, 'n_independent': 4, 'n_shared': 2}. Best is trial 1 with value: -0.8441376082377721.


Stop training because you reached max_epochs = 100 with best_epoch = 93 and best_val_0_accuracy = 0.85139


[I 2024-09-23 00:33:17,310] Trial 5 finished with value: -0.8513924643107886 and parameters: {'n_d': 44, 'n_a': 37, 'n_steps': 10, 'gamma': 1.6832373496991175, 'lambda_sparse': 2.435101716333242e-05, 'n_independent': 5, 'n_shared': 5}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 54 with best_epoch = 44 and best_val_0_accuracy = 0.85046


[I 2024-09-23 00:34:43,492] Trial 6 finished with value: -0.8504563538497543 and parameters: {'n_d': 23, 'n_a': 31, 'n_steps': 6, 'gamma': 1.0669123100577829, 'lambda_sparse': 0.0007996159777000336, 'n_independent': 3, 'n_shared': 1}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 46 with best_epoch = 36 and best_val_0_accuracy = 0.85139


[I 2024-09-23 00:35:45,868] Trial 7 finished with value: -0.8513924643107886 and parameters: {'n_d': 16, 'n_a': 9, 'n_steps': 5, 'gamma': 1.0285781229376538, 'lambda_sparse': 5.190066396978215e-05, 'n_independent': 3, 'n_shared': 1}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 23 with best_epoch = 13 and best_val_0_accuracy = 0.82752


[I 2024-09-23 00:36:38,314] Trial 8 finished with value: -0.8275216475544114 and parameters: {'n_d': 24, 'n_a': 54, 'n_steps': 8, 'gamma': 1.523428124053581, 'lambda_sparse': 0.0006095981913500478, 'n_independent': 1, 'n_shared': 3}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 25 with best_epoch = 15 and best_val_0_accuracy = 0.84554


[I 2024-09-23 00:37:49,531] Trial 9 finished with value: -0.8455417739293236 and parameters: {'n_d': 19, 'n_a': 55, 'n_steps': 4, 'gamma': 1.8489971619315795, 'lambda_sparse': 0.00044231427026272975, 'n_independent': 5, 'n_shared': 5}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 71 with best_epoch = 61 and best_val_0_accuracy = 0.85139


[I 2024-09-23 00:42:42,206] Trial 10 finished with value: -0.8513924643107886 and parameters: {'n_d': 45, 'n_a': 46, 'n_steps': 10, 'gamma': 1.6829933100425087, 'lambda_sparse': 1.074012401926491e-05, 'n_independent': 2, 'n_shared': 4}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 24 with best_epoch = 14 and best_val_0_accuracy = 0.84929


[I 2024-09-23 00:43:05,745] Trial 11 finished with value: -0.8492862157734613 and parameters: {'n_d': 41, 'n_a': 24, 'n_steps': 3, 'gamma': 1.2638698362058818, 'lambda_sparse': 6.094060486165478e-05, 'n_independent': 3, 'n_shared': 1}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 17 with best_epoch = 7 and best_val_0_accuracy = 0.82448


[I 2024-09-23 00:43:31,856] Trial 12 finished with value: -0.8244792885560496 and parameters: {'n_d': 8, 'n_a': 24, 'n_steps': 5, 'gamma': 1.9919423895700388, 'lambda_sparse': 4.2757466408874844e-05, 'n_independent': 4, 'n_shared': 2}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 36 with best_epoch = 26 and best_val_0_accuracy = 0.84461


[I 2024-09-23 00:44:43,225] Trial 13 finished with value: -0.8446056634682892 and parameters: {'n_d': 56, 'n_a': 10, 'n_steps': 6, 'gamma': 1.2586726662543135, 'lambda_sparse': 0.0001482821364848797, 'n_independent': 2, 'n_shared': 3}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 51 with best_epoch = 41 and best_val_0_accuracy = 0.84812


[I 2024-09-23 00:47:04,882] Trial 14 finished with value: -0.8481160776971682 and parameters: {'n_d': 50, 'n_a': 64, 'n_steps': 10, 'gamma': 1.2501856394937325, 'lambda_sparse': 2.584826538684608e-05, 'n_independent': 2, 'n_shared': 1}. Best is trial 5 with value: -0.8513924643107886.



Early stopping occurred at epoch 41 with best_epoch = 31 and best_val_0_accuracy = 0.85209


[I 2024-09-23 00:48:30,302] Trial 15 finished with value: -0.8520945471565645 and parameters: {'n_d': 64, 'n_a': 43, 'n_steps': 3, 'gamma': 1.7459622510311343, 'lambda_sparse': 0.0002585445271326861, 'n_independent': 3, 'n_shared': 5}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 31 with best_epoch = 21 and best_val_0_accuracy = 0.84929


[I 2024-09-23 00:49:49,271] Trial 16 finished with value: -0.8492862157734613 and parameters: {'n_d': 63, 'n_a': 44, 'n_steps': 3, 'gamma': 1.7156302814740965, 'lambda_sparse': 0.0002681360258765007, 'n_independent': 4, 'n_shared': 5}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 39 with best_epoch = 29 and best_val_0_accuracy = 0.84297


[I 2024-09-23 00:52:57,807] Trial 17 finished with value: -0.8429674701614791 and parameters: {'n_d': 64, 'n_a': 46, 'n_steps': 8, 'gamma': 1.8122347148289621, 'lambda_sparse': 0.0002789216102178035, 'n_independent': 2, 'n_shared': 5}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 45 with best_epoch = 35 and best_val_0_accuracy = 0.84203


[I 2024-09-23 00:54:36,429] Trial 18 finished with value: -0.8420313597004446 and parameters: {'n_d': 34, 'n_a': 30, 'n_steps': 4, 'gamma': 1.9677316256286437, 'lambda_sparse': 9.714998831922168e-05, 'n_independent': 3, 'n_shared': 4}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 73 with best_epoch = 63 and best_val_0_accuracy = 0.84905


[I 2024-09-23 01:01:05,551] Trial 19 finished with value: -0.8490521881582027 and parameters: {'n_d': 52, 'n_a': 38, 'n_steps': 8, 'gamma': 1.5830415079524538, 'lambda_sparse': 0.00026281862325315157, 'n_independent': 4, 'n_shared': 5}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 76 with best_epoch = 66 and best_val_0_accuracy = 0.84788


[I 2024-09-23 01:07:51,515] Trial 20 finished with value: -0.8478820500819096 and parameters: {'n_d': 39, 'n_a': 42, 'n_steps': 9, 'gamma': 1.743612528371246, 'lambda_sparse': 0.00015414513062341203, 'n_independent': 5, 'n_shared': 4}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 31 with best_epoch = 21 and best_val_0_accuracy = 0.85022


[I 2024-09-23 01:08:55,103] Trial 21 finished with value: -0.8502223262344957 and parameters: {'n_d': 34, 'n_a': 50, 'n_steps': 4, 'gamma': 1.3577911012020136, 'lambda_sparse': 5.282894029612608e-05, 'n_independent': 3, 'n_shared': 3}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 17 with best_epoch = 7 and best_val_0_accuracy = 0.83852


[I 2024-09-23 01:09:32,085] Trial 22 finished with value: -0.8385209454715656 and parameters: {'n_d': 15, 'n_a': 33, 'n_steps': 5, 'gamma': 1.6283493625793328, 'lambda_sparse': 3.1625818919776436e-05, 'n_independent': 3, 'n_shared': 3}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 14 with best_epoch = 4 and best_val_0_accuracy = 0.84812


[I 2024-09-23 01:09:57,752] Trial 23 finished with value: -0.8481160776971682 and parameters: {'n_d': 50, 'n_a': 25, 'n_steps': 3, 'gamma': 1.1221207243496232, 'lambda_sparse': 1.7019932534738236e-05, 'n_independent': 3, 'n_shared': 5}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 20 with best_epoch = 10 and best_val_0_accuracy = 0.84344


[I 2024-09-23 01:10:21,560] Trial 24 finished with value: -0.8434355253919963 and parameters: {'n_d': 28, 'n_a': 17, 'n_steps': 4, 'gamma': 1.7761446581877833, 'lambda_sparse': 7.072855860245378e-05, 'n_independent': 2, 'n_shared': 2}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 29 with best_epoch = 19 and best_val_0_accuracy = 0.83361


[I 2024-09-23 01:11:57,626] Trial 25 finished with value: -0.833606365551135 and parameters: {'n_d': 8, 'n_a': 49, 'n_steps': 6, 'gamma': 1.9021725446592357, 'lambda_sparse': 1.7500353614224557e-05, 'n_independent': 4, 'n_shared': 4}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 25 with best_epoch = 15 and best_val_0_accuracy = 0.83595


[I 2024-09-23 01:13:16,646] Trial 26 finished with value: -0.835946641703721 and parameters: {'n_d': 59, 'n_a': 28, 'n_steps': 5, 'gamma': 1.5374334470868427, 'lambda_sparse': 4.146494261478529e-05, 'n_independent': 3, 'n_shared': 5}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 38 with best_epoch = 28 and best_val_0_accuracy = 0.84999


[I 2024-09-23 01:14:55,097] Trial 27 finished with value: -0.8499882986192371 and parameters: {'n_d': 45, 'n_a': 40, 'n_steps': 7, 'gamma': 1.1622745129688292, 'lambda_sparse': 0.00036727239967718607, 'n_independent': 3, 'n_shared': 1}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 32 with best_epoch = 22 and best_val_0_accuracy = 0.84999


[I 2024-09-23 01:15:47,532] Trial 28 finished with value: -0.8499882986192371 and parameters: {'n_d': 39, 'n_a': 18, 'n_steps': 3, 'gamma': 1.3706980052591464, 'lambda_sparse': 0.00012881902190835305, 'n_independent': 4, 'n_shared': 3}. Best is trial 15 with value: -0.8520945471565645.



Early stopping occurred at epoch 35 with best_epoch = 25 and best_val_0_accuracy = 0.84788


[I 2024-09-23 01:16:52,346] Trial 29 finished with value: -0.8478820500819096 and parameters: {'n_d': 54, 'n_a': 37, 'n_steps': 4, 'gamma': 1.458840789432847, 'lambda_sparse': 1.938474428827938e-05, 'n_independent': 1, 'n_shared': 4}. Best is trial 15 with value: -0.8520945471565645.


In [9]:
# Get the best parameters
print(f"Best parameters: {study.best_params}")
print(f"Best score: {-study.best_value}")

Best parameters: {'n_d': 64, 'n_a': 43, 'n_steps': 3, 'gamma': 1.7459622510311343, 'lambda_sparse': 0.0002585445271326861, 'n_independent': 3, 'n_shared': 5}
Best score: 0.8520945471565645


In [10]:
# Train the final model using the best parameters
best_params = study.best_params
final_model = TabNetClassifier(
    n_d=best_params['n_d'],
    n_a=best_params['n_a'],
    n_steps=best_params['n_steps'],
    gamma=best_params['gamma'],
    lambda_sparse=best_params['lambda_sparse'],
    n_independent=best_params['n_independent'],
    n_shared=best_params['n_shared'],
    verbose=0
)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.2, random_state=42)

# Train the model on the training set
final_model.fit(
    X_train, y_train,
    eval_set=[(X_test, y_test)],
    eval_metric=['accuracy'],
    max_epochs=100,
    patience=10,  # Early stopping
    batch_size=1024,
    virtual_batch_size=128
)

# Make predictions on the test set
y_pred = final_model.predict(X_test)


Early stopping occurred at epoch 41 with best_epoch = 31 and best_val_0_accuracy = 0.85209




In [11]:
# Calculate metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)

In [12]:
# Print the metrics
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print("Confusion Matrix:")
print(conf_matrix)
print("Classification Report:")
print(class_report)

Accuracy: 0.8521
Precision: 0.8279
Recall: 0.6534
F1 Score: 0.7304
Confusion Matrix:
[[2785  178]
 [ 454  856]]
Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.94      0.90      2963
           1       0.83      0.65      0.73      1310

    accuracy                           0.85      4273
   macro avg       0.84      0.80      0.81      4273
weighted avg       0.85      0.85      0.85      4273



In [13]:
# Extract feature importances
feature_importances = final_model.feature_importances_

# Create a DataFrame to display feature importances
feature_importance_df = pd.DataFrame({
    'Importance': feature_importances
})

# Sort and display the top 10 features
top_10_features = feature_importance_df.sort_values(by='Importance', ascending=False).head(10)
print(top_10_features)


    Importance
15    0.177899
0     0.149297
17    0.135926
8     0.122604
9     0.120186
22    0.056613
7     0.046655
5     0.036365
6     0.030592
18    0.026730


In [14]:
original_feature_names = X.columns
feature_importance_df['Feature'] = [original_feature_names[i] for i in range(len(original_feature_names))]
top_10_features = feature_importance_df.sort_values(by='Importance', ascending=False).head(10)

print(top_10_features)

    Importance                       Feature
15    0.177899                        ua_wbc
0     0.149297                   ua_bacteria
17    0.135926                        abxUTI
8     0.122604                       ua_leuk
9     0.120186                    ua_nitrite
22    0.056613                        Female
7     0.046655                    ua_ketones
5     0.036365                        ua_epi
6     0.030592                    ua_glucose
18    0.026730  ethnicity_Hispanic or Latino
