In [1]:
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
from pytorch_tabnet.pretraining import TabNetPretrainer
from sklearn.impute import KNNImputer
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import math

# Generate random data
n_samples = 5000  # Adjust the number of samples as needed
n_features = 60   # Adjust the number of features as needed
data = np.random.rand(n_samples, n_features)
df = pd.DataFrame(data, columns=[f'feature_{i}' for i in range(n_features)])

# Print the ground values before introducing missing values
print("Ground Values (Before Introducing Missing Values):")
print(df.head())

# Store a copy of the ground values
ground_values = df.copy()

# Introduce random missing values
missing_fraction = 0.3  # Adjust the fraction of missing values as needed
mask = np.random.rand(n_samples, n_features) < missing_fraction
df[mask] = np.nan

# Print the input dataset with missing values
print("\nInput Data with Missing Values:")
print(df.head())

# Normalize the dataset
scaler = StandardScaler()
numeric_cols = df.columns
df[numeric_cols] = scaler.fit_transform(df[numeric_cols])

# Store the normalized data in a separate variable
df_normalized = df.copy()

print('After Normalizing Input Data:')
print(df.head())

df[numeric_cols] = df[numeric_cols].fillna(0)
print("\nFilled Missing Values with Zero Placeholder:")
print(df.head()) 

# Pretrain the TabNet model
pretrained_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    mask_type='entmax',
    n_d=32,  # Increase the number of decision steps
    n_a=32   # Increase the number of features shared
)

max_epochs = 50
pretrained_model.fit(
    df[numeric_cols].values,
    max_epochs=max_epochs
)
print("\nPretrained TabNet Model:")

# Define the tabnet_recon function
def tabnet_recon(df, network, df_mean=0, df_std=1):
    df_train = df.copy()
    df_train[numeric_cols] = scaler.transform(df_train[numeric_cols])
    
    # Convert input data to tensors for use in the TabNet network
    input_data = torch.tensor(df_train[numeric_cols].values, dtype=torch.float32)
    
    # Pass the input data through the TabNet network
    results = network.predict(input_data)
    
    # Handle potential tuple output
    if isinstance(results, tuple):
        results = results[0]  # Use the first element of the tuple
    
    # Denormalize the reconstructed data
    df_na_tab = (results * df_std) + df_mean
    
    # Patch the reconstructed data into the original data with missing values
    df_rec_tab = df.copy()
    df_rec_tab.update(df_na_tab)
    
    return df_rec_tab

# Extract true missing values before filling
true_missing_values = df[numeric_cols].values  # Replace df_original with df


print('\nTrue missing values:')
print(true_missing_values)

# Reconstruct missing values using the pretrained model
reconstructed_data = tabnet_recon(df, network=pretrained_model)

print('\nReconstructed data:')
print(reconstructed_data.head())

# Extract imputed values
imputed_values = reconstructed_data[numeric_cols].values

print('\nImputed values:')
print(imputed_values)

# Denormalize the reconstructed data
reconstructed_data[numeric_cols] = scaler.inverse_transform(reconstructed_data[numeric_cols])

# Compare the denormalized reconstructed data with the original input data

print("\nDenormalized Reconstructed Data:")
print(reconstructed_data.head())

# Print the ground values again (before processing) at the end
print("\nGround Values (Before Processing, At the End):")
print(ground_values.head())


# Calculate RMSE
rmse = math.sqrt(mean_squared_error(ground_values[numeric_cols].values, reconstructed_data[numeric_cols].values))
print("\nRoot Mean Squared Error (RMSE):", rmse)

Ground Values (Before Introducing Missing Values):
   feature_0  feature_1  feature_2  feature_3  feature_4  feature_5  \
0   0.348235   0.349890   0.458102   0.703647   0.824123   0.322928   
1   0.498819   0.163697   0.842276   0.537371   0.273011   0.488537   
2   0.893514   0.887175   0.188144   0.986549   0.855154   0.825018   
3   0.984123   0.161337   0.370182   0.842160   0.227979   0.974128   
4   0.035229   0.112280   0.937106   0.048855   0.697233   0.791711   

   feature_6  feature_7  feature_8  feature_9  ...  feature_50  feature_51  \
0   0.470692   0.728581   0.886758   0.755525  ...    0.065345    0.324274   
1   0.842686   0.222656   0.693628   0.084708  ...    0.312724    0.791988   
2   0.608383   0.644597   0.925009   0.649155  ...    0.605175    0.060008   
3   0.517608   0.610857   0.461404   0.655599  ...    0.381657    0.349899   
4   0.467091   0.664637   0.513832   0.520772  ...    0.179873    0.845256   

   feature_52  feature_53  feature_54  feature_55  fe