# Latent Space Learning Progression

Purpose: Visualize how a fixed latent vector is decoded at different
training stages of the VAE to observe representation stabilization.

This notebook uses saved VAE checkpoints only.
No classifiers or evaluations are involved.


In [None]:
import torch
import matplotlib.pyplot as plt
from pathlib import Path

from src.models.vae import VAE  # your existing VAE implementation

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Dataset to visualize (change only this if you want another dataset)
DATASET = "mnist"   # options: "mnist", "fashion", "emnist"

# Where VAE checkpoints are stored
CHECKPOINT_DIR = "../checkpoints/grayscale"

# Epochs whose checkpoints exist as .pt files
EPOCHS_TO_VISUALIZE = [1, 10, 50, 100]

# Latent dimension used during VAE training
LATENT_DIM = 32


In [None]:
torch.manual_seed(42)  # reproducibility

z_fixed = torch.randn(1, LATENT_DIM).to(device)


In [None]:
def load_vae_checkpoint(epoch):
    ckpt_path = f"{CHECKPOINT_DIR}/{DATASET}_epoch_{epoch}.pt"

    if not Path(ckpt_path).exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    vae = VAE(latent_dim=LATENT_DIM).to(device)
    vae.load_state_dict(torch.load(ckpt_path, map_location=device))
    vae.eval()
    return vae


In [None]:
decoded_images = []

with torch.no_grad():
    for epoch in EPOCHS_TO_VISUALIZE:
        vae = load_vae_checkpoint(epoch)
        img = vae.decode(z_fixed)
        decoded_images.append(img.cpu())


In [None]:
plt.figure(figsize=(12, 2))

for i, img in enumerate(decoded_images):
    plt.subplot(1, len(decoded_images), i + 1)
    plt.imshow(img.squeeze(), cmap="gray")
    plt.title(f"Epoch {EPOCHS_TO_VISUALIZE[i]}")
    plt.axis("off")

plt.suptitle(f"Latent Progression for {DATASET.upper()}", fontsize=12)
plt.show()


In [None]:
from pathlib import Path

output_dir = Path("../outputs/progression")
output_dir.mkdir(parents=True, exist_ok=True)

plt.figure(figsize=(12, 2))
for i, img in enumerate(decoded_images):
    plt.subplot(1, len(decoded_images), i + 1)
    plt.imshow(img.squeeze(), cmap="gray")
    plt.title(f"Epoch {EPOCHS_TO_VISUALIZE[i]}")
    plt.axis("off")

plt.suptitle(f"Latent Progression for {DATASET.upper()}")
plt.savefig(output_dir / f"{DATASET}_latent_progression.png", dpi=200)
plt.close()
