# 🔵 Lesson 4: VAE and Latent Space

We are entering **Module 2: Architecture Deep Dive**. We start with the Variational Autoencoder (VAE).

### The Problem: Images are Big
A 512x512 image has `512 * 512 * 3 (RGB) = 786,432` pixels.
Trying to run attention mechanisms on nearly a million pixels would crash even a 4090.

### The Solution: Latent Space
Images have a lot of redundancy (the sky is mostly blue, the wall is flat). We can compress them.
Stable Diffusion compresses images by **factor of 8**: `512 / 8 = 64`.
The new size is `64 * 64 * 4 (Channels) = 16,384` values.

This is **48x smaller** than the original image! This allows the AI to be fast.

In [None]:
# 1. Setup
import notebook_utils
project_root, device, dtype = notebook_utils.setup_notebook()

from diffusers import AutoencoderKL
import torch
from PIL import Image
import requests
from io import BytesIO
import numpy as np

## 1. Load the VAE

In [None]:
model_id = "runwayml/stable-diffusion-v1-5"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
print("VAE Loaded!")

## 2. Compress an Image (Encode)
Let's take our mountain image from Lesson 1 and squeeze it into latent space.

In [None]:
# Load image
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
image_pil = Image.open(BytesIO(response.content)).convert("RGB").resize((512, 512))

# Preprocess for VAE (Convert to -1 to 1 range)
image_arr = np.array(image_pil).astype(np.float32) / 127.5 - 1.0
image_tensor = torch.from_numpy(image_arr).permute(2, 0, 1).unsqueeze(0).to(device)

print(f"Original Shape: {image_tensor.shape} (Batch, Channel, Height, Width)")

# ENCODE
with torch.no_grad():
    latents = vae.encode(image_tensor).latent_dist.sample()
    
# Scale factor (Magic number for SD 1.5)
latents = latents * 0.18215

print(f"Latent Shape:   {latents.shape}")

## 3. Visualizing Latent Space

The latent tensor has 4 channels. We can't view it as a normal image, but we can visualize the channels separately to see what "concepts" the VAE found.

In [None]:
import matplotlib.pyplot as plt

# Convert to numpy for plotting
l_vis = latents[0].cpu().numpy()

fig, axs = plt.subplots(1, 4, figsize=(20, 5))
for i in range(4):
    axs[i].imshow(l_vis[i], cmap='viridis')
    axs[i].set_title(f"Latent Channel {i+1}")
    axs[i].axis('off')
plt.show()

## 4. Decompress (Decode)

Now we reverse the process. Even though we threw away 98% of the data, the VAE can hallucinate the missing details back.

In [None]:
# Unscale
latents = latents / 0.18215

# DECODE
with torch.no_grad():
    decoded_image = vae.decode(latents).sample

# Post-process (Back to 0-255)
decoded_image = (decoded_image / 2 + 0.5).clamp(0, 1)
decoded_image = decoded_image.cpu().permute(0, 2, 3, 1).numpy()
final_image = Image.fromarray((decoded_image[0] * 255).astype(np.uint8))

notebook_utils.show_image(final_image, title="Reconstructred from Latent")