# FID e Inception Score das GANs clássicas

Este notebook automatiza a comparação das três arquiteturas clássicas (DCGAN, CGAN e WGAN-GP) no BreastMNIST. Cada modelo é treinado 10 vezes; para cada gerador treinado executamos 10 avaliações independentes de FID e Inception Score. Reportamos média e desvio padrão das métricas em cada cenário.


In [None]:
import random

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

from classical_gans import (
    CGANDiscriminator,
    CGANGenerator,
    DCDiscriminator,
    DCGenerator,
    WGANGPCritic,
    WGANGPGenerator,
    train_cgan,
    train_gan_for_class,
    train_wgangp,
)
from medmnist_data import load_medmnist_data


In [None]:
DATA_FLAG = 'breastmnist'
BATCH_SIZE = 128
LATENT_DIM = 100
NUM_EPOCHS = 500
NUM_TRAINING_RUNS = 10
NUM_EVAL_REPEATS = 10
BASE_SEED = 2024

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bundle = load_medmnist_data(data_flag=DATA_FLAG, batch_size=BATCH_SIZE, download=True)
train_loader = bundle.train_loader
test_dataset = bundle.test_dataset
label_names = {int(k): v for k, v in bundle.label_names.items()}
label_ids = sorted(label_names.keys())
num_classes = bundle.num_classes
img_channels = bundle.train_dataset[0][0].shape[0]


In [None]:
def set_global_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def denormalize(imgs: torch.Tensor) -> torch.Tensor:
    return imgs * 0.5 + 0.5


def prepare_real_batches(dataset, *, label_target: int | None, batch_size: int = 64):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    for real, labels in loader:
        if label_target is not None:
            mask = labels.squeeze() == label_target
            if mask.sum() == 0:
                continue
            yield real[mask]
        else:
            yield real


def preprocess_for_inception(imgs: torch.Tensor) -> torch.Tensor:
    imgs = denormalize(imgs)
    imgs = imgs.clamp(0, 1)
    if imgs.size(1) == 1:
        imgs = imgs.repeat(1, 3, 1, 1)
    imgs = F.interpolate(imgs, size=(299, 299), mode='bilinear', align_corners=False)
    return imgs


def sample_from_dcgan(generator, *, batch_size: int, latent_dim: int, device, label_target=None):
    noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
    return generator(noise)


def sample_from_wgangp(generator, *, batch_size: int, latent_dim: int, device, label_target=None):
    noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
    return generator(noise)


def sample_from_cgan(generator, *, batch_size: int, latent_dim: int, device, label_target: int):
    noise = torch.randn(batch_size, latent_dim, device=device)
    labels = torch.full((batch_size,), label_target, device=device, dtype=torch.long)
    return generator(noise, labels)


def evaluate_generator(
    generator,
    *,
    label_target: int | None,
    latent_dim: int,
    sample_fn,
    device,
    dataset,
) -> tuple[float, float, float]:
    fid = FrechetInceptionDistance(feature=64, normalize=True).to(device)
    is_metric = InceptionScore(normalize=True).to(device)
    generator = generator.to(device)
    generator.eval()

    with torch.no_grad():
        for real_batch in prepare_real_batches(dataset, label_target=label_target):
            real = real_batch.to(device)
            real = preprocess_for_inception(real)
            batch_size = real.size(0)
            fake = sample_fn(
                generator,
                batch_size=batch_size,
                latent_dim=latent_dim,
                device=device,
                label_target=label_target,
            )
            fake = preprocess_for_inception(fake)
            fid.update(real, real=True)
            fid.update(fake, real=False)
            is_metric.update(fake)

    fid_score = fid.compute().item()
    is_mean, is_std = is_metric.compute()
    generator.to('cpu')
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return fid_score, is_mean.item(), is_std.item()


In [None]:
def train_dcgan_generators(seed: int):
    set_global_seed(seed)
    generators = {}
    for label in label_ids:
        G = DCGenerator(latent_dim=LATENT_DIM, img_channels=img_channels).to(device)
        D = DCDiscriminator(img_channels=img_channels).to(device)
        trained_G = train_gan_for_class(
            train_loader=train_loader,
            label_target=label,
            G=G,
            D=D,
            latent_dim=LATENT_DIM,
            num_epochs=NUM_EPOCHS,
            device=device,
        )
        generators[label] = trained_G.cpu()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    return generators


def train_cgan_model(seed: int):
    set_global_seed(seed)
    G = CGANGenerator(
        latent_dim=LATENT_DIM, num_classes=num_classes, img_channels=img_channels
    ).to(device)
    D = CGANDiscriminator(num_classes=num_classes, img_channels=img_channels).to(device)
    trained_G = train_cgan(
        train_loader=train_loader,
        G=G,
        D=D,
        latent_dim=LATENT_DIM,
        num_classes=num_classes,
        num_epochs=NUM_EPOCHS,
        device=device,
    )
    trained_G = trained_G.cpu()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return trained_G


def train_wgangp_generators(seed: int):
    set_global_seed(seed)
    generators = {}
    for label in label_ids:
        G = WGANGPGenerator(latent_dim=LATENT_DIM, img_channels=img_channels).to(device)
        D = WGANGPCritic(img_channels=img_channels).to(device)
        trained_G = train_wgangp(
            train_loader=train_loader,
            G=G,
            D=D,
            latent_dim=LATENT_DIM,
            num_epochs=NUM_EPOCHS,
            device=device,
            label_target=label,
        )
        generators[label] = trained_G.cpu()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    return generators


In [None]:
results = []

for trial in range(NUM_TRAINING_RUNS):
    seed = BASE_SEED + trial
    print(f'===== Rodada {trial + 1}/{NUM_TRAINING_RUNS} - DCGAN =====')
    dcgan_generators = train_dcgan_generators(seed)
    for label in label_ids:
        label_name = label_names.get(label, str(label))
        for repeat in range(NUM_EVAL_REPEATS):
            set_global_seed(seed * 1000 + repeat)
            fid, is_mean, is_std = evaluate_generator(
                dcgan_generators[label],
                label_target=label,
                latent_dim=LATENT_DIM,
                sample_fn=sample_from_dcgan,
                device=device,
                dataset=test_dataset,
            )
            results.append(
                {
                    'Model': 'DCGAN',
                    'Label': label_name,
                    'Trial': trial,
                    'Repeat': repeat,
                    'FID': fid,
                    'IS_Mean': is_mean,
                    'IS_Std': is_std,
                }
            )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print(f'===== Rodada {trial + 1}/{NUM_TRAINING_RUNS} - CGAN =====')
    cgan_generator = train_cgan_model(seed)
    for label in label_ids:
        label_name = label_names.get(label, str(label))
        for repeat in range(NUM_EVAL_REPEATS):
            set_global_seed(seed * 2000 + repeat)
            fid, is_mean, is_std = evaluate_generator(
                cgan_generator,
                label_target=label,
                latent_dim=LATENT_DIM,
                sample_fn=sample_from_cgan,
                device=device,
                dataset=test_dataset,
            )
            results.append(
                {
                    'Model': 'CGAN',
                    'Label': label_name,
                    'Trial': trial,
                    'Repeat': repeat,
                    'FID': fid,
                    'IS_Mean': is_mean,
                    'IS_Std': is_std,
                }
            )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print(f'===== Rodada {trial + 1}/{NUM_TRAINING_RUNS} - WGAN-GP =====')
    wgangp_generators = train_wgangp_generators(seed)
    for label in label_ids:
        label_name = label_names.get(label, str(label))
        for repeat in range(NUM_EVAL_REPEATS):
            set_global_seed(seed * 3000 + repeat)
            fid, is_mean, is_std = evaluate_generator(
                wgangp_generators[label],
                label_target=label,
                latent_dim=LATENT_DIM,
                sample_fn=sample_from_wgangp,
                device=device,
                dataset=test_dataset,
            )
            results.append(
                {
                    'Model': 'WGAN-GP',
                    'Label': label_name,
                    'Trial': trial,
                    'Repeat': repeat,
                    'FID': fid,
                    'IS_Mean': is_mean,
                    'IS_Std': is_std,
                }
            )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


In [None]:
results_df = pd.DataFrame(results)
results_df


In [None]:
summary_by_model_label = (
    results_df
    .groupby(['Model', 'Label'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

summary_by_model = (
    results_df
    .groupby(['Model'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

display(summary_by_model_label)
display(summary_by_model)
