In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

# ==========================================
# 1. CONFIGURATION & CHEMINS
# ==========================================
class Config:
    # Matériel
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Hyperparamètres WGAN-GP
    LR = 1e-4
    BATCH_SIZE = 64
    IMAGE_SIZE = 28
    CHANNELS = 1
    Z_DIM = 100
    NUM_EPOCHS = 20
    FEATURES_DIM = 64
    CRITIC_ITERATIONS = 5
    LAMBDA_GP = 10
    
    # --- GESTION DES CHEMINS ---
    # On suppose que le script tourne dans "model code/GANs/"
    
    # Chemin vers: denoising-diffusion-model/dataset
    # Torchvision ajoutera automatiquement le sous-dossier /MNIST
    DATA_ROOT = os.path.join("..", "..", "dataset") 
    
    # Chemin vers: denoising-diffusion-model/model code/GANs/samples
    IMG_DIR = "samples"
    
    # Nom du fichier checkpoint
    CKPT_NAME = "wgan_mnist_ckpt.pth"
    
    # Fréquence de sauvegarde
    SAVE_EVERY = 5

conf = Config()
print(f"Entraînement sur : {conf.DEVICE}")

# Création du dossier pour les samples s'il n'existe pas
os.makedirs(conf.IMG_DIR, exist_ok=True)

# Chemin complet du fichier .pth (dans le dossier courant du notebook)
ckpt_path_full = conf.CKPT_NAME

print(f"Dataset sera dans : {os.path.abspath(conf.DATA_ROOT)}")
print(f"Samples seront dans : {os.path.abspath(conf.IMG_DIR)}")
print(f"Checkpoint sera : {os.path.abspath(ckpt_path_full)}")

# ==========================================
# 2. ARCHITECTURES (MNIST 28x28)
# ==========================================
class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            # Input: N x 1 x 28 x 28
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # -> 14x14
            nn.LeakyReLU(0.2),
            
            # 14x14 -> 7x7
            nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(features_d * 2, affine=True),
            nn.LeakyReLU(0.2),
            
            # 7x7 -> 3x3
            nn.Conv2d(features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(features_d * 4, affine=True),
            nn.LeakyReLU(0.2),
            
            # 3x3 -> 1x1
            nn.Conv2d(features_d * 4, 1, kernel_size=3, stride=1, padding=0),
        )

    def forward(self, x):
        return self.critic(x)

class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: Z (N x 100 x 1 x 1) -> 7x7
            nn.ConvTranspose2d(z_dim, features_g * 4, kernel_size=7, stride=1, padding=0),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(),
            
            # 7x7 -> 14x14
            nn.ConvTranspose2d(features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(),
            
            # 14x14 -> 28x28
            nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh(), # Sortie entre [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

# ==========================================
# 3. UTILS : GRADIENT PENALTY
# ==========================================
def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated = real * epsilon + fake * (1 - epsilon)
    
    # Indispensable pour calculer le gradient par rapport à l'entrée
    interpolated.requires_grad_(True)

    mixed_scores = critic(interpolated)
    
    gradient = torch.autograd.grad(
        inputs=interpolated,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    return torch.mean((gradient_norm - 1) ** 2)

# ==========================================
# 4. PRÉPARATION DES DONNÉES
# ==========================================
transform = transforms.Compose([
    transforms.Resize(conf.IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# Le dataset sera téléchargé dans ../../dataset/MNIST
dataset = torchvision.datasets.MNIST(
    root=conf.DATA_ROOT, 
    train=True, 
    transform=transform, 
    download=True
)
loader = DataLoader(dataset, batch_size=conf.BATCH_SIZE, shuffle=True)

# Initialisation des modèles
gen = Generator(conf.Z_DIM, conf.CHANNELS, conf.FEATURES_DIM).to(conf.DEVICE)
critic = Critic(conf.CHANNELS, conf.FEATURES_DIM).to(conf.DEVICE)

# Optimiseurs
opt_gen = optim.Adam(gen.parameters(), lr=conf.LR, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=conf.LR, betas=(0.0, 0.9))

# Bruit fixe pour visualiser la progression
fixed_noise = torch.randn(64, conf.Z_DIM, 1, 1).to(conf.DEVICE)

# ==========================================
# 5. REPRISE DE CHECKPOINT (Load)
# ==========================================
start_epoch = 0
if os.path.exists(ckpt_path_full):
    print(f"Chargement du checkpoint : {ckpt_path_full}")
    ckpt = torch.load(ckpt_path_full, map_location=conf.DEVICE)
    gen.load_state_dict(ckpt["gen_state"])
    critic.load_state_dict(ckpt["critic_state"])
    opt_gen.load_state_dict(ckpt["opt_gen_state"])
    opt_critic.load_state_dict(ckpt["opt_critic_state"])
    start_epoch = ckpt["epoch"] + 1
else:
    print("Aucun checkpoint trouvé. Démarrage de zéro.")

# ==========================================
# 6. BOUCLE D'ENTRAÎNEMENT
# ==========================================
print("Début de l'entraînement...")

for epoch in range(start_epoch, conf.NUM_EPOCHS):
    gen.train()
    critic.train()
    
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(conf.DEVICE)
        cur_batch_size = real.shape[0]

        # ---------------------
        # 1. Train Critique
        # ---------------------
        for _ in range(conf.CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, conf.Z_DIM, 1, 1).to(conf.DEVICE)
            fake = gen(noise)
            
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            
            gp = gradient_penalty(critic, real, fake, conf.DEVICE)
            
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + conf.LAMBDA_GP * gp
            )
            
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # ---------------------
        # 2. Train Générateur
        # ---------------------
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Log
        if batch_idx % 400 == 0:
            print(f"Epoch [{epoch}/{conf.NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} "
                  f"Loss D: {loss_critic:.4f}, Loss G: {loss_gen:.4f}")

    # ==========================================
    # SAUVEGARDE PÉRIODIQUE (PTH + IMAGES)
    # ==========================================
    if (epoch + 1) % conf.SAVE_EVERY == 0 or (epoch + 1) == conf.NUM_EPOCHS:
        print(f"--> Sauvegarde Epoch {epoch+1}")
        
        # A. Sauvegarde du modèle complet dans GANs/
        torch.save({
            "epoch": epoch,
            "gen_state": gen.state_dict(),
            "critic_state": critic.state_dict(),
            "opt_gen_state": opt_gen.state_dict(),
            "opt_critic_state": opt_critic.state_dict(),
        }, ckpt_path_full)
        
        # B. Génération et sauvegarde des images dans GANs/samples/
        gen.eval()
        with torch.no_grad():
            fake_img = gen(fixed_noise)
            # Dénormalisation [-1, 1] -> [0, 1]
            fake_img = (fake_img * 0.5) + 0.5
            
            save_name = f"epoch_{epoch+1}.png"
            save_path = os.path.join(conf.IMG_DIR, save_name)
            
            save_image(fake_img, save_path, nrow=8)
            print(f"    Checkpoint : {ckpt_path_full}")
            print(f"    Image : {save_path}")
        
        gen.train() # Retour en mode train

Entraînement sur : cuda
Dataset sera dans : c:\Users\alban\Documents\Cursor code\denoising-diffusion-model\dataset
Samples seront dans : c:\Users\alban\Documents\Cursor code\denoising-diffusion-model\model code\GANs\samples
Checkpoint sera : c:\Users\alban\Documents\Cursor code\denoising-diffusion-model\model code\GANs\wgan_mnist_ckpt.pth
Aucun checkpoint trouvé. Démarrage de zéro.
Début de l'entraînement...
Epoch [0/20] Batch 0/938 Loss D: -3.6450, Loss G: 0.9150
Epoch [0/20] Batch 400/938 Loss D: -21.4977, Loss G: 19.3148
Epoch [0/20] Batch 800/938 Loss D: -21.1267, Loss G: 20.0049
Epoch [1/20] Batch 0/938 Loss D: -21.4401, Loss G: 21.1709
Epoch [1/20] Batch 400/938 Loss D: -22.2393, Loss G: 22.2292
Epoch [1/20] Batch 800/938 Loss D: -21.7653, Loss G: 23.4041
Epoch [2/20] Batch 0/938 Loss D: -21.8310, Loss G: 24.3075
Epoch [2/20] Batch 400/938 Loss D: -5.6216, Loss G: 18.8907
Epoch [2/20] Batch 800/938 Loss D: -5.3838, Loss G: 19.6463
Epoch [3/20] Batch 0/938 Loss D: -5.0966, Loss G: