In [47]:
import pandas as pd
import numpy as np
import torch
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
from pytorch_tabnet.pretraining import TabNetPretrainer
from sklearn.metrics import mean_squared_error
import math

# Load Iris dataset
iris = load_iris()
data = iris.data
columns = iris.feature_names

# Convert to DataFrame
df = pd.DataFrame(data, columns=columns)

# Make a copy of the original data before introducing missing values
original_data = df.copy()
print("Original Data (Before Missing Values):")
print(original_data.head())

# Introduce artificial missing values
missing_fraction = 0.2
mask = np.random.rand(*data.shape) < missing_fraction
df[mask] = np.nan

# Print the input dataset with missing values
print("\nInput Data with Missing Values:")
print(df.head())

# Normalize the dataset
scaler = MinMaxScaler()
numeric_cols = df.columns
df[numeric_cols] = scaler.fit_transform(df[numeric_cols])

print("\nAfter Normalizing Input Data:")
print(df.head())

# Fill missing values with zero
df[numeric_cols] = df[numeric_cols].fillna(0)

print("\nFilled Missing Values with Zero Placeholder:")
print(df.head())

# Define and pretrain the TabNet model
pretrained_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    mask_type='sparsemax',
    #n_d=32,  # Adjust the number of decision steps as needed
    #n_a=32   # Adjust the number of features shared as needed
)

max_epochs = 40
pretrained_model.fit(
    df[numeric_cols].values,
    max_epochs=max_epochs
)

print("\nPretrained TabNet Model:")

def tabnet_recon(df, network, missing_value_placeholder=0, numeric_cols=None, scaler=None):
    # Ensure numeric_cols and scaler are provided
    if numeric_cols is None or scaler is None:
        raise ValueError("Please provide numeric_cols and scaler.")
    
    # Copy the input data
    df_tab = df.copy()
    
    # Identify the rows and columns where the placeholder value exists
    missing_mask = (df_tab[numeric_cols] == missing_value_placeholder)
    
    # Normalize the input data
    df_tab[numeric_cols] = scaler.transform(df_tab[numeric_cols])
    
    # Convert input data to tensors for use in the TabNet network
    inputData = torch.tensor(df_tab[numeric_cols].values, dtype=torch.float32)
    
    # Pass the input data through the TabNet network
    results = network.predict(inputData)
    
    # Ensure results is a NumPy array
    if isinstance(results, tuple):
        results = results[0]  # Use the first element of the tuple
    elif isinstance(results, torch.Tensor):
        results = results.detach().numpy()  # Convert to NumPy array
        
    df_tab[numeric_cols] = results
    
    # Apply the reconstructed values only to the missing positions
    df_tab[numeric_cols] = np.where(missing_mask, df_tab[numeric_cols], df[numeric_cols] )
    
    return df_tab

# Extract true missing values before filling
true_missing_values = df[numeric_cols].values

print('\nTrue Missing Values:')
print(true_missing_values)


# Reconstruct missing values using the pretrained model
reconstructed_data = tabnet_recon(
    df,
    network=pretrained_model,
    numeric_cols=numeric_cols,  # Provide the numeric_cols
    scaler=scaler  # Provide the scaler
)

print('\nReconstructed Data:')
print(reconstructed_data.head())

# Extract imputed values
imputed_values = reconstructed_data[numeric_cols].values

print('\nImputed Values:')
print(imputed_values)


# Denormalize the reconstructed data
reconstructed_data[numeric_cols] = scaler.inverse_transform(reconstructed_data[numeric_cols])

# Print the original data (before missing values) and reconstructed data
print("\nOriginal Data (Before Missing Values):")
print(original_data.head())

print("\nDenormalized Reconstructed Data:")
print(reconstructed_data.head())

# Calculate RMSE
rmse = np.sqrt(np.mean((original_data[numeric_cols].values - reconstructed_data[numeric_cols].values) ** 2))

# Print the RMSE
print("\nRMSE between Original Data and Denormalized Reconstructed Data:", rmse)


Original Data (Before Missing Values):
   sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
0                5.1               3.5                1.4               0.2
1                4.9               3.0                1.4               0.2
2                4.7               3.2                1.3               0.2
3                4.6               3.1                1.5               0.2
4                5.0               3.6                1.4               0.2

Input Data with Missing Values:
   sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
0                5.1               3.5                1.4               0.2
1                4.9               3.0                1.4               0.2
2                4.7               NaN                1.3               0.2
3                NaN               3.1                1.5               0.2
4                NaN               3.6                1.4               NaN

After Normalizi