## Setup

First, let's import the necessary libraries and set up our environment.

In [1]:
import sys
import os

# Add the parent directory to the path so we can import our notmiwae_pytorch module
sys.path.insert(0, os.path.dirname(os.path.abspath('')))

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# Import our not-MIWAE implementation
from notmiwae_pytorch import NotMIWAE, MIWAE, Trainer
from notmiwae_pytorch.utils import (
    set_seed, 
    imputation_rmse, 
    introduce_mnar_missing,
    standardize
)

# Set random seed for reproducibility
set_seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

ModuleNotFoundError: No module named 'notmiwae_pytorch'

## 1. Load and Prepare Data

We'll use the UCI Wine Quality dataset, following the experimental setup in the paper.

In [None]:
# Load the Wine Quality dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv"
data = pd.read_csv(url, sep=';')

print(f"Dataset shape: {data.shape}")
print(f"\nColumn names: {list(data.columns)}")
print(f"\nFirst few rows:")
data.head()

In [None]:
# Drop the quality column (target) - we only want features
X = data.drop('quality', axis=1).values.astype(np.float32)
N, D = X.shape

print(f"Data shape: N={N} samples, D={D} features")

# Standardize the data
X_std, mean, std = standardize(X)

# Random permutation
perm = np.random.permutation(N)
X_std = X_std[perm]

print(f"\nData statistics after standardization:")
print(f"Mean: {X_std.mean(axis=0).round(4)}")
print(f"Std: {X_std.std(axis=0).round(4)}")

## 2. Introduce Missing Values (MNAR Mechanism)

Following the paper's experimental setup, we introduce **self-masking MNAR**:
- In the first D/2 dimensions, values above the mean are missing
- This creates a challenging scenario where the missingness depends on the values themselves

In [None]:
# Introduce MNAR missing values
X_nan, X_filled, mask = introduce_mnar_missing(X_std)

print(f"Missing value statistics:")
print(f"Total missing rate: {(1 - mask.mean()):.2%}")
print(f"Missing rate per feature:")
for i, rate in enumerate(1 - mask.mean(axis=0)):
    print(f"  Feature {i}: {rate:.2%}")

In [None]:
# Visualize the missing pattern
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap of missing values (sample of first 100 rows)
n_show = min(100, mask.shape[0])
im = axes[0].imshow(mask[:n_show], aspect='auto', cmap='RdYlGn', interpolation='nearest')
axes[0].set_xlabel('Feature')
axes[0].set_ylabel('Sample')
axes[0].set_title(f'Missing Pattern (1=observed, 0=missing)\nFirst {n_show} samples')
plt.colorbar(im, ax=axes[0])

# Missing rate per feature
missing_rates = 1 - mask.mean(axis=0)
axes[1].bar(range(len(missing_rates)), missing_rates)
axes[1].set_xlabel('Feature Index')
axes[1].set_ylabel('Missing Rate')
axes[1].set_title('Missing Rate by Feature')
axes[1].axhline(y=missing_rates.mean(), color='r', linestyle='--', label=f'Mean: {missing_rates.mean():.2%}')
axes[1].legend()

plt.tight_layout()
plt.show()

## 3. Create Datasets and DataLoaders

We'll split the data into training and validation sets.

In [None]:
# Create train/validation split
train_ratio = 0.8
n_train = int(N * train_ratio)

# Training data
X_train_filled = X_filled[:n_train]
mask_train = mask[:n_train]
X_train_original = X_std[:n_train]

# Validation data
X_val_filled = X_filled[n_train:]
mask_val = mask[n_train:]
X_val_original = X_std[n_train:]

print(f"Training samples: {n_train}")
print(f"Validation samples: {N - n_train}")

In [None]:
# Convert to PyTorch tensors
X_train_filled_t = torch.tensor(X_train_filled, dtype=torch.float32)
mask_train_t = torch.tensor(mask_train, dtype=torch.float32)
X_train_original_t = torch.tensor(X_train_original, dtype=torch.float32)

X_val_filled_t = torch.tensor(X_val_filled, dtype=torch.float32)
mask_val_t = torch.tensor(mask_val, dtype=torch.float32)
X_val_original_t = torch.tensor(X_val_original, dtype=torch.float32)

# Create TensorDatasets (x_filled, mask, x_original)
train_dataset = TensorDataset(X_train_filled_t, mask_train_t, X_train_original_t)
val_dataset = TensorDataset(X_val_filled_t, mask_val_t, X_val_original_t)

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

## 4. Create and Train the not-MIWAE Model

We'll train the not-MIWAE model with the `selfmasking_known` missing process, which assumes we know that higher values are more likely to be missing.

In [None]:
# Model hyperparameters
latent_dim = D - 1  # Following the paper's setup
hidden_dim = 128
n_samples = 20  # Number of importance samples

# Create the not-MIWAE model
notmiwae = NotMIWAE(
    input_dim=D,
    latent_dim=latent_dim,
    hidden_dim=hidden_dim,
    n_samples=n_samples,
    out_dist='gauss',
    missing_process='selfmasking_known'  # Key: we use the known self-masking mechanism
)

print(f"not-MIWAE Model:")
print(f"  Input dimension: {D}")
print(f"  Latent dimension: {latent_dim}")
print(f"  Hidden dimension: {hidden_dim}")
print(f"  Number of importance samples: {n_samples}")
print(f"  Missing process: selfmasking_known")
print(f"\nTotal parameters: {sum(p.numel() for p in notmiwae.parameters()):,}")

In [None]:
# Create the trainer with original_data_available=True to track RMSE during training
trainer_notmiwae = Trainer(
    model=notmiwae,
    lr=1e-3,
    device=device,
    log_dir='./runs',
    checkpoint_dir='./checkpoints',
    original_data_available=True,  # Track imputation RMSE each epoch
    rmse_n_samples=50  # Fewer samples for speed during training
)

print(f"Trainer created. Device: {trainer_notmiwae.device}")
print(f"TensorBoard logs will be saved to: {trainer_notmiwae.log_dir}")
print(f"Tracking imputation RMSE during training: {trainer_notmiwae.original_data_available}")

In [None]:
# Log the model graph to TensorBoard before training
# This allows you to visualize the model architecture in TensorBoard's "Graphs" tab

# Get a sample batch to use for graph logging
sample_batch = next(iter(train_loader))
x_sample, s_sample, _ = sample_batch

# Manually trigger graph logging
trainer_notmiwae.log_model_graph(x_sample, s_sample)


In [None]:
# Train the not-MIWAE model
print("Training not-MIWAE...")
print("="*60)

history_notmiwae = trainer_notmiwae.train(
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=200,
    log_interval=20,
    save_best=True,
    early_stopping_patience=30,
    checkpoint_name='notmiwae_best.pt'
)

In [None]:
# Plot training history including RMSE
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss curves
axes[0].plot(history_notmiwae['train_loss'], label='Train')
if history_notmiwae.get('val_loss'):
    axes[0].plot(history_notmiwae['val_loss'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss (-ELBO)')
axes[0].set_title('not-MIWAE Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# ELBO curves
axes[1].plot(history_notmiwae['train_elbo'], label='Train')
if history_notmiwae.get('val_elbo'):
    axes[1].plot(history_notmiwae['val_elbo'], label='Validation')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('ELBO')
axes[1].set_title('not-MIWAE ELBO')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# RMSE curves
if history_notmiwae.get('train_rmse'):
    axes[2].plot(history_notmiwae['train_rmse'], label='Train')
    if history_notmiwae.get('val_rmse'):
        axes[2].plot(history_notmiwae['val_rmse'], label='Validation')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('RMSE')
    axes[2].set_title('not-MIWAE Imputation RMSE')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('notmiwae_training.png', dpi=150)
plt.show()

# Print final RMSE
if history_notmiwae.get('val_rmse'):
    print(f"\nFinal Validation RMSE: {history_notmiwae['val_rmse'][-1]:.5f}")

## 5. Train the Standard MIWAE for Comparison

For comparison, we'll also train the standard MIWAE which doesn't model the missing process.

In [None]:
# Create the MIWAE model (without missing process modeling)
miwae = MIWAE(
    input_dim=D,
    latent_dim=latent_dim,
    hidden_dim=hidden_dim,
    n_samples=n_samples,
    out_dist='gauss'
)

print(f"MIWAE Model (no missing process):")
print(f"  Total parameters: {sum(p.numel() for p in miwae.parameters()):,}")

In [None]:
# Create trainer for MIWAE with RMSE tracking
trainer_miwae = Trainer(
    model=miwae,
    lr=1e-3,
    device=device,
    log_dir='./runs',
    checkpoint_dir='./checkpoints',
    original_data_available=True,
    rmse_n_samples=50
)

# Log the MIWAE model graph to TensorBoard
trainer_miwae.log_model_graph(x_sample, s_sample)


# Train MIWAE
print("\nTraining MIWAE...")
print("="*60)

history_miwae = trainer_miwae.train(
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=200,
    log_interval=20,
    save_best=True,
    early_stopping_patience=30,
    checkpoint_name='miwae_best.pt'
)

## 6. Evaluate Imputation Performance

Now let's compare the imputation performance of both models.

In [None]:
# Load best models
trainer_notmiwae.load_checkpoint('notmiwae_best.pt')
trainer_miwae.load_checkpoint('miwae_best.pt')

# Compute imputation RMSE
n_imp_samples = 1000  # More samples for better imputation

print("Computing imputation RMSE...")
print("="*60)

In [None]:
# not-MIWAE imputation
rmse_notmiwae, X_imputed_notmiwae = imputation_rmse(
    model=notmiwae,
    x_original=torch.tensor(X_train_original),
    x_filled=torch.tensor(X_train_filled),
    mask=torch.tensor(mask_train),
    n_samples=n_imp_samples,
    device=device
)

print(f"\nnot-MIWAE Imputation RMSE: {rmse_notmiwae:.5f}")

In [None]:
# MIWAE imputation
rmse_miwae, X_imputed_miwae = imputation_rmse(
    model=miwae,
    x_original=torch.tensor(X_train_original),
    x_filled=torch.tensor(X_train_filled),
    mask=torch.tensor(mask_train),
    n_samples=n_imp_samples,
    device=device
)

print(f"\nMIWAE Imputation RMSE: {rmse_miwae:.5f}")

In [None]:
# Mean imputation baseline
from sklearn.impute import SimpleImputer

imputer = SimpleImputer(strategy='mean')
X_imputed_mean = imputer.fit_transform(X_nan[:n_train])

# Compute RMSE for mean imputation
missing_mask = (1 - mask_train).astype(bool)
rmse_mean = np.sqrt(np.mean((X_train_original[missing_mask] - X_imputed_mean[missing_mask])**2))

print(f"Mean Imputation RMSE: {rmse_mean:.5f}")

In [None]:
# Summary of results
print("\n" + "="*60)
print("IMPUTATION RESULTS SUMMARY")
print("="*60)
print(f"Method               | RMSE")
print("-"*40)
print(f"Mean Imputation      | {rmse_mean:.5f}")
print(f"MIWAE                | {rmse_miwae:.5f}")
print(f"not-MIWAE            | {rmse_notmiwae:.5f}")
print("-"*40)

improvement = (rmse_miwae - rmse_notmiwae) / rmse_miwae * 100
print(f"\nnot-MIWAE improvement over MIWAE: {improvement:.2f}%")

## 7. Interpret the Missing Process

One unique feature of not-MIWAE is the ability to interpret the learned missing process. We can analyze which features influence the probability of missingness.

In [None]:
# Get feature names from the dataset
feature_names = list(data.drop('quality', axis=1).columns)

# Create a new model with feature names for interpretation
notmiwae_interp = NotMIWAE(
    input_dim=D,
    latent_dim=latent_dim,
    hidden_dim=hidden_dim,
    n_samples=n_samples,
    out_dist='gauss',
    missing_process='selfmasking_known',
    feature_names=feature_names
)

# Load the trained weights
notmiwae_interp.load_state_dict(notmiwae.state_dict())
notmiwae_interp.eval()

# Interpret the missing process
interpretation = notmiwae_interp.interpret_missing_process()
print(interpretation)

In [None]:
# Visualize missing process coefficients
import torch.nn.functional as F

fig, ax = plt.subplots(figsize=(12, 5))

# For selfmasking_known, we have per-feature coefficients
# The missing process uses logit(p(s=1|x)) = -W * (x - b)
# Access through missing_model (not missing_process)
W = F.softplus(notmiwae_interp.missing_model.W).detach().squeeze().cpu().numpy()
b = notmiwae_interp.missing_model.b.detach().squeeze().cpu().numpy()

# Plot W coefficients for each feature
x_pos = np.arange(D)
colors = ['red' if w > 0.5 else 'blue' for w in W]

bars = ax.bar(x_pos, W, color=colors, alpha=0.7)
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

ax.set_xticks(x_pos)
ax.set_xticklabels(feature_names, rotation=45, ha='right')
ax.set_xlabel('Feature')
ax.set_ylabel('Missing Process Weight (W)')
ax.set_title('Missing Process Coefficients\n(Higher W = stronger self-masking effect)')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('missing_process_interpretation.png', dpi=150)
plt.show()

print("\nNote: In self-masking MNAR, the model learns: logit(p(s=1|x)) = -W*(x-b)")
print("Larger W means higher values are more likely to be missing.")
print(f"The first {D//2} features were made MNAR in our simulation.")

## 8. Visualize Imputation Results

In [None]:
# Visualize imputation for a specific feature
feature_idx = 0  # First feature (has MNAR missing values)
missing_idx = mask_train[:, feature_idx] == 0

fig, ax = plt.subplots(figsize=(12, 6))

# Sample indices for visualization
n_show = min(50, missing_idx.sum())
sample_indices = np.where(missing_idx)[0][:n_show]

# Plot original values, filled values, and imputed values
x_pos = np.arange(n_show)
width = 0.2  # Narrower bars to fit 4 groups

ax.bar(x_pos - 1.5*width, X_train_original[sample_indices, feature_idx], width, label='Original (true)', alpha=0.8, color='C0')
ax.bar(x_pos - 0.5*width, X_train_filled[sample_indices, feature_idx], width, label='Filled (mean)', alpha=0.8, color='C1')
ax.bar(x_pos + 0.5*width, X_imputed_miwae[sample_indices, feature_idx], width, label='MIWAE imputed', alpha=0.8, color='C2')
ax.bar(x_pos + 1.5*width, X_imputed_notmiwae[sample_indices, feature_idx], width, label='not-MIWAE imputed', alpha=0.8, color='C3')

ax.set_xlabel('Sample Index')
ax.set_ylabel('Value')
ax.set_title(f'Imputation Comparison for Feature {feature_idx}\n(showing {n_show} missing values)')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('imputation_comparison.png', dpi=150)
plt.show()

In [None]:
# Compare imputation methods for a specific feature
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

feature_idx = 0  # Feature with MNAR missing values
missing_idx = mask_train[:, feature_idx] == 0

# Scatter plot: Original vs Imputed
axes[0].scatter(X_train_original[missing_idx, feature_idx], 
                X_imputed_mean[missing_idx, feature_idx], 
                alpha=0.5, label='Mean')
axes[0].scatter(X_train_original[missing_idx, feature_idx], 
                X_imputed_miwae[missing_idx, feature_idx], 
                alpha=0.5, label='MIWAE')
axes[0].scatter(X_train_original[missing_idx, feature_idx], 
                X_imputed_notmiwae[missing_idx, feature_idx], 
                alpha=0.5, label='not-MIWAE')

# Perfect imputation line
lim = [X_train_original[missing_idx, feature_idx].min(), 
       X_train_original[missing_idx, feature_idx].max()]
axes[0].plot(lim, lim, 'k--', label='Perfect')

axes[0].set_xlabel('Original Value')
axes[0].set_ylabel('Imputed Value')
axes[0].set_title(f'Imputation Quality (Feature {feature_idx})')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Error distribution
errors_mean = X_imputed_mean[missing_idx, feature_idx] - X_train_original[missing_idx, feature_idx]
errors_miwae = X_imputed_miwae[missing_idx, feature_idx] - X_train_original[missing_idx, feature_idx]
errors_notmiwae = X_imputed_notmiwae[missing_idx, feature_idx] - X_train_original[missing_idx, feature_idx]

axes[1].hist(errors_mean, bins=30, alpha=0.5, label=f'Mean (bias={errors_mean.mean():.3f})')
axes[1].hist(errors_miwae, bins=30, alpha=0.5, label=f'MIWAE (bias={errors_miwae.mean():.3f})')
axes[1].hist(errors_notmiwae, bins=30, alpha=0.5, label=f'not-MIWAE (bias={errors_notmiwae.mean():.3f})')
axes[1].axvline(x=0, color='k', linestyle='--')
axes[1].set_xlabel('Imputation Error')
axes[1].set_ylabel('Count')
axes[1].set_title('Error Distribution')
axes[1].legend()

# Box plot of absolute errors
abs_errors = [np.abs(errors_mean), np.abs(errors_miwae), np.abs(errors_notmiwae)]
bp = axes[2].boxplot(abs_errors, labels=['Mean', 'MIWAE', 'not-MIWAE'])
axes[2].set_ylabel('Absolute Error')
axes[2].set_title('Absolute Error Distribution')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('imputation_analysis.png', dpi=150)
plt.show()