# FID e Inception Score da GAN quântica MOSAIQ

Este notebook replica a análise de FID e Inception Score realizada para as GANs clássicas, agora dedicada ao gerador MOSAIQ.
Treinamos um gerador por rótulo do BreastMNIST ao longo de várias rodadas, avaliando cada modelo em múltiplas repetições e reportando média e desvio padrão das métricas.


In [None]:
import random

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

import medmnist
from medmnist import INFO

from quantum_gan_medmnist import (
    MosaiqQuantumGenerator,
    MosaiqDiscriminator,
    train_mosaiq_gan,
    create_mosaiq_pca_loaders,
    scale_data,
)
from medmnist_data import load_medmnist_data


In [None]:
DATA_FLAG = 'breastmnist'
TARGET_IMG_SIZE = 8
BATCH_SIZE = 128
MOSAIQ_TRAIN_BATCH_SIZE = 8
MOSAIQ_TRAIN_PIN_MEMORY = torch.cuda.is_available()
MOSAIQ_TRAIN_DROP_LAST = True
N_QUBITS = 5
Q_DEPTH = 6
N_GENERATORS = 8
PCA_DIMS = N_GENERATORS * N_QUBITS
NUM_EPOCHS = 30
NUM_TRAINING_RUNS = 3
NUM_EVAL_REPEATS = 3
BASE_SEED = 2024

MOSAIQ_SCALE_MIN = -1.0
MOSAIQ_SCALE_MAX = 1.0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_str = 'cuda' if torch.cuda.is_available() else 'cpu'

transform_lowres = transforms.Compose([
    transforms.Resize((TARGET_IMG_SIZE, TARGET_IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

transform_highres = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

bundle = load_medmnist_data(
    data_flag=DATA_FLAG,
    batch_size=BATCH_SIZE,
    download=True,
    transform=transform_lowres,
    shuffle_train=True,
)

train_dataset = bundle.train_dataset
test_dataset = bundle.test_dataset
label_names = {int(k): v for k, v in bundle.label_names.items()}
label_ids = sorted(label_names.keys())

dataset_class = getattr(medmnist, INFO[DATA_FLAG]["python_class"])
train_dataset_highres = dataset_class(
    split='train',
    transform=transform_highres,
    download=True,
)

train_loaders, train_tensor_pca, train_labels, mosaiq_pca_model = create_mosaiq_pca_loaders(
    train_dataset_highres,
    batch_size=MOSAIQ_TRAIN_BATCH_SIZE,
    target_size=TARGET_IMG_SIZE,
    pca_dims=PCA_DIMS,
    drop_last=MOSAIQ_TRAIN_DROP_LAST,
    pin_memory=MOSAIQ_TRAIN_PIN_MEMORY,
)

if hasattr(train_dataset_highres, "imgs"):
    base_imgs = train_dataset_highres.imgs
elif hasattr(train_dataset_highres, "data"):
    base_imgs = train_dataset_highres.data
else:
    raise AttributeError("Dataset precisa expor os atributos `imgs` ou `data`.")

base_tensor = torch.as_tensor(base_imgs).float()
if base_tensor.ndim == 4 and base_tensor.shape[-1] == 1:
    base_tensor = base_tensor.permute(0, 3, 1, 2)
elif base_tensor.ndim == 3:
    base_tensor = base_tensor.unsqueeze(1)
if base_tensor.max() > 1:
    base_tensor = base_tensor / 255.0

lowres_tensor = F.interpolate(
    base_tensor,
    size=(TARGET_IMG_SIZE, TARGET_IMG_SIZE),
    mode='bilinear',
    align_corners=False,
)
flat_imgs = lowres_tensor.reshape(lowres_tensor.size(0), -1).cpu().numpy()
MOSAIQ_FLAT_MIN = float(flat_imgs.min())
MOSAIQ_FLAT_MAX = float(flat_imgs.max())
scaled_inputs = scale_data(flat_imgs, (0.0, 1.0))
pca_data = mosaiq_pca_model.transform(scaled_inputs)
MOSAIQ_PCA_MIN = float(pca_data.min())
MOSAIQ_PCA_MAX = float(pca_data.max())


In [None]:
def set_global_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def denormalize(imgs: torch.Tensor) -> torch.Tensor:
    return imgs * 0.5 + 0.5


def preprocess_for_inception(imgs: torch.Tensor) -> torch.Tensor:
    imgs = denormalize(imgs)
    imgs = imgs.clamp(0, 1)
    if imgs.size(1) == 1:
        imgs = imgs.repeat(1, 3, 1, 1)
    imgs = F.interpolate(imgs, size=(299, 299), mode='bilinear', align_corners=False)
    return imgs


def mosaiq_features_to_images(features: torch.Tensor) -> torch.Tensor:
    if MOSAIQ_PCA_MAX <= MOSAIQ_PCA_MIN:
        raise ValueError('Intervalo inválido para reescalar as componentes da PCA.')
    pca_span = MOSAIQ_PCA_MAX - MOSAIQ_PCA_MIN
    features_np = features.detach().cpu().numpy()
    unscaled = ((features_np - MOSAIQ_SCALE_MIN) / (MOSAIQ_SCALE_MAX - MOSAIQ_SCALE_MIN)) * pca_span + MOSAIQ_PCA_MIN
    reconstructed = mosaiq_pca_model.inverse_transform(unscaled)
    flat_span = MOSAIQ_FLAT_MAX - MOSAIQ_FLAT_MIN
    if flat_span == 0:
        flat = np.full_like(reconstructed, MOSAIQ_FLAT_MIN)
    else:
        flat = reconstructed * flat_span + MOSAIQ_FLAT_MIN
    imgs = torch.from_numpy(flat).float().view(-1, 1, TARGET_IMG_SIZE, TARGET_IMG_SIZE)
    imgs = (imgs - 0.5) / 0.5
    return imgs


def sample_from_mosaiq(generator, *, batch_size: int, latent_dim: int, device, label_target=None):
    noise = torch.rand(batch_size, latent_dim, device=device) * (torch.pi / 2)
    features = generator(noise)
    imgs = mosaiq_features_to_images(features)
    return imgs.to(device)


def evaluate_generator(
    generator,
    *,
    label_target: int,
    latent_dim: int,
    sample_fn,
    device,
    dataset,
    batch_size: int = 64,
) -> tuple[float, float, float]:
    fid = FrechetInceptionDistance(feature=64, normalize=True).to(device)
    is_metric = InceptionScore(normalize=True).to(device)

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    generator = generator.to(device)
    generator.eval()

    has_samples = False
    with torch.no_grad():
        for real, labels in loader:
            mask = labels.squeeze() == label_target
            if mask.sum() == 0:
                continue

            real = real[mask].to(device)
            real = preprocess_for_inception(real)

            fake = sample_fn(
                generator,
                batch_size=real.size(0),
                latent_dim=latent_dim,
                device=device,
                label_target=label_target,
            )
            fake = preprocess_for_inception(fake)

            fid.update(real, real=True)
            fid.update(fake, real=False)
            is_metric.update(fake)
            has_samples = True

    if not has_samples:
        raise RuntimeError(f'Nenhuma amostra disponível para o rótulo {label_target}.')

    fid_score = fid.compute().item()
    is_mean, is_std = is_metric.compute()

    if device.type == "cuda":
        generator.to('cpu')

    return fid_score, is_mean.item(), is_std.item()


In [None]:
def build_mosaiq_generator():
    return MosaiqQuantumGenerator(
        N_GENERATORS,
        N_QUBITS,
        Q_DEPTH,
    )


def build_mosaiq_discriminator():
    return MosaiqDiscriminator(input_dim=PCA_DIMS)


def train_mosaiq_generators(seed: int, trial_index: int, num_trials: int):
    set_global_seed(seed)
    generators: dict[int, MosaiqQuantumGenerator] = {}

    for label in label_ids:
        label_name = label_names.get(label, str(label))
        print(
            f"[Rodada {trial_index + 1}/{num_trials}] Iniciando treinamento para o rótulo {label_name} (seed={seed})"
        )

        generator = build_mosaiq_generator().to(device)
        discriminator = build_mosaiq_discriminator().to(device)

        def log_progress(epoch: int, total_epochs: int, d_loss: float, g_loss: float, *, label_name=label_name):
            print(
                f"[Rodada {trial_index + 1}/{num_trials}] [Rótulo {label_name}] Época {epoch}/{total_epochs} - "
                f"Loss_D={d_loss:.4f} Loss_G={g_loss:.4f}"
            )

        train_mosaiq_gan(
            train_loaders[label],
            generator,
            discriminator,
            epochs=NUM_EPOCHS,
            device=device_str,
            progress_callback=log_progress,
        )

        print(f"[Rodada {trial_index + 1}/{num_trials}] Finalizado treinamento para o rótulo {label_name}")

        generator.eval()
        generator = generator.to('cpu')
        generators[label] = generator

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return generators


In [None]:
results = []

for trial in range(NUM_TRAINING_RUNS):
    seed = BASE_SEED + trial
    print(f'===== Rodada {trial + 1}/{NUM_TRAINING_RUNS} - MOSAIQ =====')
    mosaiq_generators = train_mosaiq_generators(seed, trial, NUM_TRAINING_RUNS)

    for label in label_ids:
        label_name = label_names.get(label, str(label))
        for repeat in range(NUM_EVAL_REPEATS):
            set_global_seed(seed * 1000 + repeat)
            fid, is_mean, is_std = evaluate_generator(
                mosaiq_generators[label],
                label_target=label,
                latent_dim=N_QUBITS,
                sample_fn=sample_from_mosaiq,
                device=device,
                dataset=test_dataset,
            )
            results.append(
                {
                    'Model': 'MOSAIQ',
                    'Label': label_name,
                    'Trial': trial,
                    'Repeat': repeat,
                    'FID': fid,
                    'IS_Mean': is_mean,
                    'IS_Std': is_std,
                }
            )

    if torch.cuda.is_available():
        torch.cuda.empty_cache()


In [None]:
results_df = pd.DataFrame(results)
results_df


In [None]:
summary_by_model_label = (
    results_df
    .groupby(['Model', 'Label'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

summary_by_model = (
    results_df
    .groupby(['Model'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

display(summary_by_model_label)
display(summary_by_model)
