In [4]:
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
from pytorch_tabnet.pretraining import TabNetPretrainer
from sklearn.impute import KNNImputer
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import math

# Generate random data
n_samples = 3000  # Adjust the number of samples as needed
n_features = 50   # Adjust the number of features as needed
data = np.random.rand(n_samples, n_features)
df = pd.DataFrame(data, columns=[f'feature_{i}' for i in range(n_features)])

# Print the ground values before introducing missing values
print("Ground Values (Before Introducing Missing Values):")
print(df.head())

# Store a copy of the ground values
ground_values = df.copy()

# Introduce random missing values
missing_fraction = 0.3  # Adjust the fraction of missing values as needed
mask = np.random.rand(n_samples, n_features) < missing_fraction
df[mask] = np.nan

# Print the input dataset with missing values
print("\nInput Data with Missing Values:")
print(df.head())

# Normalize the dataset
scaler = StandardScaler()
numeric_cols = df.columns
df[numeric_cols] = scaler.fit_transform(df[numeric_cols])

# Store the normalized data in a separate variable
df_normalized = df.copy()

print('After Normalizing Input Data:')
print(df.head())

# Use KNN imputer to fill missing values
knn_imputer = KNNImputer(n_neighbors=50)  # You can adjust the number of neighbors as needed
df[numeric_cols] = knn_imputer.fit_transform(df[numeric_cols])
print("\nFilled Missing Values with KNN:")
print(df.head()) 

# Pretrain the TabNet model
pretrained_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    mask_type='entmax',
    n_d=32,  # Increase the number of decision steps
    n_a=32   # Increase the number of features shared
)

max_epochs = 25
pretrained_model.fit(
    df[numeric_cols].values,
    max_epochs=max_epochs
)
print("\nPretrained TabNet Model:")

# Define the tabnet_recon function
def tabnet_recon(df, network, df_mean=0, df_std=1):
    df_train = df.copy()
    df_train[numeric_cols] = scaler.transform(df_train[numeric_cols])
    
    # Convert input data to tensors for use in the TabNet network
    input_data = torch.tensor(df_train[numeric_cols].values, dtype=torch.float32)
    
    # Pass the input data through the TabNet network
    results = network.predict(input_data)
    
    # Handle potential tuple output
    if isinstance(results, tuple):
        results = results[0]  # Use the first element of the tuple
    
    # Denormalize the reconstructed data
    df_na_tab = (results * df_std) + df_mean
    
    # Patch the reconstructed data into the original data with missing values
    df_rec_tab = df.copy()
    df_rec_tab.update(df_na_tab)
    
    return df_rec_tab

# Extract true missing values before filling
true_missing_values = df[numeric_cols].values  # Replace df_original with df


print('\nTrue missing values:')
print(true_missing_values)

# Reconstruct missing values using the pretrained model
reconstructed_data = tabnet_recon(df, network=pretrained_model)

print('\nReconstructed data:')
print(reconstructed_data.head())

# Extract imputed values
imputed_values = reconstructed_data[numeric_cols].values

print('\nImputed values:')
print(imputed_values)

# Denormalize the reconstructed data
reconstructed_data[numeric_cols] = scaler.inverse_transform(reconstructed_data[numeric_cols])

# Compare the denormalized reconstructed data with the original input data

print("\nDenormalized Reconstructed Data:")
print(reconstructed_data.head())

# Print the ground values again (before processing) at the end
print("\nGround Values (Before Processing, At the End):")
print(ground_values.head())


# Calculate RMSE
rmse = math.sqrt(mean_squared_error(ground_values[numeric_cols].values, reconstructed_data[numeric_cols].values))
print("\nRoot Mean Squared Error (RMSE):", rmse)

Ground Values (Before Introducing Missing Values):
   feature_0  feature_1  feature_2  feature_3  feature_4  feature_5  \
0   0.103919   0.466684   0.987541   0.269519   0.072754   0.120624   
1   0.857396   0.239468   0.980458   0.418978   0.660958   0.968176   
2   0.463284   0.188713   0.183854   0.662724   0.246240   0.123679   
3   0.831639   0.798794   0.524871   0.132035   0.129027   0.757240   
4   0.957893   0.279951   0.587905   0.142577   0.811062   0.931662   

   feature_6  feature_7  feature_8  feature_9  ...  feature_40  feature_41  \
0   0.775693   0.341257   0.196782   0.031011  ...    0.537373    0.472912   
1   0.262432   0.015159   0.689803   0.625879  ...    0.185452    0.848876   
2   0.766383   0.492859   0.262677   0.987888  ...    0.002743    0.223723   
3   0.938280   0.479514   0.321694   0.360252  ...    0.331237    0.252659   
4   0.381321   0.041035   0.077837   0.475790  ...    0.289496    0.157729   

   feature_42  feature_43  feature_44  feature_45  fe