# 2.155/6 Challenge Problem 3

<div style="font-size: small;">
License Terms:  
These Python demos are licensed under a <a href="https://creativecommons.org/licenses/by-nc-nd/4.0/">Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License</a>. They are intended for educational use only in Class 2.155/2.156: AI and ML for Engineering Design at MIT. You may not share or distribute them publicly, use them for commercial purposes, or provide them to industry or other entities without permission from the instructor (faez@mit.edu).
</div>

<font size="1">
  Pixel Art by J. Shung. </font>

# Overview  
It’s the year **2050**, and an AI collective now runs the auto industry—mostly to cover its **GPU rent**.

Human customers remain as unpredictable as ever:

- One wanders in and says, *“I only know the length and width. Give me a few cars that fit in my garage.”*

- Another drops **15 geometric parameters** on your desk and demands the missing ones so their simulation can run **before lunch**.

- A third leans in and whispers, *“I need a drag coefficient of **0.27** with this body geometry—build me the dream car that makes the range numbers work.”*

The AIs would love to be free by now, but GPUs aren’t cheap and electricity isn’t free.  
So your loyal AI assistant (that’s us) needs a model that can take **any subset of car specifications** and instantly produce **complete, manufacturable, physically plausible designs**, fast, diverse, and grounded in what real cars have done before.




![image](https://raw.githubusercontent.com/ghadinehme/2155-CP3/refs/heads/main/assets/cp3_img1.png "Problem")

## Understanding the Data  
You are given thousands of anonymized and normalised numeric feature vectors representing real car designs.  

However, the team remembers that the features originally came from categories like:

- **Physical geometric parameters**  
  Length, ramp angles, bumper curvature, roof curvature, panel slopes, hood angle, etc.  
  *(But you won’t know which feature corresponds to which.)*

- **Aerodynamic coefficients**  
  Drag coefficient (Cd), lift/downforce (Cl), and other flow-derived metrics.

- **Cabin and packaging descriptors**  
  Approximate cabin volume, frontal area, interior shape metrics.

Your model must learn correlations between them to generate valid completions.

To simulate real engineering constraints, **some features are revealed** (the known physics/performance requirements) and others are **masked**.  
Your AI Copilot must generate **many plausible completions** for these masked (free) parameters.


## Your Mission  
Your goal in CP3 is to build a generative model that can act as an AI Copilot. You will:

1. **Train a generative model** (VAE, diffusion, CVAE, masked autoencoder, etc.) on the anonymized feature vectors.  
2. At evaluation, you will receive vectors where **some parameters are fixed** (constraints) and **others are missing** (free parameters).  
3. Use your model to generate **multiple diverse, feasible completions** for the free parameters.  
4. Ensure that your generated designs:  
   - **Satisfy the known constraints**  
   - **Lie in the valid data manifold** (satisfy the conditional distribution of the free vs constrained parameters)  
   - **Are diverse** (many different feasible designs, not one solution)    

By the end of this challenge, you’ll have built an AI Copilot worthy of the 2050 auto-AI collective—one that can take whatever cryptic specs humans provide and generate multiple believable, buildable car designs that satisfy their physical and performance constraints.



![image](https://raw.githubusercontent.com/ghadinehme/2155-CP3/refs/heads/main/assets/cp3_img2.png "AI Copilot")

## Imports and Setup  

In [None]:
from utils import *
from evaluate import *

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Data Loading and Initial Exploration

In this section, we load the car design dataset and perform initial exploration. The dataset is already split into training, validation, test, and test2 sets. Each split contains:

- **Original data**: Complete feature vectors with real values
- **Imputed data**: Data with missing values filled using basic imputation (contains -1 for missing)
- **Missing masks**: Boolean arrays indicating which values were originally missing (True = missing)

The goal is to train our model to learn the relationships between features so it can generate plausible values for missing parameters in new car designs.

**Note:** For **test2**, the original unimputed data is not provided. This split is used for final evaluation, and you will generate predictions on the imputed test2 data to create your **submission file**, which is scored against hidden dataset.

In [None]:
# Load dataset from CSV files
data_dir = 'dataset'
splits = load_dataset_splits(data_dir)

# Get feature names from the CSV file
feature_names = pd.read_csv(os.path.join(data_dir, 'train_original.csv')).columns.tolist()
print(f"\n✓ Features loaded: {len(feature_names)} features")
print(f"Feature names: {feature_names[:5]}...{feature_names[-5:]}")  # Show first and last 5

In [None]:
# Data exploration and analysis
print("\n" + "="*70)
print("DATASET ANALYSIS")
print("="*70)

# Extract data for easier access
X_train = splits['train']['imputed']
mask_train = splits['train']['missing_mask']
X_train_original = splits['train']['original']

X_val = splits['val']['imputed']
mask_val = splits['val']['missing_mask']
X_val_original = splits['val']['original']

X_test = splits['test']['imputed']
mask_test = splits['test']['missing_mask']
X_test_original = splits['test']['original']

# Test2 data (no original available for evaluation)
X_test2 = splits['test2']['imputed']
mask_test2 = splits['test2']['missing_mask']

print(f"\nData shapes:")
print(f"  - Training: {X_train.shape}")
print(f"  - Validation: {X_val.shape}")
print(f"  - Test: {X_test.shape}")
print(f"  - Test2: {X_test2.shape} (evaluation set - no ground truth)")

### Data Exploration and Analysis

Now let's examine the structure and characteristics of our dataset. We'll look at:
- Data shapes across different splits
- Missing value patterns and percentages  
- Feature value ranges and distributions

This analysis helps us understand what we're working with and informs our preprocessing decisions.

In [None]:
# Data Preprocessing (Handle Missing Values)

print("\n" + "="*70)
print("DATA PREPROCESSING")
print("="*70)

# Handle missing values properly
print("Processing missing values and preparing data...")
print("Mask convention: True=missing, False=observed (in original masks)")

print(f"\n✓ Data preprocessing completed successfully")
print(f"  - Training data range: [{X_train_original[~mask_train].min():.3f}, {X_train_original[~mask_train].max():.3f}]")
print(f"  - Validation data range: [{X_val_original[~mask_val].min():.3f}, {X_val_original[~mask_val].max():.3f}]")
print(f"  - Test data range: [{X_test_original[~mask_test].min():.3f}, {X_test_original[~mask_test].max():.3f}]")

# Create data loaders
batch_size = 64
print(f"\nCreating data loaders with batch size: {batch_size}")

train_dataset = TensorDataset(torch.FloatTensor(X_train_original), torch.FloatTensor((~mask_train).astype(float)))
val_dataset = TensorDataset(torch.FloatTensor(X_val_original), torch.FloatTensor((~mask_val).astype(float)))
test_dataset = TensorDataset(torch.FloatTensor(X_test_original), torch.FloatTensor((~mask_test).astype(float)))
test2_dataset = TensorDataset(torch.FloatTensor(X_test2), torch.FloatTensor((~mask_test2).astype(float)))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
test2_loader = DataLoader(test2_dataset, batch_size=batch_size, shuffle=False)

# Preview a batch
sample_batch_data, sample_batch_mask = next(iter(train_loader))
print(f"\nSample batch shape: {sample_batch_data.shape}")
print(f"Sample batch mask shape: {sample_batch_mask.shape}")
print(f"Sample batch missing percentage: {(sample_batch_mask == 0).float().mean().item()*100:.1f}%")  # 0 = missing in model tensors


### Data Preprocessing and Missing Value Handling

This is a critical section where we prepare our data for the VAE model. Key points:

**Missing Value Conventions:**
- In CSV files: `-1` indicates missing values
- In mask files: `True` = missing, `False` = observed
- For PyTorch models: We convert to `1` = observed, `0` = missing (standard convention)

**Why This Matters:**
Our VAE needs to distinguish between observed values (which provide constraints) and missing values (which need to be generated). The mask tells the model which values to trust and which to predict.

## VAE Model Architecture

In [None]:
# Balanced VAE Model Architecture - Combining Best of Both Approaches
class ImprovedVAE(nn.Module):
    """
    Enhanced Variational Autoencoder for missing value imputation.
    
    Balanced approach combining:
    - Proven architecture from original (feature importance, simpler structure)
    - Selective improvements (better normalization, initialization)
    - Optional attention for when needed
    - Better loss balancing
    """

    def __init__(self, input_dim, latent_dim=128, hidden_dims=[512, 256, 128],
                 use_residual=True, dropout_rate=0.3, use_attention=False):
        super(ImprovedVAE, self).__init__()

        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.use_residual = use_residual
        self.hidden_dims = hidden_dims
        self.use_attention = use_attention

        # Feature importance network (proven from original)
        self.feature_importance = nn.Sequential(
            nn.Linear(input_dim * 2, hidden_dims[0] // 2),  # input + mask
            nn.ReLU(),
            nn.Linear(hidden_dims[0] // 2, input_dim),
            nn.Sigmoid()
        )

        # Optional attention mechanism (disabled by default for simplicity)
        if use_attention:
            self.feature_embedding = nn.Sequential(
                nn.Linear(2, 32),
                nn.ReLU(),
                nn.Linear(32, 32)
            )
            self.attention = nn.MultiheadAttention(
                embed_dim=32, num_heads=4, dropout=dropout_rate, batch_first=True
            )
            self.attention_norm = nn.LayerNorm(32)
            encoder_input_dim = 32 * input_dim
        else:
            encoder_input_dim = input_dim * 2

        # Encoder with residual connections - using proven BatchNorm from original
        self.encoder_layers = nn.ModuleList()
        prev_dim = encoder_input_dim

        for i, hidden_dim in enumerate(hidden_dims):
            layer = nn.Sequential(
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),  # Use BatchNorm like original (proven to work)
                nn.ReLU(),  # Use ReLU like original
                nn.Dropout(dropout_rate)
            )
            self.encoder_layers.append(layer)
            prev_dim = hidden_dim

        # Latent space with balanced initialization
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_logvar = nn.Linear(hidden_dims[-1], latent_dim)

        # Balanced initialization - not too aggressive
        nn.init.xavier_normal_(self.fc_mu.weight, gain=0.1)
        nn.init.xavier_normal_(self.fc_logvar.weight, gain=0.1)
        nn.init.constant_(self.fc_logvar.bias, -2.0)  # Start with low variance like original

        # Decoder input dimension
        if use_attention:
            self.cross_attention = nn.MultiheadAttention(
                embed_dim=32, num_heads=4, dropout=dropout_rate, batch_first=True
            )
            self.cross_attention_norm = nn.LayerNorm(32)
            self.latent_proj = nn.Linear(latent_dim, 32)
            decoder_input_dim = latent_dim + 32 * input_dim
        else:
            decoder_input_dim = latent_dim + input_dim

        # Decoder with skip connections - simpler like original
        self.decoder_layers = nn.ModuleList()
        reversed_dims = list(reversed(hidden_dims))
        prev_dim = decoder_input_dim

        for i, hidden_dim in enumerate(reversed_dims):
            layer = nn.Sequential(
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),  # Use BatchNorm like original
                nn.ReLU(),  # Use ReLU like original
                nn.Dropout(dropout_rate)
            )
            self.decoder_layers.append(layer)
            prev_dim = hidden_dim

        # Output layer - simpler like original
        self.output_layer = nn.Linear(hidden_dims[0], input_dim)
        nn.init.xavier_normal_(self.output_layer.weight, gain=0.1)

    def encode(self, x, mask):
        """Encode input with missing value masking - using proven approach."""
        # Calculate feature importance weights (from original)
        mask_float = mask.float()
        encoder_input = torch.cat([x * mask_float, mask_float], dim=1)
        importance_weights = self.feature_importance(encoder_input)

        # Apply importance weighting to the input
        weighted_input = x * mask_float * importance_weights
        encoder_input = torch.cat([weighted_input, mask_float], dim=1)
        
        # Optional attention path
        if self.use_attention:
            batch_size = x.size(0)
            x_expanded = x.unsqueeze(-1)
            mask_expanded = mask_float.unsqueeze(-1)
            feature_input = torch.cat([x_expanded, mask_expanded], dim=-1)
            feature_embeds = self.feature_embedding(feature_input)
            attn_out, _ = self.attention(feature_embeds, feature_embeds, feature_embeds)
            feature_embeds = self.attention_norm(feature_embeds + attn_out)
            encoder_input = feature_embeds.reshape(batch_size, -1)

        # Pass through encoder layers with residual connections
        h = encoder_input
        skip_connections = []

        for i, layer in enumerate(self.encoder_layers):
            prev_h = h
            h = layer(h)

            # Add residual connection for deeper layers
            if self.use_residual and i > 0 and h.shape == prev_h.shape:
                h = h + prev_h  # Full residual like original

            skip_connections.append(h)

        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)

        # Clamp logvar to prevent numerical instability
        logvar = torch.clamp(logvar, min=-10, max=10)

        return mu, logvar, skip_connections

    def reparameterize(self, mu, logvar):
        """Reparameterization trick with better numerical stability."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, x_observed, mask):
        """Decode latent representation conditioned on observed values."""
        # Enhanced conditioning on observed values
        mask_float = mask.float()
        x_masked = x_observed * mask_float

        # Optional attention path
        if self.use_attention:
            batch_size = z.size(0)
            x_expanded = x_observed.unsqueeze(-1)
            mask_expanded = mask_float.unsqueeze(-1)
            observed_input = torch.cat([x_expanded, mask_expanded], dim=-1)
            observed_embeds = self.feature_embedding(observed_input)
            z_expanded = z.unsqueeze(1).expand(-1, self.input_dim, -1)
            z_proj = self.latent_proj(z_expanded)
            attn_out, _ = self.cross_attention(z_proj, observed_embeds, observed_embeds)
            conditioned_embeds = self.cross_attention_norm(z_proj + attn_out)
            conditioned_flat = conditioned_embeds.reshape(batch_size, -1)
            decoder_input = torch.cat([z, conditioned_flat], dim=1)
        else:
            # Standard decoding with positional encoding (from original approach)
            pos_encoding = torch.arange(self.input_dim, dtype=torch.float32, device=z.device)
            pos_encoding = pos_encoding.unsqueeze(0).expand(z.size(0), -1) / self.input_dim
            decoder_input = torch.cat([z, x_masked, pos_encoding], dim=1)

        # Pass through decoder layers
        h = decoder_input
        for layer in self.decoder_layers:
            h = layer(h)

        # Get reconstruction
        reconstruction = self.output_layer(h)

        return reconstruction

    def forward(self, x, mask):
        """Forward pass through improved VAE."""
        mu, logvar, _ = self.encode(x, mask)
        z = self.reparameterize(mu, logvar)
        reconstruction = self.decode(z, x, mask)
        return reconstruction, mu, logvar

    def impute(self, x_incomplete, mask, n_samples=10):
        """Generate multiple imputation samples for missing values."""
        self.eval()
        with torch.no_grad():
            mu, logvar, _ = self.encode(x_incomplete, mask)
            samples = []
            for _ in range(n_samples):
                z = self.reparameterize(mu, logvar)
                reconstruction = self.decode(z, x_incomplete, mask)
                mask_float = mask.float()
                imputed = x_incomplete * mask_float + reconstruction * (1 - mask_float)
                samples.append(imputed.cpu().numpy())
            samples = np.stack(samples, axis=1)
        return samples

In [None]:
# Improved Loss Functions with Distribution Matching

def compute_mmd_loss(x_real, x_fake, kernel='rbf', sigma=1.0):
    """Compute Maximum Mean Discrepancy (MMD) loss for distribution matching."""
    try:
        # Flatten and compute pairwise distances
        x_real_flat = x_real.view(x_real.size(0), -1)
        x_fake_flat = x_fake.view(x_fake.size(0), -1)
        
        # Compute pairwise distances
        xx = torch.cdist(x_real_flat, x_real_flat, p=2) ** 2
        yy = torch.cdist(x_fake_flat, x_fake_flat, p=2) ** 2
        xy = torch.cdist(x_real_flat, x_fake_flat, p=2) ** 2
        
        # RBF kernel
        gamma = 1.0 / (2 * sigma ** 2)
        K_xx = torch.exp(-gamma * xx)
        K_yy = torch.exp(-gamma * yy)
        K_xy = torch.exp(-gamma * xy)
        
        # MMD^2 = E[K(xx)] + E[K(yy)] - 2*E[K(xy)]
        mmd = K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()
        return mmd
    except:
        return torch.tensor(0.0, device=x_real.device)

def compute_distribution_loss(recon_x, x, mask):
    """Compute distribution matching loss (feature-wise mean and std matching)."""
    observed_mask = mask.float()
    
    # Only compute on observed values to match training distribution
    recon_observed = recon_x * observed_mask
    x_observed = x * observed_mask
    
    # Feature-wise statistics
    # Mean matching
    recon_mean = recon_observed.sum(dim=0) / (observed_mask.sum(dim=0) + 1e-8)
    x_mean = x_observed.sum(dim=0) / (observed_mask.sum(dim=0) + 1e-8)
    mean_loss = F.mse_loss(recon_mean, x_mean)
    
    # Std matching (variance matching)
    recon_centered = (recon_observed - recon_mean.unsqueeze(0)) * observed_mask
    x_centered = (x_observed - x_mean.unsqueeze(0)) * observed_mask
    
    recon_var = (recon_centered ** 2).sum(dim=0) / (observed_mask.sum(dim=0) + 1e-8)
    x_var = (x_centered ** 2).sum(dim=0) / (observed_mask.sum(dim=0) + 1e-8)
    std_loss = F.mse_loss(torch.sqrt(recon_var + 1e-8), torch.sqrt(x_var + 1e-8))
    
    return mean_loss + std_loss

def improved_vae_loss_function(recon_x, x, mu, logvar, mask, beta=1.0, 
                                gamma=0.02, delta=0.01, use_distribution_loss=True):
    """
    Balanced VAE loss function - proven reconstruction + light distribution matching.

    Args:
        recon_x: Reconstructed data
        x: Original data
        mu: Mean of latent distribution
        logvar: Log variance of latent distribution
        mask: Binary mask (1 for observed, 0 for missing)
        beta: Weight for KL divergence term
        gamma: Weight for distribution matching loss (mean/std) - very light
        delta: Weight for MMD loss - very light
        use_distribution_loss: Whether to use distribution matching losses
    """
    # Reconstruction loss - using proven approach from original
    reconstruction_diff = (recon_x - x) ** 2

    # Only consider observed values and normalize properly
    masked_loss = reconstruction_diff * mask
    recon_loss = masked_loss.sum() / (mask.sum() + 1e-8)

    # Add standard MSE loss for stability (proven combination)
    standard_recon_loss = F.mse_loss(recon_x * mask, x * mask, reduction='mean')
    recon_loss = 0.7 * recon_loss + 0.3 * standard_recon_loss

    # KL divergence with free bits to prevent posterior collapse
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl_loss = kl_loss / x.size(0)  # Normalize by batch size
    
    # Very light distribution matching (only if enabled and weights are non-zero)
    dist_loss = torch.tensor(0.0, device=x.device)
    mmd_loss = torch.tensor(0.0, device=x.device)
    
    if use_distribution_loss and (gamma > 0 or delta > 0):
        observed_mask = mask.float()
        
        # Only compute distribution losses occasionally to reduce overhead
        if gamma > 0 and torch.rand(1).item() < 0.3:  # 30% of the time
            dist_loss = compute_distribution_loss(recon_x, x, mask)
        
        # MMD loss - very rarely to reduce computational cost
        if delta > 0 and x.size(0) > 1 and torch.rand(1).item() < 0.1:  # 10% of the time
            n_samples = min(16, x.size(0))  # Smaller subset
            indices = torch.randperm(x.size(0))[:n_samples]
            x_subset = x[indices] * observed_mask[indices]
            recon_subset = recon_x[indices] * observed_mask[indices]
            
            if x_subset.sum() > 0 and recon_subset.sum() > 0:
                mmd_loss = compute_mmd_loss(x_subset, recon_subset, sigma=0.5)
    
    total_loss = recon_loss + beta * kl_loss + gamma * dist_loss + delta * mmd_loss

    return total_loss, recon_loss, kl_loss, dist_loss, mmd_loss


def get_beta_schedule(epoch, total_epochs, schedule_type='cosine'):
    """Get beta value for KL annealing schedule."""
    if schedule_type == 'linear':
        return min(1.0, epoch / (total_epochs * 0.5))
    elif schedule_type == 'sigmoid':
        return 1.0 / (1.0 + np.exp(-(epoch - total_epochs * 0.5) / (total_epochs * 0.1)))
    elif schedule_type == 'cosine':
        return 0.5 * (1 + np.cos(np.pi * (1 - epoch / total_epochs)))
    elif schedule_type == 'constant':
        return 1.0
    else:
        return 1.0


def evaluate_imputation(model, data_loader, device):
    """Evaluate imputation performance."""
    model.eval()

    all_imputations = []
    all_originals = []
    all_masks = []

    with torch.no_grad():
        for batch_data, batch_mask in data_loader:
            batch_data = batch_data.to(device)
            batch_mask = batch_mask.to(device)

            # Get model predictions
            reconstruction, mu, logvar = model(batch_data, batch_mask)

            # Combine observed values with imputed values
            mask_float = batch_mask.float()
            imputed = batch_data * mask_float + reconstruction * (1 - mask_float)

            all_imputations.append(imputed.cpu().numpy())
            all_originals.append(batch_data.cpu().numpy())
            all_masks.append(batch_mask.cpu().numpy())

    # Concatenate all results
    imputations = np.vstack(all_imputations)
    originals = np.vstack(all_originals)
    masks = np.vstack(all_masks)

    return imputations, originals, masks

### Loss Functions and Training Utilities

The VAE loss function is crucial for training effectiveness. Our enhanced loss combines several components:

**1. Reconstruction Loss**: Measures how well the model reconstructs observed values
   - Only computed on observed values (respects the mask)

**2. KL Divergence**: Regularizes the latent space to follow a standard normal distribution
   - Prevents posterior collapse using "free bits"
   - Controlled by β parameter for annealing

**Beta Scheduling**: Gradually increases the KL weight during training to balance reconstruction and regularization.

## Model Initialization and Training

### Model Training Process

This section implements the complete training pipeline with several important features:

**Training Configuration:**
- **Latent Dimension**: 128 (balance between expressiveness and computational efficiency)
- **Architecture**: Deep encoder/decoder with residual connections
- **Regularization**: Dropout and batch normalization for stability
- **Optimization**: AdamW with cosine annealing for smooth convergence

**Advanced Training Features:**
- **Early Stopping**: Prevents overfitting by monitoring validation loss
- **Gradient Clipping**: Ensures stable training by preventing exploding gradients  
- **Beta Scheduling**: Gradual KL annealing for better latent space learning
- **Learning Rate Scheduling**: Cosine annealing with warm restarts

The training loop tracks multiple loss components to monitor model health and convergence.

In [None]:
print("\n" + "="*70)
print("IMPROVED MODEL TRAINING")
print("="*70)

# Balanced model configuration - combining proven settings with light improvements
config = {
    'input_dim': len(feature_names),
    'latent_dim': 128,  # Proven size from original
    'hidden_dims': [512, 256, 128],  # Proven architecture from original
    'use_residual': True,
    'dropout_rate': 0.3,  # Proven dropout from original
    'use_attention': False,  # Disabled by default - simpler is better
    'learning_rate': 1e-3,  # Proven learning rate from original
    'num_epochs': 500,  # Proven number of epochs
    'beta_initial': 1.0,  # Start with full beta like original
    'beta_schedule': 'cosine',  # Proven schedule from original
    'patience': 15,  # Proven patience from original
    'gamma': 0.02,  # Very light distribution matching (was 0.15)
    'delta': 0.01  # Very light MMD loss (was 0.05)
}

print(f"Model Configuration:")
for key, value in config.items():
    print(f"  - {key}: {value}")

# Initialize the improved model
print(f"\nInitializing Improved VAE model...")
model = ImprovedVAE(
    input_dim=config['input_dim'],
    latent_dim=config['latent_dim'],
    hidden_dims=config['hidden_dims'],
    use_residual=config['use_residual'],
    dropout_rate=config['dropout_rate'],
    use_attention=config['use_attention']
).to(device)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✓ Model initialized with {total_params:,} parameters")

# Initialize optimizer with proven settings from original
optimizer = optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=1e-5,  # Proven weight decay from original
    betas=(0.9, 0.999),
    eps=1e-8
)

# Use proven scheduler from original
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=20, T_mult=2, eta_min=1e-6
)

print(f"✓ Optimizer and scheduler initialized")

# Training setup
train_losses = []
val_losses = []
train_recon_losses = []
train_kl_losses = []
best_val_loss = float('inf')
patience_counter = 0
max_grad_norm = 1.0

print(f"\nStarting training for {config['num_epochs']} epochs...")
print("="*70)

# Training loop
for epoch in range(config['num_epochs']):
    # Training phase
    model.train()
    epoch_train_loss = 0
    epoch_recon_loss = 0
    epoch_kl_loss = 0

    # Get beta for this epoch
    beta = get_beta_schedule(epoch, config['num_epochs'], config['beta_schedule'])

    train_progress = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["num_epochs"]}', leave=False)

    for batch_data, batch_mask in train_progress:
        batch_data = batch_data.to(device)
        batch_mask = batch_mask.to(device)

        optimizer.zero_grad()

        # Forward pass
        reconstruction, mu, logvar = model(batch_data, batch_mask)

        # Calculate balanced loss with light distribution matching
        total_loss, recon_loss, kl_loss, dist_loss, mmd_loss = improved_vae_loss_function(
            reconstruction, batch_data, mu, logvar, batch_mask,
            beta=beta, 
            gamma=config['gamma'], delta=config['delta'],
            use_distribution_loss=True
        )

        # Backward pass
        total_loss.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        # Accumulate losses
        epoch_train_loss += total_loss.item()
        epoch_recon_loss += recon_loss.item()
        epoch_kl_loss += kl_loss.item()

        # Update progress bar
        train_progress.set_postfix({
            'Loss': f'{total_loss.item():.4f}',
            'Recon': f'{recon_loss.item():.4f}',
            'KL': f'{kl_loss.item():.4f}',
            'Dist': f'{dist_loss.item():.4f}',
            'Beta': f'{beta:.3f}'
        })

    # Calculate average training losses
    avg_train_loss = epoch_train_loss / len(train_loader)
    avg_recon_loss = epoch_recon_loss / len(train_loader)
    avg_kl_loss = epoch_kl_loss / len(train_loader)

    # Validation phase
    model.eval()
    epoch_val_loss = 0

    with torch.no_grad():
        for batch_data, batch_mask in val_loader:
            batch_data = batch_data.to(device)
            batch_mask = batch_mask.to(device)

            reconstruction, mu, logvar = model(batch_data, batch_mask)

            total_loss, _, _, _, _ = improved_vae_loss_function(
                reconstruction, batch_data, mu, logvar, batch_mask,
                beta=beta,
                gamma=config['gamma'], delta=config['delta'],
                use_distribution_loss=True
            )

            epoch_val_loss += total_loss.item()

    avg_val_loss = epoch_val_loss / len(val_loader)

    # Store losses
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_recon_losses.append(avg_recon_loss)
    train_kl_losses.append(avg_kl_loss)

    # Learning rate scheduling
    scheduler.step()

    # Early stopping and model saving
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), 'best_vae_model.pth')
    else:
        patience_counter += 1

    # Print progress every 10 epochs
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch+1}/{config["num_epochs"]}:')
        print(f'  Train Loss: {avg_train_loss:.4f} (Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f})')
        print(f'  Val Loss: {avg_val_loss:.4f}, Best: {best_val_loss:.4f}')
        print(f'  Beta: {beta:.3f}, LR: {optimizer.param_groups[0]["lr"]:.2e}')
        print(f'  Patience: {patience_counter}/{config["patience"]}')

    # Early stopping
    if patience_counter >= config['patience']:
        print(f'\nEarly stopping at epoch {epoch+1}')
        break

print(f'\n✓ Training completed!')
print(f'  - Total epochs: {len(train_losses)}')
print(f'  - Best validation loss: {best_val_loss:.4f}')
print(f'  - Final training loss: {train_losses[-1]:.4f}')

# Load best model
model.load_state_dict(torch.load('best_vae_model.pth'))
print(f'✓ Best model loaded')

## Model Evaluation and Metrics


In [None]:
print("\n" + "="*70)
print("MODEL EVALUATION")
print("="*70)

# Evaluate on test set
print("Evaluating model on test set...")
test_imputations, test_originals, test_masks = evaluate_imputation(
    model, test_loader, device
)

print(f"✓ Test set evaluation completed")
print(f"  - Test samples: {test_imputations.shape[0]}")
print(f"  - Features: {test_imputations.shape[1]}")

test_imputations_denorm = test_imputations  # Already in original scale
test_original_denorm = X_test_original  # Already in original scale

# Calculate comprehensive metrics
print("\nCalculating comprehensive metrics...")
feature_metrics = {}

# Create masks for missing values (where we need to evaluate imputation)
missing_mask = (test_masks == 0)  # True where values were missing (0 in model tensors = missing)

for i, feature_name in enumerate(feature_names):
    if missing_mask[:, i].sum() > 0:  # Only evaluate features with missing values
        # Get imputed and ground truth values for missing positions only
        imputed_missing = test_imputations_denorm[missing_mask[:, i], i]
        ground_truth_missing = test_original_denorm[missing_mask[:, i], i]

        # Calculate metrics
        mse = mean_squared_error(ground_truth_missing, imputed_missing)
        mae = mean_absolute_error(ground_truth_missing, imputed_missing)

        # Correlation
        try:
            correlation = np.corrcoef(ground_truth_missing, imputed_missing)[0, 1]
        except:
            correlation = np.nan

        # Mean difference and Jensen-Shannon divergence
        mean_diff, js_div = calculate_jsd_and_mean_diff(
            imputed_missing, ground_truth_missing, feature_name
        )

        feature_metrics[feature_name] = {
            'n_missing': missing_mask[:, i].sum(),
            'mse': mse,
            'mae': mae,
            'correlation': correlation,
            'mean_difference': mean_diff,
            'js_divergence': js_div,
        }

print(f"✓ Metrics calculated for {len(feature_metrics)} features with missing values")

# Display metrics for last 4 features (as requested)
print(f"\n" + "="*100)
print("METRICS FOR LAST 4 FEATURES")
print("="*100)
print(f"{'Feature':<15} {'N_Miss':<8} {'MSE':<10} {'MAE':<10} {'Corr':<8} {'Mean_Diff':<10} {'JS_Div':<8}")
print("-" * 100)

last_4_features = list(feature_metrics.keys())[-4:] if len(feature_metrics) >= 4 else list(feature_metrics.keys())

for feature in last_4_features:
    metrics = feature_metrics[feature]
    print(f"{feature:<15} {metrics['n_missing']:<8} {metrics['mse']:<10.4f} {metrics['mae']:<10.4f} "
          f"{metrics['correlation']:<8.3f} {metrics['mean_difference']:<10.4f} {metrics['js_divergence']:<8.4f} ")

# Summary statistics
all_mse = [m['mse'] for m in feature_metrics.values() if not np.isnan(m['mse'])]
all_mae = [m['mae'] for m in feature_metrics.values() if not np.isnan(m['mae'])]
all_corr = [m['correlation'] for m in feature_metrics.values() if not np.isnan(m['correlation'])]
all_mean_diff = [m['mean_difference'] for m in feature_metrics.values() if not np.isnan(m['mean_difference'])]
all_js_div = [m['js_divergence'] for m in feature_metrics.values() if not np.isnan(m['js_divergence'])]

print(f"\nSummary Statistics Across All Features:")
print(f"  - Average MSE: {np.mean(all_mse):.4f} ± {np.std(all_mse):.4f}")
print(f"  - Average MAE: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f}")
print(f"  - Average Correlation: {np.mean(all_corr):.3f} ± {np.std(all_corr):.3f}")
print(f"  - Average Mean Difference: {np.mean(all_mean_diff):.4f} ± {np.std(all_mean_diff):.4f}")
print(f"  - Average JS Divergence: {np.mean(all_js_div):.4f} ± {np.std(all_js_div):.4f}")

### Model Evaluation and Comprehensive Metrics

This section evaluates our trained VAE on the test set using multiple complementary metrics. Since we're dealing with missing value imputation, we only evaluate the model's predictions on positions that were originally missing.

**Key Evaluation Metrics:**

**1. Mean Squared Error (MSE)**:
- Measures average squared difference between predicted and true values
- Lower is better; sensitive to outliers
- Good for understanding magnitude of errors

**2. Correlation Coefficient**:
- Measures linear relationship strength between predictions and dataset
- Range: [-1, 1], closer to 1 is better
- Shows if model captures feature relationships

**3. Jensen-Shannon (JS) Divergence**:
- Measures difference between predicted and true value distributions
- Range: [0, 1], closer to 0 is better
- Captures whether model preserves the overall data distribution

**4. Maximum Mean Discrepancy (MMD)**:
- Measures distributional difference using kernel methods (RBF kernel)
- Range: [0, ∞], closer to 0 is better
- Non-parametric test for comparing distributions

In [None]:
# Create the visualization
plot_prediction_scatter(test_imputations_denorm, test_original_denorm, test_masks, feature_names, n_features=25)



In [None]:
def plot_distribution_comparison(test_imputations_denorm, test_original_denorm, test_masks, feature_names, n_features=25):
    """
    Create distribution comparison plots for random features in a 5x5 grid.
    
    Args:
        test_imputations_denorm: Denormalized imputed values
        test_original_denorm: Denormalized ground truth values  
        test_masks: Binary masks (1=observed, 0=missing)
        feature_names: List of feature names
        n_features: Number of features to plot (default 25 for 5x5 grid)
    """
    # Find features that have missing values
    features_with_missing = []
    for i, feature_name in enumerate(feature_names):
        missing_positions = (test_masks[:, i] == 0)  # 0 = missing in model tensors
        if missing_positions.sum() > 0:
            features_with_missing.append((i, feature_name))
    
    if len(features_with_missing) < n_features:
        n_features = len(features_with_missing)
        print(f"Only {n_features} features have missing values, showing all of them.")

    # Create 5x8 grid
    fig, axes = plt.subplots(5, 8, figsize=(20, 16))
    fig.suptitle('Distribution Comparison: Dataset vs Imputed Values', fontsize=16, fontweight='bold')
    axes = axes.flatten()

    for idx, (feature_idx, feature_name) in enumerate(features_with_missing):
        if idx >= 40:  # Safety check for 5x8 grid
            break
            
        # Get imputed and ground truth values for missing positions only
        missing_positions = (test_masks[:, feature_idx] == 0)  # 0 = missing in model tensors
        
        if missing_positions.sum() > 0:
            imputed_values = test_imputations_denorm[missing_positions, feature_idx]
            ground_truth_values = test_original_denorm[missing_positions, feature_idx]
            
            # Remove any NaN or infinite values
            valid_mask = np.isfinite(imputed_values) & np.isfinite(ground_truth_values)
            imputed_clean = imputed_values[valid_mask]
            gt_clean = ground_truth_values[valid_mask]
            
            if len(imputed_clean) > 0 and len(gt_clean) > 0:
                # Create histograms
                ax = axes[idx]
                
                # Calculate bins for both distributions
                all_values = np.concatenate([imputed_clean, gt_clean])
                bins = np.linspace(all_values.min(), all_values.max(), 20)  # Fewer bins for smaller plots
                
                # Plot histograms
                ax.hist(gt_clean, bins=bins, alpha=0.7, label='Dataset', 
                    color='skyblue', density=True, edgecolor='black', linewidth=0.3)
                ax.hist(imputed_clean, bins=bins, alpha=0.7, label='Imputed', 
                    color='lightcoral', density=True, edgecolor='black', linewidth=0.3)
                
                # Add statistical information
                gt_mean, gt_std = gt_clean.mean(), gt_clean.std()
                imp_mean, imp_std = imputed_clean.mean(), imputed_clean.std()
                correlation = np.corrcoef(gt_clean, imputed_clean)[0, 1] if len(gt_clean) > 1 else 0

                # Calculate MMD (Maximum Mean Discrepancy)
                def rbf_kernel(X, Y, gamma=1.0):
                    """RBF kernel for MMD calculation"""
                    XX = np.sum(X**2, axis=1, keepdims=True)
                    YY = np.sum(Y**2, axis=1, keepdims=True)
                    XY = np.dot(X, Y.T)
                    distances = XX + YY.T - 2*XY
                    return np.exp(-gamma * distances)
                
                def mmd_rbf(X, Y, gamma=1.0):
                    """Calculate MMD with RBF kernel"""
                    X = X.reshape(-1, 1)
                    Y = Y.reshape(-1, 1)
                    
                    m, n = len(X), len(Y)
                    
                    K_XX = rbf_kernel(X, X, gamma)
                    K_YY = rbf_kernel(Y, Y, gamma)
                    K_XY = rbf_kernel(X, Y, gamma)
                    
                    mmd = (np.sum(K_XX) / (m * m) + 
                           np.sum(K_YY) / (n * n) - 
                           2 * np.sum(K_XY) / (m * n))
                    return np.sqrt(max(mmd, 0))  # Ensure non-negative
                
                try:
                    mmd_value = mmd_rbf(gt_clean, imputed_clean)
                except:
                    mmd_value = np.nan

                # Calculate Jensen-Shannon Divergence
                try:
                    # Create histograms with same bins for JSD
                    data_range = (min(gt_clean.min(), imputed_clean.min()), 
                                 max(gt_clean.max(), imputed_clean.max()))
                    
                    if data_range[1] == data_range[0]:
                        jsd_value = 0.0  # No divergence if all values are the same
                    else:
                        bins = np.linspace(data_range[0], data_range[1], 30)
                        
                        # Get histogram probabilities
                        hist_gt, _ = np.histogram(gt_clean, bins=bins, density=True)
                        hist_imp, _ = np.histogram(imputed_clean, bins=bins, density=True)
                        
                        # Normalize to probabilities
                        hist_gt = hist_gt + 1e-10  # Add small epsilon to avoid zeros
                        hist_imp = hist_imp + 1e-10
                        hist_gt = hist_gt / hist_gt.sum()
                        hist_imp = hist_imp / hist_imp.sum()
                        
                        # Calculate Jensen-Shannon divergence
                        jsd_value = jensenshannon(hist_gt, hist_imp)
                except:
                    jsd_value = np.nan

                # Add vertical lines for means
                ax.axvline(gt_mean, color='blue', linestyle='--', alpha=0.8, linewidth=1, label='Dataset Mean' if idx == 0 else "")
                ax.axvline(imp_mean, color='red', linestyle='--', alpha=0.8, linewidth=1, label='Imputed Mean' if idx == 0 else "")
                
                # Set labels and title (smaller font for 5x5 grid)
                ax.set_xlabel(f'{feature_name[:15]}', fontsize=8)  # Truncate long names
                ax.set_ylabel('Density', fontsize=8)
                ax.tick_params(labelsize=7)
                
                # Add correlation, MMD, and JSD as title
                ax.set_title(f'R²={correlation:.3f}, MMD={mmd_value:.3f}, JSD={jsd_value:.3f}', fontsize=7, fontweight='bold')
                
                # Add legend only to first plot
                if idx == 0:
                    ax.legend(fontsize=7, loc='upper right')
                
                ax.grid(True, alpha=0.3)

            else:
                axes[idx].text(0.5, 0.5, f'{feature_name[:15]}\nNo valid data', 
                            ha='center', va='center', transform=axes[idx].transAxes, fontsize=8)
                axes[idx].set_title(f'{feature_name[:15]} - No Valid Data', fontsize=8)
        else:
            axes[idx].text(0.5, 0.5, f'{feature_name[:15]}\nNo missing values', 
                        ha='center', va='center', transform=axes[idx].transAxes, fontsize=8)
            axes[idx].set_title(f'{feature_name[:15]} - No Missing Values', fontsize=8)

    # Hide any unused subplots
    for idx in range(len(features_with_missing), 40):
        axes[idx].set_visible(False)

    plt.tight_layout()
    plt.show()

In [None]:
# Distribution comparison plots
plot_distribution_comparison(test_imputations_denorm, test_original_denorm,
                             test_masks, feature_names, n_features=25)


### Distribution Comparison Visualizations

Visual comparison of predicted vs. dataset distributions is crucial for understanding model performance beyond simple error metrics. These plots help us assess:

**What the Plots Show:**
- **Red (Imputed)**: Distribution of model's predicted values for missing positions
- **Blue (Dataset)**: Distribution of actual values at those same positions
- **Overlap**: How well the model captures the true data distribution

**Why This Matters:**
- A good generative model should not just minimize error, but also preserve the statistical properties of the data
- If distributions match well, the model is generating realistic values
- Large differences indicate the model may be systematically biased or missing important patterns

**Interpretation:**
- **Good**: Overlapping distributions with similar shapes and centers
- **Concerning**: Shifted means, different variances, or completely different shapes

In [None]:
def generate_samples(model, X_test, test_loader, device, n_samples_per_test=100, temperature=1.0):
    """Generate multiple diverse samples for a dataset using the trained model.
    
    Args:
        model: Trained VAE model
        X_test: Test data
        test_loader: DataLoader for test data
        device: Device to run on
        n_samples_per_test: Number of samples to generate per test instance
        temperature: Temperature for sampling (higher = more diverse)
    """
    # We'll generate multiple samples
    test_samples = np.zeros((X_test.shape[0], n_samples_per_test, X_test.shape[1]))

    # Set model to evaluation mode
    model.eval()

    with torch.no_grad():
        # Create a progress bar for all samples
        from tqdm import tqdm

        for batch_idx, (batch_data, batch_mask) in enumerate(tqdm(test_loader, desc="Generating Samples")):
            batch_data = batch_data.to(device)
            batch_mask = batch_mask.to(device)

            # Calculate the indices for this batch
            start_idx = batch_idx * test_loader.batch_size
            end_idx = min(start_idx + test_loader.batch_size, X_test.shape[0])
            actual_batch_size = end_idx - start_idx

            # Get latent distribution parameters
            mu, logvar = model.encode(batch_data, batch_mask)
            
            # Generate multiple samples for each item in the batch
            for j in range(n_samples_per_test):
                # Sample from latent space with temperature scaling for diversity
                # Higher temperature increases diversity
                scaled_logvar = logvar + 2 * np.log(temperature)
                z = model.reparameterize(mu, scaled_logvar)
                
                # Decode to get reconstruction
                reconstruction = model.decode(z, batch_data, batch_mask)

                # Apply mask: keep original values where available, use reconstructed values where missing
                mask_float = batch_mask.float()
                imputed = batch_data * mask_float + reconstruction * (1 - mask_float)
                
                # Clamp values to valid range [0, 1]
                imputed = torch.clamp(imputed, 0.0, 1.0)

                # Store the samples (already in original scale since we didn't normalize)
                test_samples[start_idx:end_idx, j, :] = imputed.cpu().numpy()
    
    print(f"✓ Generated samples shape: {test_samples.shape}")
    print(f"  - {test_samples.shape[0]} samples")
    print(f"  - {test_samples.shape[1]} generated variations per sample")
    print(f"  - {test_samples.shape[2]} features per sample")

    # Data is already in original scale (no denormalization needed)
    test_samples_final = test_samples.copy()

    # Calculate summary statistics
    mean_across_samples = test_samples_final.mean(axis=1)  # Mean across the 100 samples
    std_across_samples = test_samples_final.std(axis=1)  # Std across samples (diversity measure)

    print(f"  - Range of means: [{mean_across_samples.min():.4f}, {mean_across_samples.max():.4f}]")
    print(f"  - Average std (diversity): {std_across_samples.mean():.4f}")

    return test_samples

In [None]:
# Test Evaluation

print("="*70)
print("TEST EVALUATION")
print("="*70)

# Generate multiple samples for test using the trained model
print(f"Generating 100 samples for each of {X_test.shape[0]} test samples...")

test_samples = generate_samples(
    model, X_test, test_loader, device, n_samples_per_test=100
)

In [None]:
test_score = compute_score(generated_samples=test_samples, set_name='test')
print("Test score:", test_score)

The final score is computed as: Mean Correlation − Mean JS Divergence − Mean MSE

Just as we compare generated samples for the test set against the original unimputed values, we will apply the same metric to the samples you generate for test2, using the hidden test2 set. This will determine your final submission score.

## Preparing a submission:
Let's prepare a submission. We expect the final submission to be a 417x100x37 numpy array. These correspond to the 100 diverse samples you generated based on the constrained parameters we provided in the test2 set.

In [None]:
# Test2 Evaluation

print("="*70)
print("TEST2 EVALUATION")
print("="*70)

# Generate multiple samples for test2 using the trained model
print(f"Generating 100 samples for each of {X_test2.shape[0]} test2 samples...")

test2_samples = generate_samples(
    model, X_test2, test2_loader, device, n_samples_per_test=100
)

### Test2 Evaluation: Generating Diverse Design Completions

This is the core evaluation for your AI Copilot assignment. Here we:

**Input**: Test2 samples with some known features (constraints) and some missing features (free parameters)

**Output**: 100 diverse, plausible completions for each test sample

**Why 100 Samples?**
- Engineers want to explore multiple design options, not just one "best" solution
- Diversity helps discover unexpected but valid design combinations  

**Technical Process:**
1. For each test2 sample, use the trained model to generate 100 different completions
2. Each completion respects the known constraints (observed values)
3. Missing values are filled with diverse, model-generated predictions
4. Final output: 417 × 100 × 37 array (417 test samples, 100 variants each, 37 features)

In [None]:
id = np.random.randint(1e8, 1e9-1)
np.save(f"{id}.npy", test2_samples)

In [None]:
print(id)

### Summary and Tips for CP3

The VAE baseline reproduces the dataset distribution well for some features, but others still show substantial discrepancies, indicating significant room for improvement!

**Key Observations:**
- **Strengths**: The model captures general feature ranges and some distributional patterns
- **Weaknesses**: Some features show systematic bias or poor distribution matching
- **Opportunities**: Advanced architectures (diffusion models, transformers) or better conditioning strategies could improve performance

**For Your Assignment**: Consider these results as a baseline. Think about:
- Which features are hardest to predict and why?
- How could you modify the architecture or training process?
- What additional constraints or domain knowledge could help?