In [1]:
import sys
from pathlib import Path

current = Path().resolve()
while not (current / "src").exists():
    current = current.parent
sys.path.append(str(current))

import torch
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from src.models.vae import ConvVAE
from src.models.discriminator import Discriminator
from src.training.losses import vae_loss


  warn(


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.FashionMNIST(
    root=current / "data" / "raw",
    train=True,
    download=True,
    transform=transform
)

loader = DataLoader(dataset, batch_size=64, shuffle=True)


Device: cpu


In [3]:
vae = ConvVAE(latent_dim=32).to(device)
disc = Discriminator().to(device)

opt_vae = Adam(vae.parameters(), lr=1e-3)
opt_disc = Adam(disc.parameters(), lr=1e-4)

bce = torch.nn.BCELoss()

beta = 0.05   # keep latent structure
lambda_gan = 0.1  # GAN influence (important)


In [4]:
epochs = 10   # start small; increase if stable

for epoch in range(epochs):
    vae.train()
    disc.train()

    total_vae, total_gan = 0, 0

    for x, _ in loader:
        x = x.to(device)
        batch_size = x.size(0)

        # --------------------
        # Train Discriminator
        # --------------------
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        with torch.no_grad():
            recon, _, _ = vae(x)

        real_pred = disc(x)
        fake_pred = disc(recon.detach())

        d_loss = bce(real_pred, real_labels) + bce(fake_pred, fake_labels)

        opt_disc.zero_grad()
        d_loss.backward()
        opt_disc.step()

        # --------------------
        # Train VAE (Generator)
        # --------------------
        recon, mu, logvar = vae(x)

        vae_recon_loss, _, _ = vae_loss(
            recon, x, mu, logvar, beta
        )

        gan_loss = bce(disc(recon), real_labels)

        total_loss = vae_recon_loss + lambda_gan * gan_loss

        opt_vae.zero_grad()
        total_loss.backward()
        opt_vae.step()

        total_vae += vae_recon_loss.item()
        total_gan += gan_loss.item()

    print(
        f"Epoch [{epoch+1}/{epochs}] | "
        f"VAE Loss: {total_vae:.0f} | "
        f"GAN Loss: {total_gan:.3f}"
    )


Epoch [1/10] | VAE Loss: 3538066 | GAN Loss: 3480.733
Epoch [2/10] | VAE Loss: 2008665 | GAN Loss: 5491.484
Epoch [3/10] | VAE Loss: 1818541 | GAN Loss: 6461.207
Epoch [4/10] | VAE Loss: 1728517 | GAN Loss: 6869.337
Epoch [5/10] | VAE Loss: 1675413 | GAN Loss: 7642.370
Epoch [6/10] | VAE Loss: 1634277 | GAN Loss: 8052.320
Epoch [7/10] | VAE Loss: 1602920 | GAN Loss: 7859.265
Epoch [8/10] | VAE Loss: 1576469 | GAN Loss: 7527.851
Epoch [9/10] | VAE Loss: 1557347 | GAN Loss: 8252.489
Epoch [10/10] | VAE Loss: 1539527 | GAN Loss: 8822.485


In [5]:
ckpt_dir = current / "checkpoints" / "grayscale"
ckpt_dir.mkdir(parents=True, exist_ok=True)

torch.save(
    vae.state_dict(),
    ckpt_dir / "vae_fashion_gan.pt"
)

print("Saved: vae_fashion_gan.pt")


Saved: vae_fashion_gan.pt
