# FID e Inception Score da GAN quântica PatchQuantum

Este notebook replica a análise de FID e Inception Score realizada para as GANs clássicas, agora dedicada ao gerador PatchQuantum.
Treinamos um gerador por rótulo do BreastMNIST ao longo de várias rodadas, avaliando cada modelo em múltiplas repetições e reportando média e desvio padrão das métricas.


In [None]:
import random

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

from quantum_gan_medmnist import PatchQuantumGenerator, Discriminator, train_quantum_gan
from medmnist_data import load_medmnist_data


In [None]:
DATA_FLAG = 'breastmnist'
TARGET_IMG_SIZE = 8
BATCH_SIZE = 128
N_QUBITS = 5
N_A_QUBITS = 1
Q_DEPTH = 6
NUM_EPOCHS = 500
NUM_TRAINING_RUNS = 10
NUM_EVAL_REPEATS = 10
BASE_SEED = 2024

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_str = 'cuda' if torch.cuda.is_available() else 'cpu'

transform_lowres = transforms.Compose([
    transforms.Resize((TARGET_IMG_SIZE, TARGET_IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

bundle = load_medmnist_data(
    data_flag=DATA_FLAG,
    batch_size=BATCH_SIZE,
    download=True,
    transform=transform_lowres,
    shuffle_train=True,
)

train_dataset = bundle.train_dataset
test_dataset = bundle.test_dataset
label_names = {int(k): v for k, v in bundle.label_names.items()}
label_ids = sorted(label_names.keys())

patch_size = 2 ** (N_QUBITS - N_A_QUBITS)
if (TARGET_IMG_SIZE ** 2) % patch_size != 0:
    raise ValueError('target_img_size**2 deve ser múltiplo de patch_size para montar a imagem completa')
N_GENERATORS = (TARGET_IMG_SIZE ** 2) // patch_size
LATENT_DIM = N_QUBITS

def subset_by_label(dataset, label):
    indices = [i for i in range(len(dataset)) if int(dataset.labels[i]) == label]
    return Subset(dataset, indices)

train_loaders = {
    label: DataLoader(
        subset_by_label(train_dataset, label),
        batch_size=BATCH_SIZE,
        shuffle=True,
    )
    for label in label_ids
}


In [None]:
def set_global_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def denormalize(imgs: torch.Tensor) -> torch.Tensor:
    return imgs * 0.5 + 0.5


def preprocess_for_inception(imgs: torch.Tensor) -> torch.Tensor:
    imgs = denormalize(imgs)
    imgs = imgs.clamp(0, 1)
    if imgs.size(1) == 1:
        imgs = imgs.repeat(1, 3, 1, 1)
    imgs = F.interpolate(imgs, size=(299, 299), mode='bilinear', align_corners=False)
    return imgs


def sample_from_patchquantum(generator, *, batch_size: int, latent_dim: int, device, label_target=None):
    noise = torch.rand(batch_size, latent_dim, device=device) * (torch.pi / 2)
    return generator(noise)


def evaluate_generator(
    generator,
    *,
    label_target: int,
    latent_dim: int,
    sample_fn,
    device,
    dataset,
    batch_size: int = 64,
) -> tuple[float, float, float]:
    fid = FrechetInceptionDistance(feature=64, normalize=True).to(device)
    is_metric = InceptionScore(normalize=True).to(device)

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    generator = generator.to(device)
    generator.eval()

    has_samples = False
    with torch.no_grad():
        for real, labels in loader:
            mask = labels.squeeze() == label_target
            if mask.sum() == 0:
                continue

            real = real[mask].to(device)
            real = preprocess_for_inception(real)

            fake = sample_fn(
                generator,
                batch_size=real.size(0),
                latent_dim=latent_dim,
                device=device,
                label_target=label_target,
            )
            fake = preprocess_for_inception(fake)

            fid.update(real, real=True)
            fid.update(fake, real=False)
            is_metric.update(fake)
            has_samples = True

    if not has_samples:
        raise RuntimeError(f'Nenhuma amostra disponível para o rótulo {label_target}.')

    fid_score = fid.compute().item()
    is_mean, is_std = is_metric.compute()

    if device.type == 'cuda':
        generator.to('cpu')

    return fid_score, is_mean.item(), is_std.item()


In [None]:
def build_patchquantum_generator():
    return PatchQuantumGenerator(
        N_GENERATORS,
        TARGET_IMG_SIZE,
        n_qubits=N_QUBITS,
        n_a_qubits=N_A_QUBITS,
        q_depth=Q_DEPTH,
    )


def build_discriminator():
    return Discriminator(img_size=TARGET_IMG_SIZE)


def train_patchquantum_generators(seed: int):
    set_global_seed(seed)
    generators: dict[int, PatchQuantumGenerator] = {}

    for label in label_ids:
        generator = build_patchquantum_generator().to(device)
        discriminator = build_discriminator().to(device)

        train_quantum_gan(
            train_loaders[label],
            generator,
            discriminator,
            epochs=NUM_EPOCHS,
            device=device_str,
        )

        generator.eval()
        if device.type == 'cuda':
            generator = generator.to('cpu')
        generators[label] = generator

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return generators


In [None]:
results = []

for trial in range(NUM_TRAINING_RUNS):
    seed = BASE_SEED + trial
    print(f'===== Rodada {trial + 1}/{NUM_TRAINING_RUNS} - PatchQuantum =====')
    pq_generators = train_patchquantum_generators(seed)

    for label in label_ids:
        label_name = label_names.get(label, str(label))
        for repeat in range(NUM_EVAL_REPEATS):
            set_global_seed(seed * 1000 + repeat)
            fid, is_mean, is_std = evaluate_generator(
                pq_generators[label],
                label_target=label,
                latent_dim=LATENT_DIM,
                sample_fn=sample_from_patchquantum,
                device=device,
                dataset=test_dataset,
            )
            results.append(
                {
                    'Model': 'PatchQuantum',
                    'Label': label_name,
                    'Trial': trial,
                    'Repeat': repeat,
                    'FID': fid,
                    'IS_Mean': is_mean,
                    'IS_Std': is_std,
                }
            )

    if torch.cuda.is_available():
        torch.cuda.empty_cache()


In [None]:
results_df = pd.DataFrame(results)
results_df


In [None]:
summary_by_model_label = (
    results_df
    .groupby(['Model', 'Label'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

summary_by_model = (
    results_df
    .groupby(['Model'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

display(summary_by_model_label)
display(summary_by_model)
