In [5]:
pip install gdown

Note: you may need to restart the kernel to use updated packages.


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

In [7]:
class Config:
    # Matériel
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Hyperparamètres WGAN-GP
    LR = 1e-4
    BATCH_SIZE = 64
    IMAGE_SIZE = 28
    CHANNELS = 1
    Z_DIM = 100
    NUM_EPOCHS = 20
    FEATURES_DIM = 64
    CRITIC_ITERATIONS = 5
    LAMBDA_GP = 10
    
    # --- GESTION DES CHEMINS ---
    # On suppose que le script tourne dans "model code/GANs/"
    
    # Chemin vers: denoising-diffusion-model/dataset
    # Torchvision ajoutera automatiquement le sous-dossier /MNIST
    DATA_ROOT = os.path.join("..", "..", "dataset") 
    
    # Chemin vers: denoising-diffusion-model/model code/GANs/samples
    IMG_DIR = "samples"
    
    # Nom du fichier checkpoint
    CKPT_NAME = "wgan_mnist_ckpt.pth"
    
    # Fréquence de sauvegarde
    SAVE_EVERY = 5

class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            # Input: N x 1 x 28 x 28
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # -> 14x14
            nn.LeakyReLU(0.2),
            
            # 14x14 -> 7x7
            nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(features_d * 2, affine=True),
            nn.LeakyReLU(0.2),
            
            # 7x7 -> 3x3
            nn.Conv2d(features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(features_d * 4, affine=True),
            nn.LeakyReLU(0.2),
            
            # 3x3 -> 1x1
            nn.Conv2d(features_d * 4, 1, kernel_size=3, stride=1, padding=0),
        )

    def forward(self, x):
        return self.critic(x)

class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: Z (N x 100 x 1 x 1) -> 7x7
            nn.ConvTranspose2d(z_dim, features_g * 4, kernel_size=7, stride=1, padding=0),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(),
            
            # 7x7 -> 14x14
            nn.ConvTranspose2d(features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(),
            
            # 14x14 -> 28x28
            nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh(), # Sortie entre [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [8]:
import gdown
import os

file_id = "1_85f6DEJ4lEZl0VWx0V5PzSRMd4-l4yK"
out_path = "wgan_mnist_ckpt.pth"

if not os.path.exists(out_path):
    gdown.download(
        f"https://drive.google.com/uc?id={file_id}",
        out_path,
        quiet=False,
    )
else:
    print("Already downloaded.")

Downloading...
From (original): https://drive.google.com/uc?id=1_85f6DEJ4lEZl0VWx0V5PzSRMd4-l4yK
From (redirected): https://drive.google.com/uc?id=1_85f6DEJ4lEZl0VWx0V5PzSRMd4-l4yK&confirm=t&uuid=87e0042d-05f3-4d0c-bbd6-c92fa5070d87
To: /home/onyxia/work/denoising-diffusion-model/measurement code/wgan_mnist_ckpt.pth
100%|██████████| 29.3M/29.3M [00:00<00:00, 97.4MB/s]


In [13]:
conf = Config()
gen = Generator(conf.Z_DIM, conf.CHANNELS, conf.FEATURES_DIM).to(conf.DEVICE)

# Charger le checkpoint
ckpt = torch.load("wgan_mnist_ckpt.pth", map_location=conf.DEVICE)
gen.load_state_dict(ckpt["gen_state"])
gen.eval()

import time
import numpy as np
import torch
import torch.nn as nn

# ---------------------------
# Utilities
# ---------------------------
def _sync_if_cuda(device):
    if torch.cuda.is_available() and ("cuda" in str(device)):
        torch.cuda.synchronize()

def _percentiles(xs, ps=(50, 95)):
    xs = np.asarray(xs, dtype=np.float64)
    return {f"p{p}": float(np.percentile(xs, p)) for p in ps}

import time
import numpy as np
import torch
import torch.nn as nn

# ---------------------------
# Utilities
# ---------------------------
def _sync_if_cuda(device):
    if torch.cuda.is_available() and ("cuda" in str(device)):
        torch.cuda.synchronize()

def _percentiles(xs, ps=(50, 95)):
    xs = np.asarray(xs, dtype=np.float64)
    return {f"p{p}": float(np.percentile(xs, p)) for p in ps}

def sample_z(batch_size: int, z_dim: int, device: str, conv_style: bool = True):
    """
    For ConvTranspose2d generators, z should be [B, Z, 1, 1].
    For MLP generators, z is typically [B, Z].
    Your error indicates conv_style=True is required.
    """
    if conv_style:
        return torch.randn(batch_size, z_dim, 1, 1, device=device)
    return torch.randn(batch_size, z_dim, device=device)

# ---------------------------
# Timed Generator Wrapper
# ---------------------------
class TimedGenerator(nn.Module):
    """
    Wraps a GAN generator to time forward passes.
    For a GAN generator, NFE/sample is always 1 by definition.
    """
    def __init__(self, generator: nn.Module, device: str):
        super().__init__()
        self.G = generator
        self.device = device
        self.reset_stats()

    def reset_stats(self):
        self.nfe = 0
        self.model_ms = 0.0

    @torch.no_grad()
    def forward(self, z):
        self.nfe += 1
        if torch.cuda.is_available() and ("cuda" in str(self.device)):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = self.G(z)
            end.record()
            torch.cuda.synchronize()
            self.model_ms += start.elapsed_time(end)
            return out
        else:
            t0 = time.perf_counter()
            out = self.G(z)
            t1 = time.perf_counter()
            self.model_ms += (t1 - t0) * 1000.0
            return out

# ---------------------------
# Benchmark runner
# ---------------------------
@torch.no_grad()
def run_gan_benchmark(
    generator: nn.Module,
    *,
    z_dim: int,
    batch_size: int,
    device: str,
    conv_style_z: bool = True,
    warmup_runs: int = 2,
    timed_runs: int = 10,
    reset_cuda_peak_mem: bool = True,
):
    generator.eval().to(device)
    timed_G = TimedGenerator(generator, device=device).to(device)

    # Warmup
    for _ in range(warmup_runs):
        timed_G.reset_stats()
        z = sample_z(batch_size, z_dim, device, conv_style=conv_style_z)
        _sync_if_cuda(device)
        _ = timed_G(z)
        _sync_if_cuda(device)

    per_img_total_ms = []
    per_img_model_ms = []
    per_img_overhd_ms = []
    ms_per_forward_model = []
    throughput_img_s = []
    peak_alloc_mb = []
    peak_reserved_mb = []

    for _ in range(timed_runs):
        timed_G.reset_stats()

        if reset_cuda_peak_mem and torch.cuda.is_available() and ("cuda" in str(device)):
            torch.cuda.reset_peak_memory_stats()

        z = sample_z(batch_size, z_dim, device, conv_style=conv_style_z)

        _sync_if_cuda(device)
        t0 = time.perf_counter()
        _ = timed_G(z)
        _sync_if_cuda(device)
        t1 = time.perf_counter()

        total_ms_batch = (t1 - t0) * 1000.0
        model_ms_batch = float(timed_G.model_ms)
        overhd_ms_batch = total_ms_batch - model_ms_batch

        # Per-image
        total_ms_img = total_ms_batch / batch_size
        model_ms_img = model_ms_batch / batch_size
        overhd_ms_img = overhd_ms_batch / batch_size

        per_img_total_ms.append(total_ms_img)
        per_img_model_ms.append(model_ms_img)
        per_img_overhd_ms.append(overhd_ms_img)

        # For GAN, one forward produces one batch -> ms/forward per *image* is model_ms_img
        ms_per_forward_model.append(model_ms_img)

        throughput_img_s.append(1000.0 / max(total_ms_img, 1e-12))

        if torch.cuda.is_available() and ("cuda" in str(device)):
            peak_alloc_mb.append(torch.cuda.max_memory_allocated() / (1024**2))
            peak_reserved_mb.append(torch.cuda.max_memory_reserved() / (1024**2))

    summary = {
        "batch_size": batch_size,
        "timed_runs": timed_runs,
        "nfe_sample": 1,

        "total_ms_img": _percentiles(per_img_total_ms, ps=(50, 95)),
        "model_ms_img": _percentiles(per_img_model_ms, ps=(50, 95)),
        "overhd_ms_img": _percentiles(per_img_overhd_ms, ps=(50, 95)),
        "ms_forward_model": _percentiles(ms_per_forward_model, ps=(50, 95)),
        "throughput": _percentiles(throughput_img_s, ps=(50, 95)),
    }

    if peak_alloc_mb:
        summary["peak_alloc_mb_p50"] = _percentiles(peak_alloc_mb, ps=(50,))["p50"]
        summary["peak_reserved_mb_p50"] = _percentiles(peak_reserved_mb, ps=(50,))["p50"]

    return summary

def pretty_print_gan(summary: dict, name: str):
    print(f"\n=== Benchmark: {name} ===")
    print(f"batch={summary['batch_size']} | timed_runs={summary['timed_runs']}")
    print(f"NFE/sample: {summary['nfe_sample']}")
    print(f"total ms/img p50: {summary['total_ms_img']['p50']:.6f} | p95: {summary['total_ms_img']['p95']:.6f}")
    print(f"model ms/img p50: {summary['model_ms_img']['p50']:.6f} | p95: {summary['model_ms_img']['p95']:.6f}")
    print(f"overhd ms/img p50: {summary['overhd_ms_img']['p50']:.6f} | p95: {summary['overhd_ms_img']['p95']:.6f}")
    print(f"ms/forward (model) p50: {summary['ms_forward_model']['p50']:.6f} | p95: {summary['ms_forward_model']['p95']:.6f}")
    print(f"throughput img/s p50: {summary['throughput']['p50']:.2f} | p95: {summary['throughput']['p95']:.2f}")
    if "peak_alloc_mb_p50" in summary:
        print(f"peak alloc MB p50: {summary['peak_alloc_mb_p50']:.1f} | peak reserved MB p50: {summary['peak_reserved_mb_p50']:.1f}")

# ---------------------------
# Use YOUR notebook globals (no guessing)
# ---------------------------
# Your notebook defines: conf, gen
device = conf.DEVICE
z_dim = conf.Z_DIM
generator = gen

# Your generator uses ConvTranspose2d -> needs z as [B, Z, 1, 1]
CONV_STYLE_Z = True

b16 = run_gan_benchmark(generator, z_dim=z_dim, batch_size=16, device=device, conv_style_z=CONV_STYLE_Z)
pretty_print_gan(b16, name="WGAN MNIST | Generator | batch=16")

b128 = run_gan_benchmark(generator, z_dim=z_dim, batch_size=128, device=device, conv_style_z=CONV_STYLE_Z)
pretty_print_gan(b128, name="WGAN MNIST | Generator | batch=128")




=== Benchmark: WGAN MNIST | Generator | batch=16 ===
batch=16 | timed_runs=10
NFE/sample: 1
total ms/img p50: 0.110475 | p95: 0.111467
model ms/img p50: 0.105405 | p95: 0.106354
overhd ms/img p50: 0.005031 | p95: 0.005405
ms/forward (model) p50: 0.105405 | p95: 0.106354
throughput img/s p50: 9051.82 | p95: 9112.21
peak alloc MB p50: 45.2 | peak reserved MB p50: 100.0

=== Benchmark: WGAN MNIST | Generator | batch=128 ===
batch=128 | timed_runs=10
NFE/sample: 1
total ms/img p50: 0.086938 | p95: 0.097352
model ms/img p50: 0.086240 | p95: 0.096720
overhd ms/img p50: 0.000648 | p95: 0.000740
ms/forward (model) p50: 0.086240 | p95: 0.096720
throughput img/s p50: 11502.50 | p95: 11520.68
peak alloc MB p50: 67.5 | peak reserved MB p50: 100.0
