In [1]:
import torch
import torch.nn as nn
import torchvision
import os
import gdown

In [2]:
# ==========================================
# 1. TÉLÉCHARGEMENT
# ==========================================
file_id = "1axjh_HshUWWEXoOLbf2L83WTtMKv0VkE"
out_path = "wgan_cifar_advanced_ckpt.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not os.path.exists(out_path):
    print(f"Téléchargement de {out_path}...")
    gdown.download(f"https://drive.google.com/uc?id={file_id}", out_path, quiet=False)
else:
    print("Fichier checkpoint déjà présent.")

Téléchargement de wgan_cifar_advanced_ckpt.pth...


Downloading...
From (original): https://drive.google.com/uc?id=1axjh_HshUWWEXoOLbf2L83WTtMKv0VkE
From (redirected): https://drive.google.com/uc?id=1axjh_HshUWWEXoOLbf2L83WTtMKv0VkE&confirm=t&uuid=25c2df4e-6e6b-4f17-82de-cda1b9089602
To: c:\Users\alban\Documents\Cursor code\denoising-diffusion-model\reproducibility\wgan_cifar_advanced_ckpt.pth
100%|██████████| 70.8M/70.8M [00:11<00:00, 6.21MB/s]


In [3]:
# ==========================================
# 2. ARCHITECTURE CORRIGÉE (Instance Norm)
# ==========================================

class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        return self.gamma * out + x

class GeneratorCorrected(nn.Module):
    def __init__(self, z_dim, channels_img, features_g=128):
        super(GeneratorCorrected, self).__init__()
        
        # Helper block: Upsample + Conv + INSTANCE NORM
        def block(in_channels, out_channels, normalize=True):
            layers = [
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            ]
            if normalize:
                # CORRECTION ICI : InstanceNorm2d au lieu de BatchNorm2d
                # affine=True permet d'apprendre des poids (weight/bias) comme BatchNorm
                # mais sans stocker running_mean/var
                layers.append(nn.InstanceNorm2d(out_channels, affine=True))
            layers.append(nn.ReLU())
            return layers

        self.initial = nn.Sequential(
            nn.ConvTranspose2d(z_dim, features_g * 4, 4, 1, 0),
            # CORRECTION ICI AUSSI
            nn.InstanceNorm2d(features_g * 4, affine=True),
            nn.ReLU(),
        )

        self.layer1 = nn.Sequential(
            *block(features_g * 4, features_g * 2),
        )
        
        self.attn = SelfAttention(features_g * 2)

        self.layer2 = nn.Sequential(
            *block(features_g * 2, features_g),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(features_g, channels_img, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.initial(x)
        out = self.layer1(out)
        out = self.attn(out)
        return self.layer2(out)

In [4]:
# ==========================================
# 3. CHARGEMENT ET GÉNÉRATION
# ==========================================
Z_DIM = 100
CHANNELS = 3
FEATURES_DIM = 128

# Instanciation du modèle corrigé
model = GeneratorCorrected(Z_DIM, CHANNELS, features_g=FEATURES_DIM).to(device)

print(f"Chargement du checkpoint : {out_path}")
ckpt = torch.load(out_path, map_location=device)

# Sélection du bon dictionnaire de poids
if "ema_state" in ckpt:
    state_dict = ckpt["ema_state"]
    print(">> Poids EMA détectés (Meilleure qualité)")
elif "gen_state" in ckpt:
    state_dict = ckpt["gen_state"]
    print(">> Poids standards détectés")
else:
    state_dict = ckpt

# Chargement
try:
    model.load_state_dict(state_dict)
    print(">> Poids chargés avec SUCCÈS ! (Architecture InstanceNorm validée)")
except RuntimeError as e:
    print(f"\nERREUR ENCORE PRÉSENTE : {e}")
    print("Essai de chargement avec strict=False (Risqué mais peut marcher)...")
    model.load_state_dict(state_dict, strict=False)

Chargement du checkpoint : wgan_cifar_advanced_ckpt.pth
>> Poids EMA détectés (Meilleure qualité)
>> Poids chargés avec SUCCÈS ! (Architecture InstanceNorm validée)


  ckpt = torch.load(out_path, map_location=device)


In [5]:
# Génération
model.eval()
num_samples = 16
noise = torch.randn(num_samples, Z_DIM, 1, 1).to(device)

with torch.no_grad():
    fake = model(noise)
    fake = (fake * 0.5) + 0.5 # Dénormalisation
    
    save_name = "cifar_fixed_result.png"
    torchvision.utils.save_image(fake, save_name, nrow=4)
    print(f"\nImage générée : {save_name}")


Image générée : cifar_fixed_result.png
