In [1]:
import torch
import torch.nn as nn
import torchvision
import os
import gdown

In [2]:
# ==========================================
# 1. TÉLÉCHARGEMENT
# ==========================================
file_id = "1axjh_HshUWWEXoOLbf2L83WTtMKv0VkE"
out_path = "wgan_cifar_advanced_ckpt.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not os.path.exists(out_path):
    print(f"Téléchargement de {out_path}...")
    gdown.download(f"https://drive.google.com/uc?id={file_id}", out_path, quiet=False)
else:
    print("Fichier checkpoint déjà présent.")

Téléchargement de wgan_cifar_advanced_ckpt.pth...


Downloading...
From (original): https://drive.google.com/uc?id=1axjh_HshUWWEXoOLbf2L83WTtMKv0VkE
From (redirected): https://drive.google.com/uc?id=1axjh_HshUWWEXoOLbf2L83WTtMKv0VkE&confirm=t&uuid=9c84ef28-40d8-4fab-8d81-5656ab7b408a
To: /home/onyxia/work/denoising-diffusion-model/measurement code/wgan_cifar_advanced_ckpt.pth
100%|██████████| 70.8M/70.8M [00:01<00:00, 60.6MB/s]


In [3]:
# ==========================================
# 2. ARCHITECTURE CORRIGÉE (Instance Norm)
# ==========================================

class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        return self.gamma * out + x

class GeneratorCorrected(nn.Module):
    def __init__(self, z_dim, channels_img, features_g=128):
        super(GeneratorCorrected, self).__init__()
        
        # Helper block: Upsample + Conv + INSTANCE NORM
        def block(in_channels, out_channels, normalize=True):
            layers = [
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            ]
            if normalize:
                # CORRECTION ICI : InstanceNorm2d au lieu de BatchNorm2d
                # affine=True permet d'apprendre des poids (weight/bias) comme BatchNorm
                # mais sans stocker running_mean/var
                layers.append(nn.InstanceNorm2d(out_channels, affine=True))
            layers.append(nn.ReLU())
            return layers

        self.initial = nn.Sequential(
            nn.ConvTranspose2d(z_dim, features_g * 4, 4, 1, 0),
            # CORRECTION ICI AUSSI
            nn.InstanceNorm2d(features_g * 4, affine=True),
            nn.ReLU(),
        )

        self.layer1 = nn.Sequential(
            *block(features_g * 4, features_g * 2),
        )
        
        self.attn = SelfAttention(features_g * 2)

        self.layer2 = nn.Sequential(
            *block(features_g * 2, features_g),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(features_g, channels_img, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.initial(x)
        out = self.layer1(out)
        out = self.attn(out)
        return self.layer2(out)

In [4]:
# ==========================================
# 3. CHARGEMENT ET GÉNÉRATION
# ==========================================
Z_DIM = 100
CHANNELS = 3
FEATURES_DIM = 128

# Instanciation du modèle corrigé
model = GeneratorCorrected(Z_DIM, CHANNELS, features_g=FEATURES_DIM).to(device)

print(f"Chargement du checkpoint : {out_path}")
ckpt = torch.load(out_path, map_location=device)

# Sélection du bon dictionnaire de poids
if "ema_state" in ckpt:
    state_dict = ckpt["ema_state"]
    print(">> Poids EMA détectés (Meilleure qualité)")
elif "gen_state" in ckpt:
    state_dict = ckpt["gen_state"]
    print(">> Poids standards détectés")
else:
    state_dict = ckpt

# Chargement
try:
    model.load_state_dict(state_dict)
    print(">> Poids chargés avec SUCCÈS ! (Architecture InstanceNorm validée)")
except RuntimeError as e:
    print(f"\nERREUR ENCORE PRÉSENTE : {e}")
    print("Essai de chargement avec strict=False (Risqué mais peut marcher)...")
    model.load_state_dict(state_dict, strict=False)

Chargement du checkpoint : wgan_cifar_advanced_ckpt.pth
>> Poids EMA détectés (Meilleure qualité)
>> Poids chargés avec SUCCÈS ! (Architecture InstanceNorm validée)


In [10]:
# Génération
model.eval()
num_samples = 16
noise = torch.randn(num_samples, Z_DIM, 1, 1).to(device)

import time
import numpy as np
import torch
import torch.nn as nn

# ---------------------------
# Utilities
# ---------------------------
def _sync_if_cuda(device):
    if torch.cuda.is_available() and ("cuda" in str(device)):
        torch.cuda.synchronize()

def _percentiles(xs, ps=(50, 95)):
    xs = np.asarray(xs, dtype=np.float64)
    return {f"p{p}": float(np.percentile(xs, p)) for p in ps}

def sample_z(batch_size: int, z_dim: int, device: str, conv_style: bool = True):
    """
    conv_style=True  -> z is [B, Z, 1, 1] (ConvTranspose2d generators)
    conv_style=False -> z is [B, Z]       (MLP / linear-first generators)
    """
    if conv_style:
        return torch.randn(batch_size, z_dim, 1, 1, device=device)
    return torch.randn(batch_size, z_dim, device=device)

import time
import numpy as np
import torch
import torch.nn as nn

# ---------------------------
# Utilities
# ---------------------------
def _sync_if_cuda(device):
    if torch.cuda.is_available() and ("cuda" in str(device)):
        torch.cuda.synchronize()

def _percentiles(xs, ps=(50, 95)):
    xs = np.asarray(xs, dtype=np.float64)
    return {f"p{p}": float(np.percentile(xs, p)) for p in ps}

def sample_z(batch_size: int, z_dim: int, device: str, conv_style: bool = True):
    """
    conv_style=True  -> z is [B, Z, 1, 1] (ConvTranspose2d generators)
    conv_style=False -> z is [B, Z]       (Linear/MLP generators)
    """
    if conv_style:
        return torch.randn(batch_size, z_dim, 1, 1, device=device)
    return torch.randn(batch_size, z_dim, device=device)

# ---------------------------
# Timed Generator Wrapper
# ---------------------------
class TimedGenerator(nn.Module):
    """Times generator forward passes. For GAN generation, NFE/sample is always 1."""
    def __init__(self, generator: nn.Module, device: str):
        super().__init__()
        self.G = generator
        self.device = device
        self.reset_stats()

    def reset_stats(self):
        self.nfe = 0
        self.model_ms = 0.0

    @torch.no_grad()
    def forward(self, z):
        self.nfe += 1
        if torch.cuda.is_available() and ("cuda" in str(self.device)):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = self.G(z)
            end.record()
            torch.cuda.synchronize()
            self.model_ms += start.elapsed_time(end)
            return out
        else:
            t0 = time.perf_counter()
            out = self.G(z)
            t1 = time.perf_counter()
            self.model_ms += (t1 - t0) * 1000.0
            return out

# ---------------------------
# Benchmark runner
# ---------------------------
@torch.no_grad()
def run_gan_benchmark(
    generator: nn.Module,
    *,
    z_dim: int,
    batch_size: int,
    device: str,
    conv_style_z: bool = True,
    warmup_runs: int = 10,
    timed_runs: int = 10,
    reset_cuda_peak_mem: bool = True,
):
    generator.eval().to(device)
    timed_G = TimedGenerator(generator, device=device).to(device)

    # Warmup
    for _ in range(warmup_runs):
        timed_G.reset_stats()
        z = sample_z(batch_size, z_dim, device, conv_style=conv_style_z)
        _sync_if_cuda(device)
        _ = timed_G(z)
        _sync_if_cuda(device)

    per_img_total_ms = []
    per_img_model_ms = []
    per_img_overhd_ms = []
    ms_per_forward_model = []
    throughput_img_s = []
    peak_alloc_mb = []
    peak_reserved_mb = []

    for _ in range(timed_runs):
        timed_G.reset_stats()

        if reset_cuda_peak_mem and torch.cuda.is_available() and ("cuda" in str(device)):
            torch.cuda.reset_peak_memory_stats()

        z = sample_z(batch_size, z_dim, device, conv_style=conv_style_z)

        _sync_if_cuda(device)
        t0 = time.perf_counter()
        _ = timed_G(z)
        _sync_if_cuda(device)
        t1 = time.perf_counter()

        total_ms_batch = (t1 - t0) * 1000.0
        model_ms_batch = float(timed_G.model_ms)
        overhd_ms_batch = total_ms_batch - model_ms_batch

        total_ms_img = total_ms_batch / batch_size
        model_ms_img = model_ms_batch / batch_size
        overhd_ms_img = overhd_ms_batch / batch_size

        per_img_total_ms.append(total_ms_img)
        per_img_model_ms.append(model_ms_img)
        per_img_overhd_ms.append(overhd_ms_img)

        # For GAN, per-image "ms/forward" equals model_ms_img (one forward generates one image)
        ms_per_forward_model.append(model_ms_img)

        throughput_img_s.append(1000.0 / max(total_ms_img, 1e-12))

        if torch.cuda.is_available() and ("cuda" in str(device)):
            peak_alloc_mb.append(torch.cuda.max_memory_allocated() / (1024**2))
            peak_reserved_mb.append(torch.cuda.max_memory_reserved() / (1024**2))

    summary = {
        "batch_size": batch_size,
        "timed_runs": timed_runs,
        "nfe_sample": 1,
        "total_ms_img": _percentiles(per_img_total_ms, ps=(50, 95)),
        "model_ms_img": _percentiles(per_img_model_ms, ps=(50, 95)),
        "overhd_ms_img": _percentiles(per_img_overhd_ms, ps=(50, 95)),
        "ms_forward_model": _percentiles(ms_per_forward_model, ps=(50, 95)),
        "throughput": _percentiles(throughput_img_s, ps=(50, 95)),
        "peak_alloc_mb_p50": _percentiles(peak_alloc_mb, ps=(50,))["p50"] if peak_alloc_mb else None,
        "peak_reserved_mb_p50": _percentiles(peak_reserved_mb, ps=(50,))["p50"] if peak_reserved_mb else None,
    }
    return summary

def pretty_print_gan(summary: dict, name: str):
    print(f"\n=== Benchmark: {name} ===")
    print(f"batch={summary['batch_size']} | timed_runs={summary['timed_runs']}")
    print(f"NFE/sample: {summary['nfe_sample']}")
    print(f"total ms/img p50: {summary['total_ms_img']['p50']:.6f} | p95: {summary['total_ms_img']['p95']:.6f}")
    print(f"model ms/img p50: {summary['model_ms_img']['p50']:.6f} | p95: {summary['model_ms_img']['p95']:.6f}")
    print(f"overhd ms/img p50: {summary['overhd_ms_img']['p50']:.6f} | p95: {summary['overhd_ms_img']['p95']:.6f}")
    print(f"ms/forward (model) p50: {summary['ms_forward_model']['p50']:.6f} | p95: {summary['ms_forward_model']['p95']:.6f}")
    print(f"throughput img/s p50: {summary['throughput']['p50']:.2f} | p95: {summary['throughput']['p95']:.2f}")
    if summary["peak_alloc_mb_p50"] is not None:
        print(f"peak alloc MB p50: {summary['peak_alloc_mb_p50']:.1f} | peak reserved MB p50: {summary['peak_reserved_mb_p50']:.1f}")


# ============================================================
# Bind to your notebook's variables (confirmed)
# ============================================================
generator = model     # <- your GeneratorCorrected instance
z_dim = Z_DIM         # <- your latent dimension
# device is already defined in your notebook

# If you get a conv_transpose2d 4D error, keep True.
# If you get a Linear matmul shape error, set False.
CONV_STYLE_Z = True

# ============================================================
# Run
# ============================================================
b16 = run_gan_benchmark(generator, z_dim=z_dim, batch_size=16, device=device, conv_style_z=CONV_STYLE_Z)
pretty_print_gan(b16, name="WGAN CIFAR | Generator | batch=16")

b128 = run_gan_benchmark(generator, z_dim=z_dim, batch_size=128, device=device, conv_style_z=CONV_STYLE_Z)
pretty_print_gan(b128, name="WGAN CIFAR | Generator | batch=128")



=== Benchmark: WGAN CIFAR | Generator | batch=16 ===
batch=16 | timed_runs=10
NFE/sample: 1
total ms/img p50: 0.275251 | p95: 0.277356
model ms/img p50: 0.269369 | p95: 0.270967
overhd ms/img p50: 0.005852 | p95: 0.006702
ms/forward (model) p50: 0.269369 | p95: 0.270967
throughput img/s p50: 3633.06 | p95: 3649.90
peak alloc MB p50: 113.6 | peak reserved MB p50: 508.0

=== Benchmark: WGAN CIFAR | Generator | batch=128 ===
batch=128 | timed_runs=10
NFE/sample: 1
total ms/img p50: 0.115927 | p95: 0.117655
model ms/img p50: 0.114861 | p95: 0.116727
overhd ms/img p50: 0.000894 | p95: 0.001538
ms/forward (model) p50: 0.114861 | p95: 0.116727
throughput img/s p50: 8626.12 | p95: 8673.31
peak alloc MB p50: 299.7 | peak reserved MB p50: 508.0
