In [1]:
# Cell 1: Imports
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, random_split

from gan_utils.io import H5ImageDataset
from gan.model import Generator, Discriminator, Encoder, DiscriminatorFeatures
from gan.trainer import Trainer

import matplotlib.pyplot as plt




In [2]:
# SECTION 2 — Utilities

def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def ensure_dir(path: str):
    if path and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)


def make_dataloaders(df, h5_path, batch_size=32, workers=16, val_split=0.1):
    """
    df: pandas DataFrame with file_name column that matches keys in HDF5
    h5_path: path to HDF5 file containing images
    returns: train_loader, val_loader
    """
    dataset = H5ImageDataset(df, h5_path, transforms=None)
    n_total = len(dataset)
    n_val = int(val_split * n_total)
    n_train = n_total - n_val

    g = torch.Generator().manual_seed(123)
    train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=g)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=workers,
        pin_memory=True,
    )
    return train_loader, val_loader


def build_models(n_z=256, image_size=512, in_channels=3):
    """
    Construct Generator and Discriminator.
    """
    gen = Generator(
        n_z=n_z,
        input_filt=512,
        final_size=image_size,
        out_channels=in_channels,
        norm=False,
        pool=False
    )

    disc = Discriminator(
        in_channels=in_channels,
        n_layers=6,
        input_size=image_size,
        norm=False,
        pool=False
    )

    return gen, disc


def train_gan(
    gen,
    disc,
    train_loader,
    val_loader,
    outdir,
    epochs=50,
    d_lr=1e-4,
    g_lr=1e-4,
    save_freq=10,
    lr_decay=None,
    decay_freq=5,
    device='cuda'
):
    """Train WGAN-GP using Trainer and save checkpoints + loss metrics."""
    ensure_dir(outdir)

    gen = gen.to(device)
    disc = disc.to(device)

    trainer = Trainer(
        generator=gen,
        discriminator=disc,
        savefolder=outdir,
        device=device
    )

    G_loss_ep, D_loss_ep = trainer.train(
        train_data=train_loader,
        val_data=val_loader,
        epochs=epochs,
        dsc_learning_rate=d_lr,
        gen_learning_rate=g_lr,
        save_freq=save_freq,
        lr_decay=lr_decay,
        decay_freq=decay_freq
    )

    return G_loss_ep, D_loss_ep


In [3]:
# SECTION 3 — Config

h5_path = "anomaly.h5"
parquet_path = "mask_labels_anomaly.gzip"
outdir = "./outputs"
gan_ckpt_dir = os.path.join(outdir, "gan_checkpoints")

batch_size = 32
workers = 16
val_split = 0.1
image_size = 512
n_z = 256
gan_epochs = 50
g_lr = 1e-4
d_lr = 1e-4

device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42

set_seed(seed)
ensure_dir(outdir)
ensure_dir(gan_ckpt_dir)


In [4]:
# SECTION 4 — Load metadata + dataloaders

df = pd.read_parquet(parquet_path)

train_loader, val_loader = make_dataloaders(
    df,
    h5_path,
    batch_size=batch_size,
    workers=workers,
    val_split=val_split,
)


In [5]:
# SECTION 5 — Build models

# === Rebuild models ===
gen, disc = build_models(n_z=n_z, image_size=image_size, in_channels=3)
gen = gen.to(device)
disc = disc.to(device)

# Quick sanity check
imgs = next(iter(train_loader))[0]
print("Real batch shape:", imgs.shape)

z = torch.randn(4, gen.n_z).to(device)
fake = gen(z)
print("Fake batch shape:", fake.shape)





Real batch shape: torch.Size([3, 512, 512])
Fake batch shape: torch.Size([4, 3, 512, 512])


In [6]:
# === Find latest checkpoints ===
import glob, os, torch

ckpts_g = sorted(glob.glob(os.path.join(gan_ckpt_dir, "generator_ep_*.pth")))
ckpts_d = sorted(glob.glob(os.path.join(gan_ckpt_dir, "discriminator_ep_*.pth")))

latest_g = ckpts_g[-1]
latest_d = ckpts_d[-1]

# Extract epoch number from filename
last_epoch = int(latest_g.split("_ep_")[1].split(".")[0])
print("Last completed epoch:", last_epoch)

# === Load weights ===
gen.load_state_dict(torch.load(latest_g, map_location=device))
disc.load_state_dict(torch.load(latest_d, map_location=device))

print("Loaded generator + discriminator.")


Last completed epoch: 50


  gen.load_state_dict(torch.load(latest_g, map_location=device))


Loaded generator + discriminator.


  disc.load_state_dict(torch.load(latest_d, map_location=device))


In [7]:
# === Create Trainer ===
trainer = Trainer(
    generator=gen,
    discriminator=disc,
    savefolder=gan_ckpt_dir,
    device=device
)

# Tell Trainer where to resume
trainer.start = last_epoch + 1

print(f"Resuming from epoch {trainer.start}")

# === Resume training ===
G_loss_ep, D_loss_ep = trainer.train(
    train_data=train_loader,
    val_data=val_loader,
    epochs=last_epoch + 50,   # train 50 more epochs
    dsc_learning_rate=d_lr,
    gen_learning_rate=g_lr,
    save_freq=10
)

print("Finished resumed training.")


  self.scaler_g = GradScaler()
  self.scaler_d = GradScaler()


Resuming from epoch 51
Epoch 51 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


  with autocast():
  with autocast():
Training: 100%|██████████| 2813/2813 [10:29<00:00,  4.47it/s, gen: -1.46e+02 disc: -3.22e+00 w_dist: -5.57e+00 grad_penalty: 2.35e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.38it/s, gen: 1.84e+03 disc: 4.72e+00 w_dist: 4.65e+00 grad_penalty: 6.73e-02]


Epoch 52 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:28<00:00,  4.47it/s, gen: -4.86e+01 disc: -3.50e+00 w_dist: -6.16e+00 grad_penalty: 2.66e+00]
Validation: 100%|██████████| 313/313 [00:40<00:00,  7.71it/s, gen: 1.15e+03 disc: -1.06e+00 w_dist: -1.08e+00 grad_penalty: 2.31e-02]


Epoch 53 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:43<00:00,  4.37it/s, gen: -2.93e+02 disc: -3.58e+00 w_dist: -6.64e+00 grad_penalty: 3.06e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.47it/s, gen: -1.86e+03 disc: -1.82e+01 w_dist: -1.95e+01 grad_penalty: 1.30e+00]


Epoch 54 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:21<00:00,  4.52it/s, gen: 2.32e+02 disc: -3.29e+00 w_dist: -5.99e+00 grad_penalty: 2.70e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.34it/s, gen: -1.35e+03 disc: -4.21e+00 w_dist: -6.90e+00 grad_penalty: 2.69e+00]


Epoch 55 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:23<00:00,  4.51it/s, gen: 1.22e+02 disc: -3.49e+00 w_dist: -6.29e+00 grad_penalty: 2.80e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.28it/s, gen: -6.25e+03 disc: 2.84e+01 w_dist: -5.53e+01 grad_penalty: 8.37e+01]


Epoch 56 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:30<00:00,  4.46it/s, gen: 6.77e+00 disc: -3.17e+00 w_dist: -6.12e+00 grad_penalty: 2.94e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.25it/s, gen: -1.36e+01 disc: -9.85e-01 w_dist: -2.75e+00 grad_penalty: 1.76e+00]


Epoch 57 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:25<00:00,  4.50it/s, gen: -1.56e+02 disc: -3.14e+00 w_dist: -5.89e+00 grad_penalty: 2.76e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.40it/s, gen: -2.24e+03 disc: -2.33e+01 w_dist: -2.46e+01 grad_penalty: 1.31e+00]


Epoch 58 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.48it/s, gen: -2.28e+02 disc: -3.37e+00 w_dist: -6.23e+00 grad_penalty: 2.85e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.29it/s, gen: -3.12e+03 disc: -3.36e+00 w_dist: -6.84e+00 grad_penalty: 3.47e+00]


Epoch 59 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.49it/s, gen: 1.52e+02 disc: -3.19e+00 w_dist: -6.25e+00 grad_penalty: 3.06e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.32it/s, gen: -1.74e+03 disc: 1.52e+00 w_dist: 1.49e+00 grad_penalty: 3.04e-02] 


Epoch 60 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:24<00:00,  4.51it/s, gen: -6.33e+01 disc: -2.92e+00 w_dist: -5.77e+00 grad_penalty: 2.85e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.24it/s, gen: -2.23e+03 disc: 9.65e+00 w_dist: 9.53e+00 grad_penalty: 1.21e-01]


Saving to ./outputs/gan_checkpoints//generator_ep_060.pth and ./outputs/gan_checkpoints//discriminator_ep_060.pth
Epoch 61 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:26<00:00,  4.49it/s, gen: 4.42e+00 disc: -3.48e+00 w_dist: -6.54e+00 grad_penalty: 3.05e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.38it/s, gen: 2.70e+03 disc: -4.04e+00 w_dist: -4.12e+00 grad_penalty: 8.38e-02]


Epoch 62 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:32<00:00,  4.45it/s, gen: -1.55e+02 disc: -2.94e+00 w_dist: -5.74e+00 grad_penalty: 2.80e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.26it/s, gen: -8.27e+02 disc: 4.05e+00 w_dist: -1.72e+01 grad_penalty: 2.12e+01]


Epoch 63 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.49it/s, gen: 1.50e+02 disc: -3.32e+00 w_dist: -6.08e+00 grad_penalty: 2.77e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.26it/s, gen: -2.50e+03 disc: 7.31e+00 w_dist: 7.18e+00 grad_penalty: 1.33e-01]


Epoch 64 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.48it/s, gen: 1.06e+02 disc: -3.55e+00 w_dist: -6.47e+00 grad_penalty: 2.92e+00]
Validation: 100%|██████████| 313/313 [00:32<00:00,  9.49it/s, gen: 2.32e+03 disc: -3.47e+00 w_dist: -4.01e+00 grad_penalty: 5.39e-01]


Epoch 65 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:31<00:00,  4.45it/s, gen: 2.73e+02 disc: -3.37e+00 w_dist: -6.18e+00 grad_penalty: 2.80e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.36it/s, gen: 4.80e+02 disc: 2.29e+00 w_dist: 6.03e-01 grad_penalty: 1.69e+00]


Epoch 66 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:28<00:00,  4.48it/s, gen: 9.35e+01 disc: -3.14e+00 w_dist: -6.00e+00 grad_penalty: 2.86e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.35it/s, gen: -3.15e+02 disc: -9.37e+00 w_dist: -9.38e+00 grad_penalty: 1.01e-02]


Epoch 67 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:31<00:00,  4.45it/s, gen: 2.58e+02 disc: -3.39e+00 w_dist: -6.16e+00 grad_penalty: 2.77e+00]
Validation: 100%|██████████| 313/313 [00:34<00:00,  9.10it/s, gen: -4.38e+03 disc: 1.91e+01 w_dist: 1.91e+01 grad_penalty: 4.46e-02]


Epoch 68 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:32<00:00,  4.45it/s, gen: -2.66e+01 disc: -3.07e+00 w_dist: -5.68e+00 grad_penalty: 2.61e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.36it/s, gen: 1.31e+02 disc: -1.10e+01 w_dist: -1.18e+01 grad_penalty: 7.64e-01]


Epoch 69 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:29<00:00,  4.47it/s, gen: 1.32e+02 disc: -2.75e+00 w_dist: -5.52e+00 grad_penalty: 2.77e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.29it/s, gen: -2.69e+03 disc: 7.46e+00 w_dist: 7.06e+00 grad_penalty: 3.96e-01]


Epoch 70 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.48it/s, gen: 1.35e+02 disc: -3.12e+00 w_dist: -5.58e+00 grad_penalty: 2.47e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.26it/s, gen: 1.74e+03 disc: 4.30e+00 w_dist: 2.18e+00 grad_penalty: 2.12e+00]


Saving to ./outputs/gan_checkpoints//generator_ep_070.pth and ./outputs/gan_checkpoints//discriminator_ep_070.pth
Epoch 71 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.48it/s, gen: 5.23e+01 disc: -3.21e+00 w_dist: -5.77e+00 grad_penalty: 2.56e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.29it/s, gen: 2.86e+03 disc: -1.87e+01 w_dist: -2.19e+01 grad_penalty: 3.12e+00]


Epoch 72 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.48it/s, gen: -2.70e+02 disc: -3.28e+00 w_dist: -5.87e+00 grad_penalty: 2.60e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.43it/s, gen: 4.66e+03 disc: -2.54e+01 w_dist: -3.16e+01 grad_penalty: 6.24e+00]


Epoch 73 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:28<00:00,  4.48it/s, gen: 4.49e+02 disc: -3.06e+00 w_dist: -5.91e+00 grad_penalty: 2.85e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.46it/s, gen: -2.03e+03 disc: 1.60e+01 w_dist: 1.43e+01 grad_penalty: 1.70e+00]


Epoch 74 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.48it/s, gen: 1.80e+02 disc: -3.35e+00 w_dist: -6.33e+00 grad_penalty: 2.98e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.37it/s, gen: -3.41e+03 disc: -7.22e+00 w_dist: -8.02e+00 grad_penalty: 7.96e-01]


Epoch 75 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:24<00:00,  4.51it/s, gen: -2.02e+01 disc: -3.26e+00 w_dist: -5.76e+00 grad_penalty: 2.49e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.36it/s, gen: 1.63e+03 disc: -1.83e+01 w_dist: -1.88e+01 grad_penalty: 5.03e-01]


Epoch 76 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:28<00:00,  4.48it/s, gen: -3.03e+01 disc: -3.34e+00 w_dist: -5.83e+00 grad_penalty: 2.49e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.37it/s, gen: 1.88e+02 disc: 5.41e-01 w_dist: -4.39e-01 grad_penalty: 9.80e-01] 


Epoch 77 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:31<00:00,  4.46it/s, gen: 1.90e+02 disc: -3.29e+00 w_dist: -5.94e+00 grad_penalty: 2.65e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.23it/s, gen: 1.21e+03 disc: 9.60e+00 w_dist: 9.07e+00 grad_penalty: 5.25e-01]


Epoch 78 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:33<00:00,  4.44it/s, gen: 7.03e+01 disc: -2.84e+00 w_dist: -5.63e+00 grad_penalty: 2.78e+00]
Validation: 100%|██████████| 313/313 [00:34<00:00,  9.07it/s, gen: -6.94e+02 disc: 9.33e+00 w_dist: 9.30e+00 grad_penalty: 2.59e-02]


Epoch 79 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:30<00:00,  4.46it/s, gen: -7.20e+01 disc: -2.89e+00 w_dist: -5.45e+00 grad_penalty: 2.56e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.32it/s, gen: -4.54e+03 disc: -2.96e+01 w_dist: -3.46e+01 grad_penalty: 5.02e+00]


Epoch 80 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:36<00:00,  4.42it/s, gen: -1.34e+02 disc: -2.94e+00 w_dist: -5.27e+00 grad_penalty: 2.32e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.32it/s, gen: 3.47e+03 disc: 2.96e+00 w_dist: 2.95e+00 grad_penalty: 9.45e-03]


Saving to ./outputs/gan_checkpoints//generator_ep_080.pth and ./outputs/gan_checkpoints//discriminator_ep_080.pth
Epoch 81 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:29<00:00,  4.47it/s, gen: -7.32e+01 disc: -3.05e+00 w_dist: -5.72e+00 grad_penalty: 2.67e+00]
Validation: 100%|██████████| 313/313 [00:34<00:00,  9.13it/s, gen: 3.46e+03 disc: 2.85e+00 w_dist: 2.75e+00 grad_penalty: 9.95e-02]


Epoch 82 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:30<00:00,  4.46it/s, gen: -1.84e+02 disc: -3.07e+00 w_dist: -5.91e+00 grad_penalty: 2.84e+00]
Validation: 100%|██████████| 313/313 [00:32<00:00,  9.51it/s, gen: 1.77e+03 disc: 5.38e+00 w_dist: 4.42e+00 grad_penalty: 9.58e-01]


Epoch 83 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:28<00:00,  4.48it/s, gen: -6.14e+01 disc: -2.82e+00 w_dist: -5.24e+00 grad_penalty: 2.42e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.37it/s, gen: -1.82e+03 disc: -1.33e+01 w_dist: -1.34e+01 grad_penalty: 8.60e-02]


Epoch 84 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:26<00:00,  4.49it/s, gen: 5.52e+01 disc: -3.20e+00 w_dist: -5.90e+00 grad_penalty: 2.70e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.21it/s, gen: 2.44e+03 disc: 1.06e+01 w_dist: 1.05e+01 grad_penalty: 1.28e-01]


Epoch 85 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:26<00:00,  4.49it/s, gen: 5.98e+01 disc: -2.81e+00 w_dist: -5.42e+00 grad_penalty: 2.61e+00] 
Validation: 100%|██████████| 313/313 [00:32<00:00,  9.53it/s, gen: 1.24e+03 disc: -2.47e+00 w_dist: -1.21e+01 grad_penalty: 9.67e+00]


Epoch 86 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.48it/s, gen: -2.23e+01 disc: -2.67e+00 w_dist: -5.14e+00 grad_penalty: 2.47e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.32it/s, gen: -3.09e+03 disc: 4.78e+00 w_dist: 3.96e+00 grad_penalty: 8.14e-01]


Epoch 87 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:29<00:00,  4.47it/s, gen: -3.13e+01 disc: -3.00e+00 w_dist: -5.46e+00 grad_penalty: 2.46e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.44it/s, gen: 7.59e+01 disc: -1.35e+01 w_dist: -1.41e+01 grad_penalty: 5.57e-01]


Epoch 88 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:30<00:00,  4.46it/s, gen: 7.20e+01 disc: -2.74e+00 w_dist: -5.23e+00 grad_penalty: 2.49e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.25it/s, gen: -3.20e+03 disc: -9.60e+00 w_dist: -9.99e+00 grad_penalty: 3.92e-01]


Epoch 89 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:28<00:00,  4.48it/s, gen: 9.68e+01 disc: -2.68e+00 w_dist: -5.20e+00 grad_penalty: 2.52e+00] 
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.46it/s, gen: 1.54e+03 disc: -2.19e+00 w_dist: -3.78e+00 grad_penalty: 1.59e+00]


Epoch 90 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.48it/s, gen: -1.17e+02 disc: -2.78e+00 w_dist: -5.20e+00 grad_penalty: 2.42e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.23it/s, gen: 2.77e+03 disc: 1.75e+00 w_dist: 1.72e+00 grad_penalty: 3.09e-02] 


Saving to ./outputs/gan_checkpoints//generator_ep_090.pth and ./outputs/gan_checkpoints//discriminator_ep_090.pth
Epoch 91 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:30<00:00,  4.46it/s, gen: 1.81e+02 disc: -2.57e+00 w_dist: -4.79e+00 grad_penalty: 2.22e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.32it/s, gen: 2.38e+03 disc: -2.59e+00 w_dist: -2.64e+00 grad_penalty: 5.28e-02]


Epoch 92 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:27<00:00,  4.49it/s, gen: 1.23e+02 disc: -2.35e+00 w_dist: -4.72e+00 grad_penalty: 2.37e+00]
Validation: 100%|██████████| 313/313 [00:33<00:00,  9.33it/s, gen: 7.28e+02 disc: 3.78e+00 w_dist: 2.77e+00 grad_penalty: 1.02e+00]


Epoch 93 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:32<00:00,  4.45it/s, gen: -4.78e+01 disc: -3.10e+00 w_dist: -5.63e+00 grad_penalty: 2.53e+00]
Validation: 100%|██████████| 313/313 [00:34<00:00,  9.18it/s, gen: 2.46e+03 disc: 1.32e+00 w_dist: 1.66e-01 grad_penalty: 1.16e+00]


Epoch 94 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:23<00:00,  4.51it/s, gen: 3.91e+01 disc: -2.92e+00 w_dist: -5.35e+00 grad_penalty: 2.43e+00] 
Validation: 100%|██████████| 313/313 [00:32<00:00,  9.59it/s, gen: -2.53e+03 disc: 1.42e+01 w_dist: 1.41e+01 grad_penalty: 1.24e-01]


Epoch 95 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:31<00:00,  4.45it/s, gen: -3.64e+01 disc: -2.57e+00 w_dist: -5.14e+00 grad_penalty: 2.57e+00]
Validation: 100%|██████████| 313/313 [00:31<00:00,  9.83it/s, gen: -3.30e+03 disc: -6.54e+00 w_dist: -6.94e+00 grad_penalty: 3.98e-01]


Epoch 96 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:07<00:00,  4.63it/s, gen: 8.71e+01 disc: -2.82e+00 w_dist: -5.29e+00 grad_penalty: 2.47e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.25it/s, gen: -7.60e+02 disc: -2.89e+00 w_dist: -1.07e+01 grad_penalty: 7.80e+00]


Epoch 97 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training:  33%|███▎      | 924/2813 [03:19<06:40,  4.72it/s, gen: 7.79e+01 disc: -2.66e+00 w_dist: -5.03e+00 grad_penalty: 2.37e+00] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Training: 100%|██████████| 2813/2813 [10:08<00:00,  4.63it/s, gen: 1.30e+02 disc: -2.66e+00 w_dist: -5.54e+00 grad_penalty: 2.87e+00] 
Validation: 100%|██████████| 313/313 [00:29<00:00, 10.44it/s, gen: 3.73e+03 disc: -4.57e+00 w_dist: -4.93e+00 grad_penalty: 3.64e-01]


Epoch 99 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:08<00:00,  4.63it/s, gen: 3.60e+01 disc: -2.48e+00 w_dist: -4.90e+00 grad_penalty: 2.42e+00] 
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.37it/s, gen: -1.40e+03 disc: -1.50e+01 w_dist: -1.51e+01 grad_penalty: 4.84e-02]


Epoch 100 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:10<00:00,  4.60it/s, gen: 3.83e+01 disc: -2.89e+00 w_dist: -5.44e+00 grad_penalty: 2.54e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.17it/s, gen: -3.22e+03 disc: -2.90e+01 w_dist: -3.19e+01 grad_penalty: 2.87e+00]


Saving to ./outputs/gan_checkpoints//generator_ep_100.pth and ./outputs/gan_checkpoints//discriminator_ep_100.pth
Finished resumed training.


In [8]:
# SECTION 7 — Plot GAN losses

plt.figure()
plt.plot(G_loss_ep, label="Gen loss", color="red")
plt.plot(D_loss_ep, label="Disc loss", color="blue")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(outdir, "gan_losses.png"))
plt.close()


In [9]:
# Cell 8: Save final generator and discriminator

torch.save(gen.state_dict(), os.path.join(gan_ckpt_dir, "generator_final.pth"))
torch.save(disc.state_dict(), os.path.join(gan_ckpt_dir, "discriminator_final.pth"))

print("Done training GAN. Checkpoints in:", gan_ckpt_dir)


Done training GAN. Checkpoints in: ./outputs/gan_checkpoints


# 