<a href="https://colab.research.google.com/github/aiemond/TabNet/blob/main/TabNetImputation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
from pytorch_tabnet.pretraining import TabNetPretrainer

# larger example dataset with missing values
np.random.seed(42)
num_samples = 5000
num_features = 25
data = {
    f'feature{i}': np.random.normal(size=num_samples) for i in range(1, num_features + 1)
}
data['subtype'] = np.random.choice(['A', 'B', 'C'], size=num_samples)

# Introduce random missing values
missing_mask = np.random.rand(num_samples, num_features) < 0.2
for col in data.keys():
    if col != 'subtype':
        data[col][missing_mask[:, int(col[-1]) - 1]] = np.nan

omicMiss = pd.DataFrame(data)

# Normalize the dataset
scaler = StandardScaler()
numeric_cols = [f'feature{i}' for i in range(1, num_features + 1)]
omicMiss[numeric_cols] = scaler.fit_transform(omicMiss[numeric_cols])

# Handle missing values by filling NaNs with a specific value
missing_value_placeholder = 0
omicMiss.fillna(missing_value_placeholder, inplace=True)

# Pretrain the TabNet model
pretrained_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    mask_type='entmax'
)

max_epochs = 10
pretrained_model.fit(
    omicMiss[numeric_cols].values,
    max_epochs=max_epochs
)

# Define the tabnet_recon function

def tabnet_recon(omicMiss, network, omicMissMean=0, omicMissSd=1):
    omicMissTrain = omicMiss.copy()
    omicMissTrain[numeric_cols] = scaler.transform(omicMissTrain[numeric_cols])

    # Convert input data to tensors for use in the TabNet network
    inputData = torch.tensor(omicMissTrain[numeric_cols].values, dtype=torch.float32)

    # Pass the input data through the TabNet network
    results = network.predict(inputData)

    # Handle potential tuple output
    if isinstance(results, tuple):
        results = results[0]  # Use the first element of the tuple

    # Denormalize the reconstructed data
    omicNa_tab = (results * omicMissSd) + omicMissMean

    # Combine reconstructed data with original categorical column
    omicNa_tab = pd.DataFrame(omicNa_tab, columns=numeric_cols)
    omicNa_tab['subtype'] = omicMiss['subtype']

    # Patch the reconstructed data into the original data with missing values
    omicRec_tab = omicMiss.copy()
    omicRec_tab.update(omicNa_tab)

    return omicRec_tab


# Extract true missing values before filling
true_missing_values = omicMiss[numeric_cols].values

# Reconstruct missing values using the pretrained model
reconstructed_data = tabnet_recon(omicMiss, network=pretrained_model)

# Extract imputed values
imputed_values = reconstructed_data[numeric_cols].values

# Calculate MAE, R-squared, and RMSE
mae = np.mean(np.abs(imputed_values - true_missing_values))
total_variation = np.sum((true_missing_values - np.mean(true_missing_values)) ** 2)
residual_variation = np.sum((true_missing_values - imputed_values) ** 2)
r_squared = 1 - (residual_variation / total_variation)
rmse = np.sqrt(np.mean((imputed_values - true_missing_values) ** 2))

print("MAE:", mae)
print("R-squared:", r_squared)
print("RMSE:", rmse)

# Print original and reconstructed data (only printing a subset)
print("Original Data:")
print(omicMiss.head())
print("\nReconstructed Data:")
print(reconstructed_data.head())




epoch 0  | loss: 3.71432 |  0:00:00s
epoch 1  | loss: 2.00397 |  0:00:00s
epoch 2  | loss: 1.38268 |  0:00:00s
epoch 3  | loss: 1.1753  |  0:00:00s
epoch 4  | loss: 1.08561 |  0:00:00s
epoch 5  | loss: 1.04796 |  0:00:01s
epoch 6  | loss: 1.02771 |  0:00:01s
epoch 7  | loss: 1.00867 |  0:00:01s
epoch 8  | loss: 1.01442 |  0:00:01s
epoch 9  | loss: 1.01794 |  0:00:01s
MAE: 0.6471541921315724
R-squared: -0.0024228582490715134
RMSE: 0.896584038719793
Original Data:
   feature1  feature2  feature3  feature4  feature5  feature6  feature7  \
0  0.000000 -0.401434 -0.683793 -0.161647  0.369118  0.175104 -1.919698   
1  0.000000 -0.430877 -0.311414  0.000000  0.302928  0.000000 -1.018288   
2  0.636453 -1.763517  0.000000  0.045774 -0.939967 -0.422877  0.000000   
3  1.514330 -0.308434  0.103817  0.927079  0.604787  0.002869  0.154902   
4 -0.247943  0.746892  0.000000  0.000000  0.000000  0.492921  1.006554   

   feature8  feature9  feature10  ...  feature17  feature18  feature19  \
0  0.929