In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
import time

In [2]:
# ==========================================
# 1. CONFIGURATION
# ==========================================
class Config:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Hyperparamètres WGAN-GP Avancé
    LR = 1e-4
    BATCH_SIZE = 64
    IMAGE_SIZE = 32
    CHANNELS = 3
    Z_DIM = 100
    NUM_EPOCHS = 100        # CIFAR demande du temps
    FEATURES_DIM = 128      # Augmenté pour la capacité (Symétrie G/C)
    CRITIC_ITERATIONS = 5
    LAMBDA_GP = 10
    EMA_DECAY = 0.999       # Lissage des poids
    
    # Chemins
    DATA_ROOT = os.path.join("..", "..", "dataset", "CIFAR")
    IMG_DIR = "samples-cifar"
    CKPT_NAME = "wgan_cifar_advanced_ckpt.pth"
    SAVE_EVERY = 5

conf = Config()
print(f"Entraînement sur : {conf.DEVICE}")

# Création des dossiers
os.makedirs(conf.IMG_DIR, exist_ok=True)
os.makedirs(conf.DATA_ROOT, exist_ok=True)
ckpt_path_full = conf.CKPT_NAME

Entraînement sur : cuda


In [3]:
# ==========================================
# 2. MODULES UTILITAIRES (EMA & ATTENTION)
# ==========================================
class EMA:
    """Exponential Moving Average pour stabiliser le générateur."""
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data
                self.shadow[name] = self.shadow[name].to(conf.DEVICE)
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
    
    # Nécessaire pour sauvegarder l'état de l'EMA dans le .pth
    def state_dict(self):
        return self.shadow
    
    def load_state_dict(self, state_dict):
        self.shadow = state_dict

class SelfAttention(nn.Module):
    """Bloc d'attention pour capturer les dépendances globales."""
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        return self.gamma * out + x

In [4]:
# ==========================================
# 3. ARCHITECTURES AVANCÉES
# ==========================================
class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            # 32x32 -> 16x16
            nn.Conv2d(channels_img, features_d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            
            # 16x16 -> 8x8
            nn.Conv2d(features_d, features_d * 2, 4, 2, 1),
            nn.InstanceNorm2d(features_d * 2, affine=True), # InstanceNorm
            nn.LeakyReLU(0.2),
        )
        
        # Attention à la résolution 8x8
        self.attn = SelfAttention(features_d * 2)
        
        self.final_layers = nn.Sequential(
            # 8x8 -> 4x4
            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1),
            nn.InstanceNorm2d(features_d * 4, affine=True),
            nn.LeakyReLU(0.2),
            # 4x4 -> 1x1
            nn.Conv2d(features_d * 4, 1, 4, 1, 0),
        )

    def forward(self, x):
        out = self.critic(x)
        out = self.attn(out)
        return self.final_layers(out)

class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        
        # Bloc helper : Upsample + Conv (Anti-Checkerboard)
        def block(in_channels, out_channels, normalize=True):
            layers = [
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            ]
            if normalize:
                layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU())
            return layers

        self.initial = nn.Sequential(
            # Z -> 4x4
            nn.ConvTranspose2d(z_dim, features_g * 4, 4, 1, 0),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(),
        )

        self.layer1 = nn.Sequential(
            # 4x4 -> 8x8
            *block(features_g * 4, features_g * 2),
        )
        
        # Attention à la résolution 8x8
        self.attn = SelfAttention(features_g * 2)

        self.layer2 = nn.Sequential(
            # 8x8 -> 16x16
            *block(features_g * 2, features_g),
            # 16x16 -> 32x32
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(features_g, channels_img, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.initial(x)
        out = self.layer1(out)
        out = self.attn(out)
        return self.layer2(out)

In [5]:
# ==========================================
# 4. GRADIENT PENALTY
# ==========================================
def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated = real * epsilon + fake * (1 - epsilon)
    interpolated.requires_grad_(True)
    mixed_scores = critic(interpolated)
    gradient = torch.autograd.grad(
        inputs=interpolated,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    return torch.mean((gradient_norm - 1) ** 2)

In [6]:
# ==========================================
# 5. INITIALISATION
# ==========================================
transform = transforms.Compose([
    transforms.Resize(conf.IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

dataset = torchvision.datasets.CIFAR10(root=conf.DATA_ROOT, train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=conf.BATCH_SIZE, shuffle=True)

# Instanciation avec FEATURES_DIM = 128
gen = Generator(conf.Z_DIM, conf.CHANNELS, conf.FEATURES_DIM).to(conf.DEVICE)
critic = Critic(conf.CHANNELS, conf.FEATURES_DIM).to(conf.DEVICE)

opt_gen = optim.Adam(gen.parameters(), lr=conf.LR, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=conf.LR, betas=(0.0, 0.9))

# Initialisation EMA
ema = EMA(gen, decay=conf.EMA_DECAY)
ema.register()

fixed_noise = torch.randn(64, conf.Z_DIM, 1, 1).to(conf.DEVICE)

# Reprise de Checkpoint
start_epoch = 0
if os.path.exists(ckpt_path_full):
    print(f"Chargement du checkpoint : {ckpt_path_full}")
    ckpt = torch.load(ckpt_path_full, map_location=conf.DEVICE)
    gen.load_state_dict(ckpt["gen_state"])
    critic.load_state_dict(ckpt["critic_state"])
    opt_gen.load_state_dict(ckpt["opt_gen_state"])
    opt_critic.load_state_dict(ckpt["opt_critic_state"])
    # Chargement état EMA
    if "ema_state" in ckpt:
        ema.load_state_dict(ckpt["ema_state"])
    start_epoch = ckpt["epoch"] + 1
else:
    print("Démarrage de zéro.")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ..\..\dataset\CIFAR\cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:18<00:00, 9.36MB/s] 


Extracting ..\..\dataset\CIFAR\cifar-10-python.tar.gz to ..\..\dataset\CIFAR
Démarrage de zéro.


In [7]:
# ==========================================
# 6. BOUCLE D'ENTRAÎNEMENT + TIMING
# ==========================================
print("Début de l'entraînement CIFAR-10 Avancé...")

for epoch in range(start_epoch, conf.NUM_EPOCHS):
    gen.train()
    critic.train()
    
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(conf.DEVICE)
        cur_batch_size = real.shape[0]

        # --- Train Critic ---
        for _ in range(conf.CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, conf.Z_DIM, 1, 1).to(conf.DEVICE)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, conf.DEVICE)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) + conf.LAMBDA_GP * gp
            
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # --- Train Generator ---
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        # Mise à jour EMA
        ema.update()

    # ==========================================
    # SAUVEGARDE, GÉNÉRATION ET CHRONOMÉTRAGE
    # ==========================================
    if (epoch + 1) % conf.SAVE_EVERY == 0 or (epoch + 1) == conf.NUM_EPOCHS:
        print(f"--> Epoch {epoch+1}")
        
        # 1. Sauvegarde PTH (incluant EMA)
        torch.save({
            "epoch": epoch,
            "gen_state": gen.state_dict(),
            "critic_state": critic.state_dict(),
            "opt_gen_state": opt_gen.state_dict(),
            "opt_critic_state": opt_critic.state_dict(),
            "ema_state": ema.state_dict(), # On sauvegarde les poids lissés
        }, ckpt_path_full)
        
        # 2. Utilisation de l'EMA pour la génération et le timing
        ema.apply_shadow() # On charge les poids lissés dans le modèle
        gen.eval()
        
        # --- CHRONOMÈTRE (FORWARD) ---
        if torch.cuda.is_available(): torch.cuda.synchronize()
        start_time = time.time()
        
        with torch.no_grad():
            fake_img = gen(fixed_noise) # FORWARD PASS
            
        if torch.cuda.is_available(): torch.cuda.synchronize()
        end_time = time.time()
        forward_duration = end_time - start_time
        
        print(f"    [TIMING] Temps de Forward (EMA) pour {conf.BATCH_SIZE} images : {forward_duration:.5f} sec")
        print(f"    [TIMING] Temps moyen par image : {forward_duration/conf.BATCH_SIZE:.6f} sec")

        # 3. Sauvegarde Image
        fake_img = (fake_img * 0.5) + 0.5
        save_path = os.path.join(conf.IMG_DIR, f"epoch_{epoch+1}.png")
        save_image(fake_img, save_path, nrow=8)
        print(f"    Image sauvegardée : {save_path}")
        
        # 4. Restauration
        ema.restore() # On remet les poids normaux pour continuer l'entraînement
        gen.train()

Début de l'entraînement CIFAR-10 Avancé...
--> Epoch 5
    [TIMING] Temps de Forward (EMA) pour 64 images : 0.00377 sec
    [TIMING] Temps moyen par image : 0.000059 sec
    Image sauvegardée : samples-cifar\epoch_5.png
--> Epoch 10
    [TIMING] Temps de Forward (EMA) pour 64 images : 0.00218 sec
    [TIMING] Temps moyen par image : 0.000034 sec
    Image sauvegardée : samples-cifar\epoch_10.png
--> Epoch 15
    [TIMING] Temps de Forward (EMA) pour 64 images : 0.00103 sec
    [TIMING] Temps moyen par image : 0.000016 sec
    Image sauvegardée : samples-cifar\epoch_15.png
--> Epoch 20
    [TIMING] Temps de Forward (EMA) pour 64 images : 0.00299 sec
    [TIMING] Temps moyen par image : 0.000047 sec
    Image sauvegardée : samples-cifar\epoch_20.png
--> Epoch 25
    [TIMING] Temps de Forward (EMA) pour 64 images : 0.00269 sec
    [TIMING] Temps moyen par image : 0.000042 sec
    Image sauvegardée : samples-cifar\epoch_25.png
--> Epoch 30
    [TIMING] Temps de Forward (EMA) pour 64 images 