# Day 23: Variational Lossy Autoencoder (VLAE)

> Xi Chen et al. (2017) - [Variational Lossy Autoencoder](https://arxiv.org/abs/1611.02731)

**Core problem:** Posterior collapse. VAEs with powerful decoders (PixelCNN) ignore the latent code because local pixel dependencies are "easier" than global semantic modeling.

**VLAE solution:** 
1. **Cripple the decoder:** Use a limited receptive field to force reliance on global $z$.
2. **Boost the prior:** Use autoregressive flows (IAF) to allow for complex, flexible latent distributions.

In this notebook, we implement and verify the VLAE architecture on binarized MNIST.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

from implementation import VLAE, loss_function
from visualization import plot_reconstructions, plot_latent_sampling, plot_bits_heatmap

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Technical Verification: Autoregressive Masking

VLAE relies on strict autoregressive property. We use **MADE** (Masked Autoencoder for Distribution Estimation) for the Flows and **MaskedConv2d** for the PixelCNN decoder.

In [None]:
from implementation import MADE
dim = 10
model = MADE(dim, 32, dim)
x = torch.randn(1, dim, requires_grad=True)
out = model(x)

# Verification: Output i should NOT depend on input >= i
i = 5
out[0, i].backward()
print(f"Gradients for dimension {i}:")
print(x.grad.data.numpy())
assert torch.all(x.grad.data[0, i:] == 0)
print("Autoregressive condition MET.")

## 2. Load and Binarize MNIST

We use binarized MNIST because PixelCNN is a discrete model ($p(x_i | x_{<i})$).

In [None]:
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    lambda x: (x > 0.5).float()
])

train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform), 
                          batch_size=batch_size, shuffle=True)
test_loader = DataLoader(datasets.MNIST('./data', train=False, transform=transform), 
                         batch_size=batch_size, shuffle=False)

model = VLAE(input_dim=1, latent_dim=32, n_layers=3, use_flow=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

## 3. Training Loop

We track both the Reconstruction Loss (BCE) and the KL Divergence (the information usage).

In [None]:
epochs = 3 # 10+ is ideal for crisp results
model.train()
for epoch in range(1, epochs + 1):
    total_loss, total_kl, total_recon = 0, 0, 0
    for data, _ in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        logits, mu, logvar, z, log_det = model(data)
        loss = loss_function(logits, data, mu, logvar, z, log_det)
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            recon = torch.nn.functional.binary_cross_entropy_with_logits(logits, data, reduction='sum')
            total_recon += recon.item()
            total_kl += (loss.item() - recon.item())
            total_loss += loss.item()
            
    print(f"Epoch {epoch} | Loss: {total_loss/len(train_loader.dataset):.2f} | R: {total_recon/len(train_loader.dataset):.2f} | KL: {total_kl/len(train_loader.dataset):.2f}")

## 4. Visual Analysis

### Posterior Information Usage (Bits Heatmap)
In a collapsed VAE, this plot would be flat (all zero bits). In VLAE, we expect clear spikes in dimensions that captured global features.

In [None]:
model.eval()
test_data, _ = next(iter(test_loader))
with torch.no_grad():
    _, mu, logvar, _, _ = model(test_data[:100].to(device))
    plot_bits_heatmap(mu, logvar)

### Reconstructions
The decoder should accurately reconstruct local textures while relying on $z$ for global orientation.

In [None]:
plot_reconstructions(model, test_data[:8], device)

## 5. Latent Space Traversals

Does $z$ capture high-level semantics? We pick an image, encode it, and sweep one dimension of $z$ to see what changes.

In [None]:
def plot_traversal(model, image, dim_idx, device, n_steps=8):
    model.eval()
    with torch.no_grad():
        mu, logvar = model.encoder(image.to(device))
        z = model.reparameterize(mu, logvar)
        
        traversals = []
        for val in np.linspace(-3, 3, n_steps):
            z_step = z.clone()
            z_step[0, dim_idx] += val
            h_latent = z_step.view(1, -1, 1, 1).expand(-1, -1, 28, 28)
            
            # Fast recon (teacher forcing with zeroes for speed in visualization)
            # Correct way is autoregressive, but here we just check logits
            h = model.initial_conv(torch.zeros_like(image).to(device))
            for layer in model.decoder_layers:
                h = layer(h, h_latent)
            logits = model.final_conv(h)
            traversals.append(torch.sigmoid(logits).cpu().squeeze())
            
    fig, axes = plt.subplots(1, n_steps, figsize=(15, 2))
    for i, img in enumerate(traversals):
        axes[i].imshow(img, cmap='gray')
        axes[i].axis('off')
    plt.show()

print("Traversal for dimension 0:")
plot_traversal(model, test_data[0:1], dim_idx=0, device=device)

## 6. Full Sampling (Slow)

We generate new digits by sampling $z \sim p(z)$ and decoding pixel-by-pixel.

In [None]:
plot_latent_sampling(model, device, latent_dim=32, n_samples=9)

### Summary

- VLAE forces latent usage by restricting the decoder's visual field.
- Normalizing flows (IAF) allow for a high-capacity prior.
- The result is a disentangled model where $z$ is semantic and the decoder is structural/textural.