# FID e Inception Score da LaSt-QGAN

Este notebook replica a metodologia de avaliação usada em `gans_classical_fid_is.ipynb`, mas agora aplicada à arquitetura LaSt-QGAN localizada na pasta `LaSt-QGAN-main`. Treinamos e avaliamos a LaSt-QGAN no MNIST, repetindo múltiplas execuções com sementes controladas e calculando FID e Inception Score em cada rodada.

In [None]:
import os
import sys
import random
import json

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

# Garantir que possamos importar os módulos da LaSt-QGAN
PROJECT_ROOT = os.path.abspath('LaSt-QGAN-main')
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

from utils import (
    DigitsDataset,
    GANModule,
    AutoencoderModule,
    build_model_from_config,
    parse_config,
    seed_everything,
)
from train_qgan import QuantumGenerator

In [None]:
# Hiperparâmetros e caminhos
MNIST_CSV = os.path.join(PROJECT_ROOT, 'mnist.csv')
GAN_CONFIG_PATH = os.path.join(PROJECT_ROOT, 'gan.yaml')
AUTOENCODER_CONFIG_PATH = os.path.join(PROJECT_ROOT, 'autoencoder.yaml')

BATCH_SIZE = 128
LATENT_DIM = 6  # número de rotações usado pelo gerador quântico
NUM_EPOCHS = 50
NUM_TRAINING_RUNS = 3
NUM_EVAL_REPEATS = 3
BASE_SEED = 2024
LABELS = list(range(10))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Usando dispositivo:', device)

In [None]:
def set_global_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def denormalize(imgs: torch.Tensor) -> torch.Tensor:
    return imgs * 0.5 + 0.5


def preprocess_for_inception(imgs: torch.Tensor) -> torch.Tensor:
    imgs = denormalize(imgs)
    imgs = imgs.clamp(0, 1)
    if imgs.size(1) == 1:
        imgs = imgs.repeat(1, 3, 1, 1)
    imgs = F.interpolate(imgs, size=(299, 299), mode='bilinear', align_corners=False)
    return imgs


def prepare_real_batches(dataset, *, label_target: int | None, batch_size: int = 64):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    for real in loader:
        if label_target is not None:
            # DigitsDataset retorna apenas as imagens; filtramos pelas linhas do csv que já são separadas por label
            # (as colunas são filtradas na criação do dataset).
            yield real
        else:
            yield real

In [None]:
def load_lastqgan_components():
    gan_cfg = parse_config(GAN_CONFIG_PATH)
    auto_cfg = parse_config(AUTOENCODER_CONFIG_PATH)

    discriminator = build_model_from_config(gan_cfg['discriminator'])
    autoencoder_model = build_model_from_config(auto_cfg['autoencoder'])

    autoencoder = AutoencoderModule(autoencoder=autoencoder_model, optimizer=auto_cfg['optimizers']).double()
    generator = QuantumGenerator(
        n_qubits=gan_cfg['n_qubits'],
        n_rots=6,
        n_circuits=gan_cfg['n_circuits'],
        dropout=gan_cfg['generator_dropout'],
    ).double()

    gan_module = GANModule(
        alpha=gan_cfg['alpha'],
        n_qubits=gan_cfg['n_qubits'],
        n_rots=6,
        autoencoder=autoencoder,
        generator=generator,
        discriminator=discriminator,
        optimizers_config=gan_cfg['optimizers'],
        step_disc_every_n_steps=gan_cfg['step_disc_every_n_steps'],
    ).double()

    return gan_module, autoencoder

In [None]:
def train_lastqgan(seed: int):
    # Função simples para ilustrar o treinamento sem depender da CLI
    import lightning as l

    seed_everything(seed)
    gan_module, _ = load_lastqgan_components()

    dataset = DigitsDataset(path_to_csv=MNIST_CSV, label=LABELS)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True)

    trainer = l.Trainer(
        accelerator='cuda' if torch.cuda.is_available() else 'cpu',
        devices=1,
        max_epochs=NUM_EPOCHS,
        enable_progress_bar=True,
        logger=False,
        enable_checkpointing=False,
    )

    trainer.fit(model=gan_module, train_dataloaders=loader)
    return gan_module


def sample_from_lastqgan(generator, *, batch_size: int, latent_dim: int, device, label_target=None):
    # O gerador quântico opera no espaço latente de qubits; usamos ruído normal padrão.
    noise = torch.randn(batch_size, latent_dim, device=device, dtype=torch.double)
    with torch.no_grad():
        hidden_states = generator.generator.eval()(noise)
        decoded = generator.autoencoder.decode(hidden_states)
    return decoded.float()


def evaluate_generator(generator_module, *, label_target: int | None, latent_dim: int, device, dataset):
    fid = FrechetInceptionDistance(feature=64, normalize=True).to(device)
    is_metric = InceptionScore(normalize=True).to(device)
    generator_module = generator_module.to(device)
    generator_module.eval()

    with torch.no_grad():
        for real_batch in prepare_real_batches(dataset, label_target=label_target):
            real = real_batch.to(device).unsqueeze(1)
            real = preprocess_for_inception(real)
            batch_size = real.size(0)
            fake = sample_from_lastqgan(
                generator_module,
                batch_size=batch_size,
                latent_dim=latent_dim,
                device=device,
                label_target=label_target,
            )
            fake = preprocess_for_inception(fake)

            fid.update(real, real=True)
            fid.update(fake, real=False)
            is_metric.update(fake)

    fid_score = float(fid.compute())
    is_mean, is_std = [float(x) for x in is_metric.compute()]
    return fid_score, is_mean, is_std

In [None]:
results = []
dataset = DigitsDataset(path_to_csv=MNIST_CSV, label=LABELS)

for trial in range(NUM_TRAINING_RUNS):
    seed = BASE_SEED + trial
    print(f'===== Rodada {trial + 1}/{NUM_TRAINING_RUNS} - LaSt-QGAN =====')
    qgan_model = train_lastqgan(seed)

    for label in LABELS:
        for repeat in range(NUM_EVAL_REPEATS):
            set_global_seed(seed * 1000 + repeat)
            fid, is_mean, is_std = evaluate_generator(
                qgan_model,
                label_target=label,
                latent_dim=LATENT_DIM,
                device=device,
                dataset=dataset,
            )
            results.append(
                {
                    'Model': 'LaSt-QGAN',
                    'Label': label,
                    'Trial': trial,
                    'Repeat': repeat,
                    'FID': fid,
                    'IS_Mean': is_mean,
                    'IS_Std': is_std,
                }
            )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
results_df = pd.DataFrame(results)
results_df

In [None]:
summary_by_model_label = (
    results_df
    .groupby(['Model', 'Label'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

summary_by_model = (
    results_df
    .groupby(['Model'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

display(summary_by_model_label)
display(summary_by_model)