# Évaluation StyleGAN2 (FairFace) — FID / IS / LPIPS

Ce notebook :
1) charge le générateur (EMA si présent) depuis `runs/fairface_cpu_lite/last.pt`  
2) génère un échantillon synthétique  
3) calcule **FID**, **Inception Score**, et **LPIPS diversité**  
4) sauvegarde un grid d’images et un fichier `metrics.json`.

> Remarque : si le checkpoint ne contient **pas** `mapper` et `label_emb`, on échantillonne
directement **w ~ N(0,1)** (non conditionnel). Si `ckpt_*.pt` contient ces modules,
ils seront utilisés pour un **échantillonnage conditionnel**.

In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

In [3]:
!pip install torchmetrics lpips tqdm pillow scipy

Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
   ---------------------------------------- 0.0/983.2 kB ? eta -:--:--
   ---------------------------------------- 983.2/983.2 kB 23.3 MB/s  0:00:00
Downloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics

   -------------------- ------------------- 1/2 [torchmetrics]
   -------------------- ------------------- 1/2 [torchmetrics]
   -------------------- ------------------- 1/2 [torchmetrics]
   -------------------- ------------------- 1/2 [torchmetrics]
   -------------------- ------------------- 1/2 [torchmetrics]
   -------------------- ------------------- 1/2 [torchmetrics]
   -------------------- ------------------- 1/2 [torchmetrics

In [2]:
# --- Réglages Windows/CPU pour éviter l’erreur OpenMP et limiter les threads ---
import os, sys
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

# --- RÉPERTOIRE RACINE DU PROJET : ADAPTE SI BESOIN ---
ROOT = r"C:\Users\ilyes\Downloads\stylegan2_cond"

# Se placer dans la racine du projet et l’ajouter au PYTHONPATH
os.chdir(ROOT)
if ROOT not in sys.path:
    sys.path.insert(0, ROOT)

print("cwd:", os.getcwd())
print("facegan in path?", any(p.endswith("stylegan2_cond") for p in sys.path))


cwd: C:\Users\ilyes\Downloads\stylegan2_cond
facegan in path? True


In [3]:
import os, json, random
import numpy as np
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from torchvision.utils import save_image, make_grid
from PIL import Image

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import lpips  # LPIPS (réseau perceptuel)

# Import des modules du projet (ça doit maintenant marcher)
from facegan.models.generator import Generator, MappingNetwork
from facegan.models.discriminator import ProjectionDiscriminator  # pas utilisé pour l'éval
from facegan.data.dataset import FaceAttrsDataset  # pour récupérer les chemins réels

ROOT   = r"C:\Users\ilyes\Downloads\stylegan2_cond"
CSV    = os.path.join(ROOT, "attrs.csv")
CKPT   = os.path.join(ROOT, "runs", "fairface_cpu_lite", "last.pt")   # ou un ckpt_*.pt si tu préfères
OUTDIR = os.path.join(ROOT, "runs", "fairface_cpu_lite", "eval")
os.makedirs(OUTDIR, exist_ok=True)

DEVICE = torch.device("cpu")           # si tu as un GPU: torch.device("cuda")
SEED   = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Génération
IMG_SIZE      = 256
N_GEN         = 2000   # nb d'images synthétiques pour FID/IS (augmenter si tu peux)
BATCH_GEN     = 16
Z_DIM         = 128
W_DIM         = 256
LITE          = True   # = True pour correspondre à ton entraînement lite
BASE_CH       = 32 if LITE else 64

# LPIPS
LPIPS_PAIRS   = 400    # nb de paires pour diversité (augmente si tu veux)


In [4]:
# 2) Charge dataset réel (pour FID) — on lit les images en 256x256 puis on convertira pour FID
class RealImages(Dataset):
    def __init__(self, csv, image_size=256):
        import pandas as pd
        df = pd.read_csv(csv)
        self.paths = df["path"].tolist()
        self.t = T.Compose([
            T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),                    # [0,1]
        ])
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        img = Image.open(self.paths[i]).convert("RGB")
        return self.t(img)

real_ds = RealImages(CSV, IMG_SIZE)
# On ne prend que N_GEN images réelles pour rendre la comparaison équitable
subset_idx = np.random.choice(len(real_ds), size=min(N_GEN, len(real_ds)), replace=False)
real_loader = DataLoader(torch.utils.data.Subset(real_ds, subset_idx),
                         batch_size=BATCH_GEN, shuffle=False, num_workers=0)
len(real_loader.dataset), "real images used"


(2000, 'real images used')

In [5]:
# 3) Construit le générateur et charge le checkpoint
#    Si 'emaG' est présent -> on l'utilise; sinon 'G'.
#    Si le ckpt contient 'mapper'/'label_emb', on utilisera le sampling conditionnel.

# Embedding dims utilisés dans le train lite plus haut
D_AGE, D_GEN, D_ETH = 16, 8, 16

G = Generator(w_dim=W_DIM, base_ch=BASE_CH).to(DEVICE)
mapper = MappingNetwork(z_dim=Z_DIM, c_dim=D_AGE + D_GEN + D_ETH, w_dim=W_DIM,
                        n_layers=4 if LITE else 8).to(DEVICE)

# Labels embeddings (si on peut charger depuis ckpt)
label_emb = torch.nn.ModuleDict({
    "age": torch.nn.Embedding(5, D_AGE),
    "gen": torch.nn.Embedding(2, D_GEN),
    "eth": torch.nn.Embedding(7, D_ETH),
}).to(DEVICE)

ckpt = torch.load(CKPT, map_location=DEVICE)
loaded = []

if "emaG" in ckpt:
    G.load_state_dict(ckpt["emaG"]); loaded.append("emaG")
elif "G" in ckpt:
    G.load_state_dict(ckpt["G"]); loaded.append("G")

if "mapper" in ckpt:
    mapper.load_state_dict(ckpt["mapper"]); loaded.append("mapper")
else:
    mapper = None

if "label_emb" in ckpt:
    label_emb.load_state_dict(ckpt["label_emb"]); loaded.append("label_emb")
else:
    label_emb = None

print("Checkpoint keys chargés:", loaded)
G.eval()


Checkpoint keys chargés: ['emaG']


  ckpt = torch.load(CKPT, map_location=DEVICE)


Generator(
  (blocks): ModuleList(
    (0-2): 3 x StyledConv(
      (conv): ModulatedConv2d(
        (affine): Linear(in_features=256, out_features=256, bias=True)
      )
    )
    (3-4): 2 x StyledConv(
      (conv): ModulatedConv2d(
        (affine): Linear(in_features=256, out_features=128, bias=True)
      )
    )
    (5-6): 2 x StyledConv(
      (conv): ModulatedConv2d(
        (affine): Linear(in_features=256, out_features=64, bias=True)
      )
    )
    (7-13): 7 x StyledConv(
      (conv): ModulatedConv2d(
        (affine): Linear(in_features=256, out_features=32, bias=True)
      )
    )
  )
  (torgb): ToRGB(
    (conv): ModulatedConv2d(
      (affine): Linear(in_features=256, out_features=32, bias=True)
    )
  )
)

In [6]:
# 4) Utilitaires

def to_uint8(img):
    """
    img: tensor BxCxHxW dans [-1,1] ou [0,1].
    Retourne uint8 [0..255] (B,C,H,W).
    """
    if img.min() < 0:
        x = (img.clamp(-1,1) + 1) * 0.5
    else:
        x = img.clamp(0,1)
    return (x * 255).round().to(torch.uint8)

@torch.no_grad()
def sample_fake(batch, device, conditional=True):
    """ Génère batch images soit conditionnelles (si mapper+embeds connus), sinon w~N(0,1). """
    if conditional and (mapper is not None) and (label_emb is not None):
        # échantillonner labels uniformes (ou selon CSV si tu préfères)
        age = torch.randint(0,5,(batch,), device=device)
        gen = torch.randint(0,2,(batch,), device=device)
        eth = torch.randint(0,7,(batch,), device=device)
        z = torch.randn(batch, Z_DIM, device=device)
        c = torch.cat([label_emb["age"](age), label_emb["gen"](gen), label_emb["eth"](eth)], dim=1)
        w = mapper(z, c)
        x = G(w)
    else:
        # non conditionnel: w ~ N(0,1)
        w = torch.randn(batch, W_DIM, device=device)
        x = G(w)
    return x

# Sauvegarde un petit grid d'images générées pour sanity check
with torch.no_grad():
    sample = sample_fake(32, DEVICE, conditional=True)
grid = make_grid((sample.clamp(-1,1)+1)/2, nrow=8)
save_image(grid, os.path.join(OUTDIR, "samples_grid.png"))
os.path.join(OUTDIR, "samples_grid.png")


'C:\\Users\\ilyes\\Downloads\\stylegan2_cond\\runs\\fairface_cpu_lite\\eval\\samples_grid.png'

In [7]:
# 5) FID & IS
fid = FrechetInceptionDistance(feature=2048)   # torchmetrics gère le resize interne
isc = InceptionScore(splits=10, normalize=True)

# (A) accumulate REAL
for real in tqdm(real_loader, desc="FID: real"):
    # real: [B,3,H,W] en [0,1]
    fid.update((real*255).to(torch.uint8), real=True)

# (B) accumulate FAKE
n_done = 0
while n_done < N_GEN:
    b = min(BATCH_GEN, N_GEN - n_done)
    with torch.no_grad():
        fake = sample_fake(b, DEVICE, conditional=True)    # [-1,1]
    # FID attend uint8 0..255
    fid.update(to_uint8(fake).cpu(), real=False)

    # IS attend float [0,1]
    isc.update(((fake.clamp(-1,1)+1)/2).cpu())

    n_done += b

fid_score = fid.compute().item()
is_mean, is_std = isc.compute()
is_mean, is_std = float(is_mean), float(is_std)
print(f"FID: {fid_score:.3f} | IS: {is_mean:.3f} ± {is_std:.3f}")


Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to C:\Users\ilyes/.cache\torch\hub\checkpoints\weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:01<00:00, 66.3MB/s]
FID: real: 100%|██████████| 125/125 [06:56<00:00,  3.33s/it]


FID: 265.540 | IS: 2.682 ± 0.173


In [9]:
# 6) LPIPS diversité (moyenne sur paires synthétiques)
loss_fn = lpips.LPIPS(net='alex')  # plus léger
loss_fn = loss_fn.to(DEVICE).eval()

pairs = 0
lpips_vals = []
with torch.no_grad():
    while pairs < LPIPS_PAIRS:
        b = 2   # on génère 2 images et on calcule 1 distance
        imgs = sample_fake(b, DEVICE, conditional=True)  # [-1,1]
        d = loss_fn(imgs[0:1], imgs[1:2]).item()
        lpips_vals.append(d)
        pairs += 1

lpips_mean = float(np.mean(lpips_vals))
lpips_std  = float(np.std(lpips_vals))
print(f"LPIPS (diversité) mean={lpips_mean:.4f} ± {lpips_std:.4f}")


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: c:\Users\ilyes\anaconda3\Lib\site-packages\lpips\weights\v0.1\alex.pth
LPIPS (diversité) mean=0.4186 ± 0.1295
