In [1]:
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import MinMaxScaler  
from pytorch_tabnet.pretraining import TabNetPretrainer
from sklearn.metrics import mean_squared_error
import math

# Load dataset with missing values
omicMiss = pd.read_csv('/Users/emondemoniac/Desktop/TabNet_PyTorch/jiaojiao data/48complete_proteomics.csv')

omicMiss_copy = omicMiss.copy()
print("Input Data with missing Values:")
print(omicMiss.head())

# Normalize the dataset and preprocessing
scaler = MinMaxScaler()  
numeric_cols = omicMiss.columns[1:]
omicMiss[numeric_cols] = scaler.fit_transform(omicMiss[numeric_cols])

print('After Normalizing Input Data:')
print(omicMiss.head())

# Handle missing values by placeholder
missing_value_placeholder = 0
omicMiss.fillna(missing_value_placeholder, inplace=True)

print("\nFilled Missing Values with Placeholder:")
print(omicMiss.head())

# Introduce SSL by randomly masking some values during training
def mask_data(data, mask_prob=0.25):
    mask = (np.random.rand(*data.shape) < mask_prob) & (data == missing_value_placeholder)
    masked_data = data.copy()
    masked_data[mask] = missing_value_placeholder
    return masked_data

# Pretrain the TabNet model with SSL
pretrained_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    mask_type='sparsemax',
    #n_d=16,  # Increase the number of decision steps
    #n_a=16   # Increase the number of features shared
)

max_epochs = 75
for epoch in range(max_epochs):
    # Mask a percentage of the data for SSL
    masked_data = mask_data(omicMiss[numeric_cols].values)
    pretrained_model.fit(masked_data, max_epochs=1)

print("\nPretrained TabNet Model:")

def tabnet_recon(omicMiss, network, missing_value_placeholder=0, numeric_cols=None, scaler=None):
    # Ensure numeric_cols and scaler are provided
    if numeric_cols is None or scaler is None:
        raise ValueError("Please provide numeric_cols and scaler.")
    
    # Copy the input data
    omicRec_tab = omicMiss.copy()
    
    # Identify the rows and columns where the placeholder value exists
    missing_mask = (omicRec_tab[numeric_cols] == missing_value_placeholder)
    
    # Normalize the input data
    omicRec_tab[numeric_cols] = scaler.transform(omicRec_tab[numeric_cols])
    
    # Convert input data to tensors for use in the TabNet network
    inputData = torch.tensor(omicRec_tab[numeric_cols].values, dtype=torch.float32)
    
    # Pass the input data through the TabNet network
    results = network.predict(inputData)
    
    # Ensure results is a NumPy array
    if isinstance(results, tuple):
        results = results[0]  # Use the first element of the tuple
    elif isinstance(results, torch.Tensor):
        results = results.detach().numpy()  # Convert to NumPy array
    
    omicRec_tab[numeric_cols] = results
    
    # Reassign reconstructed values to the missing positions
    omicRec_tab[numeric_cols] = np.where(missing_mask, omicRec_tab[numeric_cols], omicMiss[numeric_cols])
    
    return omicRec_tab

# Extract true missing values before filling
true_missing_values = omicMiss[numeric_cols].values

print ('\n True missing values: ')
print (true_missing_values)

# Reconstruct missing values using the pretrained model
reconstructed_data = tabnet_recon(
    omicMiss,
    network=pretrained_model,
    numeric_cols=numeric_cols,
    scaler=scaler
)

print ('\n Reconstructed data: ')
print (reconstructed_data.iloc[:10])

# Extract imputed values
imputed_values = reconstructed_data[numeric_cols].values

print ('\n Imputed values: ')
print (imputed_values)

# Calculate RMSE
rmse = np.sqrt(np.mean((imputed_values - true_missing_values) ** 2))

# Denormalize the reconstructed data
reconstructed_data[numeric_cols] = scaler.inverse_transform(reconstructed_data[numeric_cols])

# Print original and reconstructed data
print("Original Data:")
print(omicMiss_copy.iloc[:10])
print("\nDenormalize Reconstructed Data:")
print(reconstructed_data.iloc[:10])

print("RMSE:", rmse)


Input Data with missing Values:
   Unnamed: 0  Balm_3_1_U_IO_DDA_30min_G6_1_5228  \
0       ILVBL                                NaN   
1     NA;NBAS                           8.937573   
2  GTPBP10;NA                           8.672674   
3         PGP                                NaN   
4     NA;AHSG                          11.804774   

   Balm_3_2_T_IO_DDA_30min_H6_1_5230  Balm_3_3_U_IO_DDA_30min_A7_1_5232  \
0                           8.935561                                NaN   
1                           9.367173                                NaN   
2                           8.955487                                NaN   
3                           8.596817                                NaN   
4                          11.778822                          11.864427   

   Balm_3_4_T_IO_DDA_30min_B7_1_5234  DOHH_2_1_U_IO_DDA_30_C4_1_5188  \
0                                NaN                        9.220390   
1                                NaN                        