In [1]:
import sys
from pathlib import Path

current = Path().resolve()
while not (current / "src").exists():
    current = current.parent

sys.path.append(str(current))


In [2]:
import torch
from torch.optim import Adam
from src.models.vae import ConvVAE
from src.training.losses import vae_loss
from src.training.scheduler import BetaScheduler


In [8]:
from src.datasets.grayscale_datasets import get_grayscale_loader

datasets_to_train = ["mnist", "fashion", "emnist"]

for ds in datasets_to_train:
    print(f"\nTraining on {ds.upper()}")

    loader = get_grayscale_loader(
        dataset_name=ds,
        root=current / "data" / "raw",
        batch_size=64
    )

    model = ConvVAE(latent_dim=32).to(device)
    optimizer = Adam(model.parameters(), lr=1e-3)
    beta_scheduler = BetaScheduler(0.0, 1.0, 5000)

    epochs = 5
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for x, _ in loader:
            x = x.to(device)
            optimizer.zero_grad()
            recon, mu, logvar = model(x)
            beta = beta_scheduler.step()
            loss, _, _ = vae_loss(recon, x, mu, logvar, beta)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"{ds} | Epoch {epoch+1} | Loss {total_loss:.0f}")

    # save model
    ckpt_dir = current / "checkpoints" / "grayscale"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), ckpt_dir / f"vae_{ds}.pt")



Training on MNIST
mnist | Epoch 1 | Loss 3844250
mnist | Epoch 2 | Loss 2315004
mnist | Epoch 3 | Loss 2784911
mnist | Epoch 4 | Loss 3186971
mnist | Epoch 5 | Loss 3550053

Training on FASHION
fashion | Epoch 1 | Loss 3730620
fashion | Epoch 2 | Loss 2940892
fashion | Epoch 3 | Loss 3212252
fashion | Epoch 4 | Loss 3457891
fashion | Epoch 5 | Loss 3674759

Training on EMNIST
emnist | Epoch 1 | Loss 6481904
emnist | Epoch 2 | Loss 6142954
emnist | Epoch 3 | Loss 7715624
emnist | Epoch 4 | Loss 8091581
emnist | Epoch 5 | Loss 7957298


In [2]:
# ============================================================
# FashionMNIST â€” Sharper VAE (Demo-biased training)
# Purpose:
# - Improve visual quality for demo
# - Keep same latent space + traversal
# - Used for Streamlit comparison
# ============================================================

# ---- FIX: make sure `src/` is importable in Jupyter ----
import sys
from pathlib import Path

current = Path().resolve()
while not (current / "src").exists():
    current = current.parent

sys.path.append(str(current))

# -------------------------------------------------------

from src.datasets.grayscale_datasets import get_grayscale_loader
from src.models.vae import ConvVAE
from src.training.losses import vae_loss
from torch.optim import Adam
import torch

print("\nTraining SHARP VAE on FASHIONMNIST")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load ONLY FashionMNIST
loader = get_grayscale_loader(
    dataset_name="fashion",
    root=current / "data" / "raw",
    batch_size=64
)

# Model + optimizer
model = ConvVAE(latent_dim=32).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

# Fixed low beta for sharper images
beta = 0.05   # <-- key change

epochs = 15   # reasonable time, visible improvement
model.train()

for epoch in range(epochs):
    total_loss = 0.0
    total_recon = 0.0
    total_kl = 0.0

    for x, _ in loader:
        x = x.to(device)

        optimizer.zero_grad()

        recon, mu, logvar = model(x)
        loss, recon_loss, kl = vae_loss(
            recon, x, mu, logvar, beta
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl.item()

    print(
        f"Epoch [{epoch+1}/{epochs}] | "
        f"Loss: {total_loss:.0f} | "
        f"Recon: {total_recon:.0f} | "
        f"KL: {total_kl:.0f} | "
        f"Beta: {beta}"
    )

# Save as NEW checkpoint (do not overwrite vanilla VAE)
ckpt_dir = current / "checkpoints" / "grayscale"
ckpt_dir.mkdir(parents=True, exist_ok=True)

torch.save(
    model.state_dict(),
    ckpt_dir / "vae_fashion_sharp.pt"
)

print("Saved: checkpoints/grayscale/vae_fashion_sharp.pt")


  warn(



Training SHARP VAE on FASHIONMNIST
Using device: cpu
Epoch [1/15] | Loss: 3420405 | Recon: 3175921 | KL: 4889688 | Beta: 0.05
Epoch [2/15] | Loss: 1978993 | Recon: 1728571 | KL: 5008446 | Beta: 0.05
Epoch [3/15] | Loss: 1807790 | Recon: 1557618 | KL: 5003438 | Beta: 0.05
Epoch [4/15] | Loss: 1723000 | Recon: 1473196 | KL: 4996096 | Beta: 0.05
Epoch [5/15] | Loss: 1667348 | Recon: 1418137 | KL: 4984232 | Beta: 0.05
Epoch [6/15] | Loss: 1629582 | Recon: 1381422 | KL: 4963193 | Beta: 0.05
Epoch [7/15] | Loss: 1599350 | Recon: 1351673 | KL: 4953536 | Beta: 0.05
Epoch [8/15] | Loss: 1575049 | Recon: 1328196 | KL: 4937065 | Beta: 0.05
Epoch [9/15] | Loss: 1556721 | Recon: 1310554 | KL: 4923341 | Beta: 0.05
Epoch [10/15] | Loss: 1539980 | Recon: 1294222 | KL: 4915160 | Beta: 0.05
Epoch [11/15] | Loss: 1526421 | Recon: 1281038 | KL: 4907648 | Beta: 0.05
Epoch [12/15] | Loss: 1514304 | Recon: 1269125 | KL: 4903598 | Beta: 0.05
Epoch [13/15] | Loss: 1502201 | Recon: 1257314 | KL: 4897739 | Beta