In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
import time

In [2]:
# ==========================================
# 1. CONFIGURATION
# ==========================================
class Config:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # WGAN-GP MNIST
    LR = 1e-4
    BATCH_SIZE = 64
    IMAGE_SIZE = 28
    CHANNELS = 1
    Z_DIM = 100
    NUM_EPOCHS = 50
    FEATURES_DIM = 64
    CRITIC_ITERATIONS = 5
    LAMBDA_GP = 10
    
    # Chemins
    # Dataset : ../../dataset (Torchvision ajoutera /MNIST automatiquement)
    DATA_ROOT = os.path.join("..", "..", "dataset") 
    IMG_DIR = "samples"
    CKPT_NAME = "wgan_mnist_ckpt.pth"
    SAVE_EVERY = 5

conf = Config()
print(f"Entraînement sur : {conf.DEVICE}")

os.makedirs(conf.IMG_DIR, exist_ok=True)
ckpt_path_full = conf.CKPT_NAME

Entraînement sur : cuda


In [3]:
# ==========================================
# 2. ARCHITECTURES (MNIST 28x28)
# ==========================================
class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            nn.Conv2d(channels_img, features_d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d, features_d * 2, 4, 2, 1),
            nn.InstanceNorm2d(features_d * 2, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1),
            nn.InstanceNorm2d(features_d * 4, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 4, 1, 3, 1, 0),
        )

    def forward(self, x):
        return self.critic(x)

class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(z_dim, features_g * 4, 7, 1, 0),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(features_g * 2, channels_img, 4, 2, 1),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

In [4]:
# ==========================================
# 3. GRADIENT PENALTY
# ==========================================
def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated = real * epsilon + fake * (1 - epsilon)
    interpolated.requires_grad_(True)
    mixed_scores = critic(interpolated)
    gradient = torch.autograd.grad(
        inputs=interpolated,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    return torch.mean((gradient_norm - 1) ** 2)

# ==========================================
# 4. INITIALISATION
# ==========================================
transform = transforms.Compose([
    transforms.Resize(conf.IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

dataset = torchvision.datasets.MNIST(root=conf.DATA_ROOT, train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=conf.BATCH_SIZE, shuffle=True)

gen = Generator(conf.Z_DIM, conf.CHANNELS, conf.FEATURES_DIM).to(conf.DEVICE)
critic = Critic(conf.CHANNELS, conf.FEATURES_DIM).to(conf.DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=conf.LR, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=conf.LR, betas=(0.0, 0.9))
fixed_noise = torch.randn(64, conf.Z_DIM, 1, 1).to(conf.DEVICE)

start_epoch = 0
if os.path.exists(ckpt_path_full):
    print(f"Chargement du checkpoint : {ckpt_path_full}")
    ckpt = torch.load(ckpt_path_full, map_location=conf.DEVICE)
    gen.load_state_dict(ckpt["gen_state"])
    critic.load_state_dict(ckpt["critic_state"])
    opt_gen.load_state_dict(ckpt["opt_gen_state"])
    opt_critic.load_state_dict(ckpt["opt_critic_state"])
    start_epoch = ckpt["epoch"] + 1

In [5]:
# ==========================================
# 5. BOUCLE D'ENTRAÎNEMENT + TIMING
# ==========================================
print("Début de l'entraînement MNIST...")

for epoch in range(start_epoch, conf.NUM_EPOCHS):
    gen.train()
    critic.train()
    
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(conf.DEVICE)
        cur_batch_size = real.shape[0]

        # Train Critic
        for _ in range(conf.CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, conf.Z_DIM, 1, 1).to(conf.DEVICE)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, conf.DEVICE)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + conf.LAMBDA_GP * gp
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

    # --- Sauvegarde & Mesure du temps de Forward ---
    if (epoch + 1) % conf.SAVE_EVERY == 0 or (epoch + 1) == conf.NUM_EPOCHS:
        print(f"--> Epoch {epoch+1}")
        
        # Sauvegarde PTH
        torch.save({
            "epoch": epoch,
            "gen_state": gen.state_dict(),
            "critic_state": critic.state_dict(),
            "opt_gen_state": opt_gen.state_dict(),
            "opt_critic_state": opt_critic.state_dict(),
        }, ckpt_path_full)
        
        # Génération & Timing
        gen.eval()
        
        # Synchronisation GPU pour une mesure précise du temps (si CUDA)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            
        start_time = time.time()  # <--- DÉBUT CHRONO
        
        with torch.no_grad():
            fake_img = gen(fixed_noise) # FORWARD PASS
            
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            
        end_time = time.time()    # <--- FIN CHRONO
        forward_duration = end_time - start_time
        
        print(f"    [TIMING] Temps de Forward pour {conf.BATCH_SIZE} images : {forward_duration:.5f} sec")
        print(f"    [TIMING] Temps moyen par image : {forward_duration/conf.BATCH_SIZE:.6f} sec")

        # Sauvegarde Image
        fake_img = (fake_img * 0.5) + 0.5
        save_path = os.path.join(conf.IMG_DIR, f"epoch_{epoch+1}.png")
        save_image(fake_img, save_path, nrow=8)

Début de l'entraînement MNIST...
--> Epoch 5
    [TIMING] Temps de Forward pour 64 images : 0.00250 sec
    [TIMING] Temps moyen par image : 0.000039 sec
--> Epoch 10
    [TIMING] Temps de Forward pour 64 images : 0.00102 sec
    [TIMING] Temps moyen par image : 0.000016 sec
--> Epoch 15
    [TIMING] Temps de Forward pour 64 images : 0.00098 sec
    [TIMING] Temps moyen par image : 0.000015 sec
--> Epoch 20
    [TIMING] Temps de Forward pour 64 images : 0.00707 sec
    [TIMING] Temps moyen par image : 0.000110 sec
--> Epoch 25
    [TIMING] Temps de Forward pour 64 images : 0.00289 sec
    [TIMING] Temps moyen par image : 0.000045 sec
--> Epoch 30
    [TIMING] Temps de Forward pour 64 images : 0.00000 sec
    [TIMING] Temps moyen par image : 0.000000 sec
--> Epoch 35
    [TIMING] Temps de Forward pour 64 images : 0.00199 sec
    [TIMING] Temps moyen par image : 0.000031 sec
--> Epoch 40
    [TIMING] Temps de Forward pour 64 images : 0.00033 sec
    [TIMING] Temps moyen par image : 0.000