In [1]:
# Cell 1: Imports
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, random_split

from gan_utils.io import H5ImageDataset
from gan.model import Generator, Discriminator, Encoder, DiscriminatorFeatures
from gan.trainer import Trainer

import matplotlib.pyplot as plt




In [2]:
# Cell 2: Utilities

def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def ensure_dir(path: str):
    if path and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)


def make_dataloaders(df, h5_path, batch_size=32, workers=16, val_split=0.1):
    """
    df: pandas DataFrame with file_name column that matches keys in HDF5
    h5_path: path to HDF5 file containing images
    returns: train_loader, val_loader
    """
    dataset = H5ImageDataset(df, h5_path, transforms=None)
    n_total = len(dataset)
    n_val = int(val_split * n_total)
    n_train = n_total - n_val

    g = torch.Generator().manual_seed(123)
    train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=g)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=workers,
        pin_memory=True,
    )
    return train_loader, val_loader


def build_models(n_z=256, image_size=512, in_channels=3):
    """
    Construct Generator, Discriminator, Encoder, and DiscriminatorFeatures.
    """
    gen = Generator(
        n_z=n_z, input_filt=512, final_size=image_size, out_channels=in_channels, norm=False, pool=False
    )
    disc = Discriminator(in_channels=in_channels, n_layers=6, input_size=image_size, norm=False, pool=False,
    )
#     enc = Encoder(
#         in_channels=in_channels,
#         n_z=n_z,
#         n_layers=5,
#         input_filt=64,
#     )
#     disc_feat = DiscriminatorFeatures(
#         in_channels=in_channels,
#         n_layers=6,
#         input_filt=64,
#     )
    return gen, disc


def train_gan(gen, disc, train_loader, val_loader, outdir, epochs=50, d_lr=1e-4, g_lr=1e-4, save_freq=10, lr_decay=None, decay_freq=5, device='cuda'): 
    """ Train WGAN-GP using Trainer and save checkpoints + loss metrics. """ 
    ensure_dir(outdir) 
    gen = gen.to(device) 
    disc = disc.to(device) 
    trainer = Trainer(generator=gen, discriminator=disc, savefolder=outdir, device=device) 
    G_loss_ep, D_loss_ep = trainer.train( train_data=train_loader, val_data=val_loader, epochs=epochs, dsc_learning_rate=d_lr, gen_learning_rate=g_lr, save_freq=save_freq, lr_decay=lr_decay, decay_freq=decay_freq ) 
    return G_loss_ep, D_loss_ep




In [3]:
# Cell 3: Config

h5_path = "anomaly.h5"
parquet_path = "mask_labels_anomaly.gzip"
outdir = "./outputs"
gan_ckpt_dir = os.path.join(outdir, "gan_checkpoints")

batch_size = 32
workers = 32
val_split = 0.1
image_size = 512
n_z = 256
gan_epochs = 50
g_lr = 1e-4
d_lr = 1e-4
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42

set_seed(seed)
ensure_dir(outdir)
ensure_dir(gan_ckpt_dir)




In [4]:
# Cell 4: Load metadata + dataloaders

df = pd.read_parquet(parquet_path)
train_loader, val_loader = make_dataloaders(
    df,
    h5_path,
    batch_size=batch_size,
    workers=workers,
    val_split=val_split,
)






In [5]:
# Cell 5: Build models (we only use gen, disc here)

gen, disc = build_models(
    n_z=n_z,
    image_size=image_size,
    in_channels=3,
)

gen = gen.to(device)
disc = disc.to(device)


# Quick sanity check on shapes
imgs = next(iter(train_loader))[0]
print("Real batch shape:", imgs.shape)

z = torch.randn(4, gen.n_z).to(device)
fake = gen(z)
print("Fake batch shape:", fake.shape)




Real batch shape: torch.Size([3, 512, 512])
Fake batch shape: torch.Size([4, 3, 512, 512])


In [None]:
# Cell 6: Train GAN

print("Training GAN (WGAN-GP)...")
G_loss_ep, D_loss_ep = train_gan(gen, disc, train_loader, val_loader, gan_ckpt_dir, epochs=gan_epochs, d_lr=d_lr, g_lr=g_lr, save_freq=10, lr_decay=None, decay_freq=5, device=device) 





  self.scaler_g = GradScaler()
  self.scaler_d = GradScaler()


Training GAN (WGAN-GP)...
Epoch 1 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


  with autocast():
  with autocast():
Training: 100%|██████████| 2813/2813 [15:50<00:00,  2.96it/s, gen: 1.59e+02 disc: -4.77e+01 w_dist: -8.85e+01 grad_penalty: 4.08e+01]
Validation: 100%|██████████| 313/313 [00:43<00:00,  7.14it/s, gen: 8.87e+01 disc: -4.28e+00 w_dist: -5.01e+00 grad_penalty: 7.24e-01]


Epoch 2 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [17:21<00:00,  2.70it/s, gen: 1.09e+02 disc: -7.89e+00 w_dist: -1.49e+01 grad_penalty: 7.01e+00]
Validation: 100%|██████████| 313/313 [00:29<00:00, 10.53it/s, gen: 2.06e+02 disc: -3.89e+01 w_dist: -4.03e+01 grad_penalty: 1.45e+00]


Epoch 3 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:03<00:00,  4.66it/s, gen: 8.38e+01 disc: -7.62e+00 w_dist: -1.25e+01 grad_penalty: 4.89e+00]
Validation: 100%|██████████| 313/313 [00:29<00:00, 10.43it/s, gen: 2.05e+00 disc: 5.62e+00 w_dist: 4.36e+00 grad_penalty: 1.26e+00]


Epoch 4 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:07<00:00,  4.63it/s, gen: 1.25e+02 disc: -7.39e+00 w_dist: -1.22e+01 grad_penalty: 4.84e+00]
Validation: 100%|██████████| 313/313 [00:29<00:00, 10.48it/s, gen: 3.65e+02 disc: -1.13e+01 w_dist: -2.14e+01 grad_penalty: 1.01e+01]


Epoch 5 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:11<00:00,  4.60it/s, gen: 1.33e+02 disc: -1.00e+01 w_dist: -1.61e+01 grad_penalty: 6.07e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.19it/s, gen: 9.32e+02 disc: -1.25e+01 w_dist: -1.93e+01 grad_penalty: 6.82e+00]


Epoch 6 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:14<00:00,  4.58it/s, gen: 1.87e+02 disc: -1.08e+01 w_dist: -1.77e+01 grad_penalty: 6.81e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.28it/s, gen: 2.85e+02 disc: 2.47e+01 w_dist: 2.46e+01 grad_penalty: 9.19e-02]


Epoch 7 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:15<00:00,  4.57it/s, gen: 2.40e+02 disc: -9.90e+00 w_dist: -1.58e+01 grad_penalty: 5.86e+00]
Validation: 100%|██████████| 313/313 [00:29<00:00, 10.46it/s, gen: 2.19e+01 disc: -4.57e+01 w_dist: -4.81e+01 grad_penalty: 2.44e+00]


Epoch 8 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:12<00:00,  4.59it/s, gen: 1.28e+02 disc: -1.24e+01 w_dist: -2.09e+01 grad_penalty: 8.50e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.42it/s, gen: -1.02e+03 disc: -3.29e+01 w_dist: -4.55e+01 grad_penalty: 1.27e+01]


Epoch 9 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:10<00:00,  4.61it/s, gen: 6.42e+01 disc: -9.62e+00 w_dist: -1.62e+01 grad_penalty: 6.54e+00]
Validation: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s, gen: -3.57e+02 disc: 2.63e+01 w_dist: 2.57e+01 grad_penalty: 5.56e-01]


Epoch 10 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:07<00:00,  4.63it/s, gen: 1.82e+02 disc: -8.31e+00 w_dist: -1.38e+01 grad_penalty: 5.45e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.43it/s, gen: 9.36e+02 disc: 3.26e+01 w_dist: 3.22e+01 grad_penalty: 4.06e-01]


Saving to ./outputs/gan_checkpoints//generator_ep_010.pth and ./outputs/gan_checkpoints//discriminator_ep_010.pth
Epoch 11 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:07<00:00,  4.63it/s, gen: 1.75e+02 disc: -1.16e+01 w_dist: -1.87e+01 grad_penalty: 7.11e+00]
Validation: 100%|██████████| 313/313 [00:29<00:00, 10.51it/s, gen: -4.18e+02 disc: -2.29e+01 w_dist: -2.56e+01 grad_penalty: 2.69e+00]


Epoch 12 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:06<00:00,  4.64it/s, gen: 8.53e+01 disc: -1.14e+01 w_dist: -1.89e+01 grad_penalty: 7.53e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.18it/s, gen: -1.54e+02 disc: -1.10e+01 w_dist: -1.26e+01 grad_penalty: 1.53e+00]


Epoch 13 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:15<00:00,  4.57it/s, gen: 1.05e+02 disc: -1.21e+01 w_dist: -2.02e+01 grad_penalty: 8.09e+00]
Validation: 100%|██████████| 313/313 [00:31<00:00, 10.06it/s, gen: 1.55e+03 disc: 1.13e+01 w_dist: 3.58e+00 grad_penalty: 7.71e+00]


Epoch 14 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:17<00:00,  4.55it/s, gen: -5.46e+00 disc: -1.03e+01 w_dist: -1.75e+01 grad_penalty: 7.14e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.33it/s, gen: -8.36e+02 disc: -1.24e+01 w_dist: -1.48e+01 grad_penalty: 2.38e+00]


Epoch 15 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:23<00:00,  4.51it/s, gen: -8.31e+00 disc: -9.02e+00 w_dist: -1.52e+01 grad_penalty: 6.16e+00]
Validation: 100%|██████████| 313/313 [00:31<00:00,  9.94it/s, gen: -1.12e+03 disc: 1.28e+01 w_dist: 1.20e+01 grad_penalty: 8.22e-01]


Epoch 16 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:21<00:00,  4.52it/s, gen: -1.12e+01 disc: -8.26e+00 w_dist: -1.39e+01 grad_penalty: 5.65e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.11it/s, gen: 4.58e+02 disc: 2.10e+01 w_dist: 2.01e+01 grad_penalty: 8.82e-01]


Epoch 17 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:22<00:00,  4.52it/s, gen: -1.78e+02 disc: -7.13e+00 w_dist: -1.25e+01 grad_penalty: 5.39e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.22it/s, gen: -2.02e+02 disc: 1.17e+01 w_dist: 1.17e+01 grad_penalty: 2.38e-02]


Epoch 18 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:19<00:00,  4.54it/s, gen: -1.84e+01 disc: -6.33e+00 w_dist: -1.09e+01 grad_penalty: 4.61e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.20it/s, gen: -1.53e+03 disc: -3.38e+00 w_dist: -1.07e+01 grad_penalty: 7.34e+00]


Epoch 19 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:13<00:00,  4.59it/s, gen: 2.56e+00 disc: -5.67e+00 w_dist: -9.95e+00 grad_penalty: 4.28e+00] 
Validation: 100%|██████████| 313/313 [00:31<00:00,  9.89it/s, gen: -5.24e+00 disc: -2.17e+01 w_dist: -2.20e+01 grad_penalty: 2.89e-01]


Epoch 20 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:08<00:00,  4.63it/s, gen: 7.70e+01 disc: -6.21e+00 w_dist: -1.04e+01 grad_penalty: 4.15e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.25it/s, gen: 1.33e+03 disc: -5.33e+01 w_dist: -5.62e+01 grad_penalty: 2.90e+00]


Saving to ./outputs/gan_checkpoints//generator_ep_020.pth and ./outputs/gan_checkpoints//discriminator_ep_020.pth
Epoch 21 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:14<00:00,  4.58it/s, gen: -1.42e+02 disc: -6.14e+00 w_dist: -1.06e+01 grad_penalty: 4.45e+00]
Validation: 100%|██████████| 313/313 [00:31<00:00,  9.82it/s, gen: -3.28e+03 disc: 4.28e+00 w_dist: -1.02e+01 grad_penalty: 1.45e+01]


Epoch 22 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:16<00:00,  4.56it/s, gen: 9.90e+00 disc: -6.10e+00 w_dist: -1.05e+01 grad_penalty: 4.38e+00] 
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.20it/s, gen: -1.13e+03 disc: 8.45e+00 w_dist: 7.72e+00 grad_penalty: 7.35e-01]


Epoch 23 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:18<00:00,  4.55it/s, gen: -5.34e+01 disc: -5.89e+00 w_dist: -9.97e+00 grad_penalty: 4.08e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.20it/s, gen: -7.38e+02 disc: 1.18e+01 w_dist: 1.18e+01 grad_penalty: 1.51e-02]


Epoch 24 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:18<00:00,  4.55it/s, gen: -5.61e+01 disc: -4.87e+00 w_dist: -8.50e+00 grad_penalty: 3.63e+00]
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.10it/s, gen: 1.21e+03 disc: -3.67e+00 w_dist: -1.57e+01 grad_penalty: 1.20e+01]


Epoch 25 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:30<00:00,  4.46it/s, gen: 3.39e+01 disc: -5.19e+00 w_dist: -9.04e+00 grad_penalty: 3.85e+00] 
Validation: 100%|██████████| 313/313 [00:32<00:00,  9.66it/s, gen: -1.49e+03 disc: 1.09e+01 w_dist: 1.08e+01 grad_penalty: 1.02e-01]


Epoch 26 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:24<00:00,  4.51it/s, gen: -6.18e+01 disc: -4.34e+00 w_dist: -7.69e+00 grad_penalty: 3.35e+00]
Validation: 100%|██████████| 313/313 [00:32<00:00,  9.65it/s, gen: -2.91e+02 disc: -4.24e+00 w_dist: -4.48e+00 grad_penalty: 2.34e-01]


Epoch 27 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:23<00:00,  4.51it/s, gen: 2.56e+01 disc: -5.47e+00 w_dist: -9.60e+00 grad_penalty: 4.13e+00] 
Validation: 100%|██████████| 313/313 [00:30<00:00, 10.26it/s, gen: -6.02e+01 disc: -2.34e+00 w_dist: -5.83e+00 grad_penalty: 3.49e+00]


Epoch 28 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:17<00:00,  4.56it/s, gen: 6.77e+01 disc: -4.89e+00 w_dist: -8.62e+00 grad_penalty: 3.73e+00] 
Validation: 100%|██████████| 313/313 [00:31<00:00,  9.83it/s, gen: -1.23e+03 disc: 2.67e+00 w_dist: 2.05e+00 grad_penalty: 6.17e-01]


Epoch 29 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:21<00:00,  4.52it/s, gen: -3.25e+01 disc: -5.31e+00 w_dist: -9.09e+00 grad_penalty: 3.77e+00]
Validation: 100%|██████████| 313/313 [00:31<00:00,  9.92it/s, gen: -2.20e+03 disc: -1.92e+01 w_dist: -2.26e+01 grad_penalty: 3.33e+00]


Epoch 30 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:20<00:00,  4.54it/s, gen: -1.01e+02 disc: -5.59e+00 w_dist: -9.70e+00 grad_penalty: 4.11e+00]
Validation: 100%|██████████| 313/313 [00:31<00:00,  9.90it/s, gen: 8.90e+02 disc: 1.82e+01 w_dist: 1.19e+01 grad_penalty: 6.25e+00]


Saving to ./outputs/gan_checkpoints//generator_ep_030.pth and ./outputs/gan_checkpoints//discriminator_ep_030.pth
Epoch 31 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:26<00:00,  4.49it/s, gen: -1.33e+01 disc: -5.60e+00 w_dist: -9.84e+00 grad_penalty: 4.24e+00]
Validation: 100%|██████████| 313/313 [00:32<00:00,  9.73it/s, gen: 4.20e+02 disc: 2.13e+01 w_dist: 2.03e+01 grad_penalty: 1.00e+00]


Epoch 32 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:23<00:00,  4.51it/s, gen: 4.67e+01 disc: -5.55e+00 w_dist: -9.55e+00 grad_penalty: 4.00e+00] 
Validation: 100%|██████████| 313/313 [00:32<00:00,  9.61it/s, gen: -6.94e+02 disc: -1.89e+01 w_dist: -3.93e+01 grad_penalty: 2.04e+01]


Epoch 33 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:19<00:00,  4.54it/s, gen: 7.28e+01 disc: -5.26e+00 w_dist: -9.04e+00 grad_penalty: 3.78e+00] 
Validation: 100%|██████████| 313/313 [00:31<00:00,  9.90it/s, gen: -2.04e+03 disc: -5.70e+00 w_dist: -5.77e+00 grad_penalty: 6.50e-02]


Epoch 34 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training: 100%|██████████| 2813/2813 [10:20<00:00,  4.53it/s, gen: 1.43e+02 disc: -4.21e+00 w_dist: -7.66e+00 grad_penalty: 3.44e+00]
Validation: 100%|██████████| 313/313 [00:31<00:00,  9.93it/s, gen: -3.25e+03 disc: -4.83e+01 w_dist: -6.30e+01 grad_penalty: 1.46e+01]


Epoch 35 -- lr: 1.000e-04, 1.000e-04
-------------------------------------------------------


Training:   2%|▏         | 45/2813 [00:14<09:50,  4.69it/s, gen: -4.71e+02 disc: -4.87e-01 w_dist: -2.34e+00 grad_penalty: 1.86e+00] 

In [None]:
# Cell 7: Plot GAN losses

plt.figure()
plt.plot(G_loss_ep, label="Gen loss", color="red")
plt.plot(D_loss_ep, label="Disc loss", color="blue")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(outdir, "gan_losses.png"))
plt.close()




In [None]:
# Cell 8: Save final generator and discriminator

torch.save(gen.state_dict(), os.path.join(gan_ckpt_dir, "generator_final.pth"))
torch.save(disc.state_dict(), os.path.join(gan_ckpt_dir, "discriminator_final.pth"))

print("Done training GAN. Checkpoints in:", gan_ckpt_dir)
