# üé≤ VAE (Variational Autoencoder): Espa√ßo Latente Probabil√≠stico

**Tutorial de Espa√ßo Latente - Notebook 3**

## üéØ Objetivos
- Entender a diferen√ßa entre Autoencoder e VAE
- Compreender o truque de reparametriza√ß√£o
- Treinar um VAE no MNIST
- Gerar novas amostras
- Comparar com Autoencoder tradicional

## üÜö Autoencoder vs VAE

### Autoencoder Tradicional
```
x ‚Üí Encoder ‚Üí z (ponto fixo) ‚Üí Decoder ‚Üí x'
```
- z √© determin√≠stico
- Espa√ßo latente pode ter "buracos"

### VAE (Variational Autoencoder)
```
x ‚Üí Encoder ‚Üí (Œº, œÉ¬≤) ‚Üí z ~ N(Œº, œÉ¬≤) ‚Üí Decoder ‚Üí x'
```
- z √© uma **distribui√ß√£o** (n√£o um ponto)
- Aprende Œº (m√©dia) e œÉ¬≤ (vari√¢ncia)
- Espa√ßo latente √© cont√≠nuo e suave
- Pode gerar novas amostras!

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from src.models.vae import VAE, vae_loss
from src.utils.data_loader import load_mnist
from src.utils.training import train_vae
from src.utils.visualization import (
    plot_vae_results,
    visualize_latent_space,
    plot_latent_grid,
    plot_training_history
)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")

## üìä Carregando Dados

In [None]:
train_loader, val_loader, test_loader = load_mnist(batch_size=128)
print(f"Data loaded: {len(train_loader.dataset)} train samples")

## üèóÔ∏è Criando o VAE

In [None]:
vae = VAE(
    input_dim=784,
    latent_dim=2,  # 2D para visualiza√ß√£o
    hidden_dims=[512, 256]
)

print(vae)
print(f"\nParameters: {sum(p.numel() for p in vae.parameters()):,}")

## üéì Fun√ß√£o de Perda do VAE

$$\mathcal{L} = \underbrace{\text{BCE}(x, x')}_{\text{Reconstruction}} + \beta \cdot \underbrace{\text{KL}(q(z|x) || p(z))}_{\text{Regulariza√ß√£o}}$$

- **Reconstruction Loss**: Qu√£o bem reconstru√≠mos a imagem
- **KL Divergence**: For√ßa z a seguir N(0,1)
- **Œ≤**: Peso do termo KL (Œ≤=1 para VAE padr√£o)

In [None]:
# Demonstra√ß√£o da loss
data, _ = next(iter(train_loader))
data = data.view(-1, 784).to(DEVICE)
vae = vae.to(DEVICE)

# Forward pass
x_recon, mu, logvar, z = vae(data[:4])
loss_dict = vae_loss(x_recon, data[:4], mu, logvar, beta=1.0)

print("Loss components (before training):")
print(f"  Total: {loss_dict['total'].item():.2f}")
print(f"  Reconstruction: {loss_dict['reconstruction'].item():.2f}")
print(f"  KL Divergence: {loss_dict['kl'].item():.2f}")

## üéì Treinamento

In [None]:
# Treina VAE
history = train_vae(
    model=vae,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=30,
    learning_rate=1e-3,
    beta=1.0,
    device=DEVICE,
    early_stopping_patience=7,
    verbose=True
)

# Plota hist√≥rico
plot_training_history(history)

## üìä Visualiza√ß√£o Completa dos Resultados

In [None]:
# Overview completo
plot_vae_results(vae, test_loader, device=DEVICE)

## üåå Explorando o Manifold Latente

Uma das grandes vantagens do VAE: todo ponto do espa√ßo latente gera uma imagem v√°lida!

In [None]:
# Grid 2D do espa√ßo latente
plot_latent_grid(vae, n_samples=20, latent_range=(-3, 3), device=DEVICE)

## üé® Gerando Novas Amostras

VAE pode gerar d√≠gitos totalmente novos!

In [None]:
# Gera amostras do prior N(0,1)
n_samples = 16
samples = vae.sample(num_samples=n_samples, device=DEVICE)
samples = samples.view(-1, 28, 28).cpu()

fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i], cmap='gray')
    ax.axis('off')
plt.suptitle('Generated Digits (sampled from N(0,1))', fontweight='bold')
plt.tight_layout()
plt.show()

## üî¨ Compara√ß√£o: VAE vs Autoencoder

In [None]:
from src.models.autoencoder import Autoencoder

# Treina autoencoder para compara√ß√£o
ae = Autoencoder(input_dim=784, latent_dim=2, hidden_dims=[512, 256, 128])

from src.utils.training import train_model
history_ae = train_model(
    model=ae,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=15,
    device=DEVICE,
    verbose=False
)

# Visualiza ambos os espa√ßos latentes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Autoencoder
ae.eval()
latents_ae = []
labels_ae = []
with torch.no_grad():
    for data, labels in test_loader:
        data = data.view(-1, 784).to(DEVICE)
        _, latent = ae(data)
        latents_ae.append(latent.cpu())
        labels_ae.append(labels)
latents_ae = torch.cat(latents_ae).numpy()[:1000]
labels_ae = torch.cat(labels_ae).numpy()[:1000]

ax1.scatter(latents_ae[:, 0], latents_ae[:, 1], c=labels_ae, cmap='tab10', alpha=0.5, s=5)
ax1.set_title('Autoencoder Latent Space', fontweight='bold', fontsize=14)
ax1.set_xlabel('z[0]')
ax1.set_ylabel('z[1]')
ax1.grid(True, alpha=0.3)

# VAE
vae.eval()
latents_vae = []
labels_vae = []
with torch.no_grad():
    for data, labels in test_loader:
        data = data.view(-1, 784).to(DEVICE)
        mu, _ = vae.encode(data)
        latents_vae.append(mu.cpu())
        labels_vae.append(labels)
latents_vae = torch.cat(latents_vae).numpy()[:1000]
labels_vae = torch.cat(labels_vae).numpy()[:1000]

scatter = ax2.scatter(latents_vae[:, 0], latents_vae[:, 1], c=labels_vae, cmap='tab10', alpha=0.5, s=5)
ax2.set_title('VAE Latent Space', fontweight='bold', fontsize=14)
ax2.set_xlabel('z[0]')
ax2.set_ylabel('z[1]')
ax2.grid(True, alpha=0.3)

plt.colorbar(scatter, ax=ax2, label='Digit')
plt.tight_layout()
plt.show()

print("\nüìä Observa√ß√µes:")
print("- VAE tem espa√ßo mais 'suave' e cont√≠nuo")
print("- Autoencoder pode ter 'buracos' entre clusters")
print("- VAE force distribui√ß√£o Gaussiana ‚Üí melhor para gera√ß√£o")

## üìù Resumo

‚úÖ VAE aprende distribui√ß√£o probabil√≠stica no espa√ßo latente  
‚úÖ Loss = Reconstruction + KL Divergence  
‚úÖ Pode gerar novas amostras facilmente  
‚úÖ Espa√ßo latente mais cont√≠nuo que Autoencoder  
‚úÖ Trade-off entre qualidade e regulariza√ß√£o (controlado por Œ≤)  

---

## üöÄ Pr√≥ximo Notebook

No **Notebook 04**, vamos explorar Beta-VAE e o conceito de disentanglement!

‚Üí‚Üí `04_beta_vae_experimento.ipynb`