In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

In [2]:
class Config:
    # Matériel
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Hyperparamètres WGAN-GP
    LR = 1e-4
    BATCH_SIZE = 64
    IMAGE_SIZE = 28
    CHANNELS = 1
    Z_DIM = 100
    NUM_EPOCHS = 20
    FEATURES_DIM = 64
    CRITIC_ITERATIONS = 5
    LAMBDA_GP = 10
    
    # --- GESTION DES CHEMINS ---
    # On suppose que le script tourne dans "model code/GANs/"
    
    # Chemin vers: denoising-diffusion-model/dataset
    # Torchvision ajoutera automatiquement le sous-dossier /MNIST
    DATA_ROOT = os.path.join("..", "..", "dataset") 
    
    # Chemin vers: denoising-diffusion-model/model code/GANs/samples
    IMG_DIR = "samples"
    
    # Nom du fichier checkpoint
    CKPT_NAME = "wgan_mnist_ckpt.pth"
    
    # Fréquence de sauvegarde
    SAVE_EVERY = 5

class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            # Input: N x 1 x 28 x 28
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # -> 14x14
            nn.LeakyReLU(0.2),
            
            # 14x14 -> 7x7
            nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(features_d * 2, affine=True),
            nn.LeakyReLU(0.2),
            
            # 7x7 -> 3x3
            nn.Conv2d(features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(features_d * 4, affine=True),
            nn.LeakyReLU(0.2),
            
            # 3x3 -> 1x1
            nn.Conv2d(features_d * 4, 1, kernel_size=3, stride=1, padding=0),
        )

    def forward(self, x):
        return self.critic(x)

class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: Z (N x 100 x 1 x 1) -> 7x7
            nn.ConvTranspose2d(z_dim, features_g * 4, kernel_size=7, stride=1, padding=0),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(),
            
            # 7x7 -> 14x14
            nn.ConvTranspose2d(features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(),
            
            # 14x14 -> 28x28
            nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh(), # Sortie entre [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [3]:
import gdown
import os

file_id = "1_85f6DEJ4lEZl0VWx0V5PzSRMd4-l4yK"
out_path = "wgan_mnist_ckpt.pth"

if not os.path.exists(out_path):
    gdown.download(
        f"https://drive.google.com/uc?id={file_id}",
        out_path,
        quiet=False,
    )
else:
    print("Already downloaded.")

Downloading...
From (original): https://drive.google.com/uc?id=1_85f6DEJ4lEZl0VWx0V5PzSRMd4-l4yK
From (redirected): https://drive.google.com/uc?id=1_85f6DEJ4lEZl0VWx0V5PzSRMd4-l4yK&confirm=t&uuid=f6de9514-d1d0-4cf5-b0de-963cf7701eb9
To: c:\Users\alban\Documents\Cursor code\denoising-diffusion-model\reproducibility\wgan_mnist_ckpt.pth
100%|██████████| 29.3M/29.3M [00:05<00:00, 4.93MB/s]


In [4]:
conf = Config()
gen = Generator(conf.Z_DIM, conf.CHANNELS, conf.FEATURES_DIM).to(conf.DEVICE)

# Charger le checkpoint
ckpt = torch.load("wgan_mnist_ckpt.pth", map_location=conf.DEVICE)
gen.load_state_dict(ckpt["gen_state"])
gen.eval()

# Générer 16 images
noise = torch.randn(16, conf.Z_DIM, 1, 1).to(conf.DEVICE)
with torch.no_grad():
    fake = gen(noise)
    fake = (fake * 0.5) + 0.5
    torchvision.utils.save_image(fake, "generated_result.png", nrow=4)
    print("Image générée : generated_result.png")

  ckpt = torch.load("wgan_mnist_ckpt.pth", map_location=conf.DEVICE)


Image générée : generated_result.png
