# VAE Training

Purpose: Train the VAE on grayscale datasets.

Includes:
- Reconstruction + KL loss
- Beta scheduling
- Model checkpointing

Excludes: classification and evaluation.


In [34]:
import sys
from pathlib import Path
import torch
from torch.optim import Adam

current = Path().resolve()
while not (current / "src").exists():
    current = current.parent

sys.path.append(str(current))
print("Project root:", current)


Project root: /workspace


In [36]:
from src.datasets.grayscale_datasets import get_grayscale_loader
from src.models.vae import ConvVAE
from src.training.losses import vae_loss


In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


In [38]:
latent_dim = 32
batch_size = 32        # CPU-safe
learning_rate = 1e-3
epochs = 10


In [39]:
ckpt_dir = current / "checkpoints" / "grayscale"
ckpt_dir.mkdir(parents=True, exist_ok=True)
print("Checkpoint dir:", ckpt_dir)


Checkpoint dir: /workspace/checkpoints/grayscale


In [41]:
datasets_to_train = ["mnist", "fashion", "emnist"]

for ds in datasets_to_train:
    print("\n==============================")
    print(f"Training BASE VAE on {ds.upper()}")
    print("==============================")

    loader = get_grayscale_loader(
    dataset_name=ds,
    root=current / "data" / "raw",
    batch_size=batch_size
)


    model = ConvVAE(latent_dim=latent_dim).to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)

    beta_scheduler = BetaScheduler(
        start=0.0,
        end=1.0,
        n_steps=5000
    )

    model.train()

    for epoch in range(epochs):
        total_loss = 0.0

        for x, _ in loader:
            x = x.to(device)

            optimizer.zero_grad()
            recon, mu, logvar = model(x)

            beta = beta_scheduler.step()
            loss, _, _ = vae_loss(recon, x, mu, logvar, beta)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(
            f"{ds.upper()} | "
            f"Epoch [{epoch+1}/{epochs}] | "
            f"Loss: {total_loss:.0f}"
        )

    ckpt_path = ckpt_dir / f"vae_{ds}_64.pt"
    torch.save(model.state_dict(), ckpt_path)
    print(f"Saved → {ckpt_path}")



Training BASE VAE on MNIST
MNIST | Epoch [1/10] | Loss: 3046001
MNIST | Epoch [2/10] | Loss: 3118639
MNIST | Epoch [3/10] | Loss: 3823196
MNIST | Epoch [4/10] | Loss: 3878317
MNIST | Epoch [5/10] | Loss: 3790832
MNIST | Epoch [6/10] | Loss: 3739327
MNIST | Epoch [7/10] | Loss: 3693429
MNIST | Epoch [8/10] | Loss: 3655617
MNIST | Epoch [9/10] | Loss: 3625138
MNIST | Epoch [10/10] | Loss: 3601726
Saved → /workspace/checkpoints/grayscale/vae_mnist_64.pt

Training BASE VAE on FASHION
FASHION | Epoch [1/10] | Loss: 3513797
FASHION | Epoch [2/10] | Loss: 3502546
FASHION | Epoch [3/10] | Loss: 3915385
FASHION | Epoch [4/10] | Loss: 3906579
FASHION | Epoch [5/10] | Loss: 3823171
FASHION | Epoch [6/10] | Loss: 3764531
FASHION | Epoch [7/10] | Loss: 3723475
FASHION | Epoch [8/10] | Loss: 3691465
FASHION | Epoch [9/10] | Loss: 3664722
FASHION | Epoch [10/10] | Loss: 3640680
Saved → /workspace/checkpoints/grayscale/vae_fashion_64.pt

Training BASE VAE on EMNIST
EMNIST | Epoch [1/10] | Loss: 68164

In [42]:
print("\n==============================")
print("Training SHARP VAE on ALL DATASETS")
print("==============================")

datasets_to_train = ["mnist", "fashion", "emnist"]

beta = 0.05          # LOW beta = sharper visuals
epochs_sharp = 12    # enough for clarity, CPU-safe

for ds in datasets_to_train:
    print(f"\n--- SHARP VAE on {ds.upper()} ---")

    loader = get_grayscale_loader(
        dataset_name=ds,
        root=current / "data" / "raw",
        batch_size=batch_size
    )

    model = ConvVAE(latent_dim=latent_dim).to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    model.train()

    for epoch in range(epochs_sharp):
        total_loss = 0.0
        total_recon = 0.0
        total_kl = 0.0

        for x, _ in loader:
            x = x.to(device)

            optimizer.zero_grad()
            recon, mu, logvar = model(x)

            loss, recon_loss, kl_loss = vae_loss(
                recon, x, mu, logvar, beta
            )

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_kl += kl_loss.item()

        print(
            f"{ds.upper()} | "
            f"Epoch [{epoch+1}/{epochs_sharp}] | "
            f"Loss: {total_loss:.0f} | "
            f"Recon: {total_recon:.0f} | "
            f"KL: {total_kl:.0f}"
        )

    ckpt_path = ckpt_dir / f"vae_{ds}_sharp_64.pt"
    torch.save(model.state_dict(), ckpt_path)
    print(f"Saved SHARP → {ckpt_path}")



Training SHARP VAE on ALL DATASETS

--- SHARP VAE on MNIST ---
MNIST | Epoch [1/12] | Loss: 2759755 | Recon: 2496538 | KL: 5264341
MNIST | Epoch [2/12] | Loss: 1215020 | Recon: 941084 | KL: 5478731
MNIST | Epoch [3/12] | Loss: 1113698 | Recon: 840369 | KL: 5466592
MNIST | Epoch [4/12] | Loss: 1059621 | Recon: 787309 | KL: 5446244
MNIST | Epoch [5/12] | Loss: 1021342 | Recon: 750243 | KL: 5421996
MNIST | Epoch [6/12] | Loss: 992118 | Recon: 722446 | KL: 5393444
MNIST | Epoch [7/12] | Loss: 968967 | Recon: 700161 | KL: 5376114
MNIST | Epoch [8/12] | Loss: 951817 | Recon: 683818 | KL: 5359982
MNIST | Epoch [9/12] | Loss: 937565 | Recon: 670294 | KL: 5345433
MNIST | Epoch [10/12] | Loss: 924583 | Recon: 657818 | KL: 5335300
MNIST | Epoch [11/12] | Loss: 915095 | Recon: 648927 | KL: 5323354
MNIST | Epoch [12/12] | Loss: 905336 | Recon: 639623 | KL: 5314251
Saved SHARP → /workspace/checkpoints/grayscale/vae_mnist_sharp_64.pt

--- SHARP VAE on FASHION ---
FASHION | Epoch [1/12] | Loss: 30209