# FID e Inception Score da LaSt-QGAN

Este notebook replica a metodologia de avaliação usada em `gans_classical_fid_is.ipynb`, mas agora aplicada à arquitetura LaSt-QGAN localizada na pasta `LaSt-QGAN-main`. Treinamos e avaliamos a LaSt-QGAN no MNIST, repetindo múltiplas execuções com sementes controladas e calculando FID e Inception Score em cada rodada.

In [3]:
import sys
!{sys.executable} -m pip install lightning


Collecting lightning
  Downloading lightning-2.5.6-py3-none-any.whl.metadata (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.4/42.4 kB[0m [31m435.5 kB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
Downloading lightning-2.5.6-py3-none-any.whl (827 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m827.9/827.9 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: lightning
Successfully installed lightning-2.5.6


In [2]:
import os
import sys
import random
import json

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

# Garantir que possamos importar os módulos da LaSt-QGAN
PROJECT_ROOT = os.path.abspath('LaSt-QGAN-main')
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

from utils import (
    DigitsDataset,
    GANModule,
    AutoencoderModule,
    build_model_from_config,
    parse_config,
    seed_everything,
)
#from train_qgan import QuantumGenerator

In [3]:
# Hiperparâmetros e caminhos
MNIST_CSV = os.path.join(PROJECT_ROOT, 'mnist.csv')
GAN_CONFIG_PATH = os.path.join(PROJECT_ROOT, 'gan.yaml')
AUTOENCODER_CONFIG_PATH = os.path.join(PROJECT_ROOT, 'autoencoder.yaml')

BATCH_SIZE = 128
LATENT_DIM = 6  # número de rotações usado pelo gerador quântico
NUM_EPOCHS = 50
NUM_TRAINING_RUNS = 3
NUM_EVAL_REPEATS = 3
BASE_SEED = 2024
LABELS = list(range(10))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Usando dispositivo:', device)

Usando dispositivo: cuda


In [15]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from argparse import ArgumentParser
import torch.optim as optim
import pennylane as qml
from tqdm import tqdm
import torch.nn as nn
from utils import *
import numpy as np
import warnings
import torch
import wandb
import yaml


class Config:
    def __init__(self, config_path):
        config = parse_config(config_path)
        for key, value in config.items():
            setattr(self, key, value)

    def generate_fields(self):
        self.discriminator = build_model_from_config(self.discriminator)
        self.l_device = [int(self.device[-1])]


class QuantumGenerator(nn.Module):
    """
    Versão independente de argparse / config.
    O device quântico é criado dentro da própria classe.
    """

    def __init__(self, n_qubits, n_rots, n_circuits, dropout):
        super().__init__()

        self.n_qubits = n_qubits
        self.n_rots = n_rots

        # Device do PennyLane atrelado a esse gerador
        dev = qml.device("default.qubit", wires=n_qubits)

        @qml.qnode(dev, interface="torch", diff_method="backprop")
        def quantum_circuit(weights):
            n_circuits_local = weights.size(0)
            n_qubits_local = weights.size(-1)

            for i in range(n_circuits_local):
                for q in range(n_qubits_local):
                    qml.RY(weights[i][0, q], wires=q)
                    qml.RZ(weights[i][1, q], wires=q)
                    qml.RY(weights[i][2, q], wires=q)
                    qml.RZ(weights[i][3, q], wires=q)

                for q in range(n_qubits_local):
                    qml.CRY(weights[i][4, q], wires=[q, (q + 1) % n_qubits_local])
                    qml.CRZ(weights[i][5, q], wires=[q, (q + 1) % n_qubits_local])

            return [qml.expval(qml.PauliX(q)) for q in range(n_qubits_local)] + [
                qml.expval(qml.PauliZ(q)) for q in range(n_qubits_local)
            ]

        self.quantum_circuit = quantum_circuit

        self.rot_params = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(
                        in_features=n_qubits,
                        out_features=n_qubits * n_rots,
                        bias=True,
                    ),
                    nn.Dropout(p=dropout),
                )
                for _ in range(n_circuits)
            ]
        )
        self.init_weights()

    def init_weights(self):
        for layer in self.rot_params:
            # o Sequential tem Linear dentro; vamos inicializar o Linear
            for sub in layer:
                if isinstance(sub, nn.Linear):
                    nn.init.uniform_(sub.weight, -0.01, 0.01)
                    nn.init.uniform_(sub.bias, -0.01, 0.01)

    def partial_measure(self, noise):
        rotations = torch.stack(
            [
                linear(noise.unsqueeze(0)).reshape(self.n_rots, self.n_qubits)
                for linear in self.rot_params
            ]
        )
        exps = self.quantum_circuit(rotations)
        exps = torch.stack(exps)
        return exps

    def forward(self, x):
        device = next(self.parameters()).device
        hidden_states = [self.partial_measure(elem) for elem in x]
        hidden_states = torch.stack(hidden_states).to(device)
        return hidden_states


def run_training(config_path: str, autoencoder_config_path: str):
    """
    Função que faz TODO o processo de treino.
    Pode ser chamada tanto via main() (CLI) quanto de outro script/notebook.
    """

    config = Config(config_path)
    config.generate_fields()
    autoencoder_config = Config(autoencoder_config_path)
    base_autoencoder = build_model_from_config(autoencoder_config.autoencoder)

    warnings.filterwarnings("ignore", category=UserWarning)
    torch.set_float32_matmul_precision("high")
    seed_everything(config.random_state)

    dataset = DigitsDataset(path_to_csv=config.path_to_mnist, label=range(10))
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=config.batch_size,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )

    # fixado como no seu código original
    config.n_rots = 6

    generator = QuantumGenerator(
        n_qubits=config.n_qubits,
        n_rots=config.n_rots,
        n_circuits=config.n_circuits,
        dropout=config.generator_dropout,
    ).double()

    autoencoder = AutoencoderModule.load_from_checkpoint(
        checkpoint_path=config.path_to_autoencoder,
        autoencoder=base_autoencoder,
        optimizer=autoencoder_config.optimizers,
    ).double()

    pushed_config = {
        k: v for k, v in dict(vars(config)).items() if not k.startswith("__")
    }

    wandb_logger = WandbLogger(
        project="QGAN", name=config.run_name, config=pushed_config
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath="./weights",
        filename="qgan-{epoch}",
        verbose=False,
        every_n_epochs=0,
        save_last=True,
    )

    gan = GANModule(
        alpha=config.alpha,
        n_qubits=config.n_qubits,
        n_rots=config.n_rots,
        autoencoder=autoencoder,
        generator=generator,
        discriminator=config.discriminator,
        optimizers_config=config.optimizers,
        step_disc_every_n_steps=config.step_disc_every_n_steps,
    ).double()

    trainer = l.Trainer(  # l vem do utils (lightning as l)
        accelerator="cuda",
        devices=config.l_device,
        max_epochs=config.epochs,
        enable_progress_bar=True,
        log_every_n_steps=config.log_every_n_steps,
        logger=wandb_logger,
        num_sanity_val_steps=0,
        fast_dev_run=config.debug,
        callbacks=[checkpoint_callback],
    )

    trainer.fit(model=gan, train_dataloaders=dataloader)
    wandb.finish()

In [16]:

def set_global_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def denormalize(imgs: torch.Tensor) -> torch.Tensor:
    return imgs * 0.5 + 0.5


def preprocess_for_inception(imgs: torch.Tensor) -> torch.Tensor:
    imgs = denormalize(imgs)
    imgs = imgs.clamp(0, 1)
    if imgs.size(1) == 1:
        imgs = imgs.repeat(1, 3, 1, 1)
    imgs = F.interpolate(imgs, size=(299, 299), mode='bilinear', align_corners=False)
    return imgs


def prepare_real_batches(dataset, *, label_target: int | None, batch_size: int = 64):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    for real in loader:
        if label_target is not None:
            # DigitsDataset retorna apenas as imagens; filtramos pelas linhas do csv que já são separadas por label
            # (as colunas são filtradas na criação do dataset).
            yield real
        else:
            yield real

In [17]:
def load_lastqgan_components():
    gan_cfg = parse_config(GAN_CONFIG_PATH)
    auto_cfg = parse_config(AUTOENCODER_CONFIG_PATH)

    discriminator = build_model_from_config(gan_cfg['discriminator'])
    autoencoder_model = build_model_from_config(auto_cfg['autoencoder'])

    autoencoder = AutoencoderModule(autoencoder=autoencoder_model, optimizer=auto_cfg['optimizers']).double()
    generator = QuantumGenerator(
        n_qubits=gan_cfg['n_qubits'],
        n_rots=6,
        n_circuits=gan_cfg['n_circuits'],
        dropout=gan_cfg['generator_dropout'],
    ).double()

    gan_module = GANModule(
        alpha=gan_cfg['alpha'],
        n_qubits=gan_cfg['n_qubits'],
        n_rots=6,
        autoencoder=autoencoder,
        generator=generator,
        discriminator=discriminator,
        optimizers_config=gan_cfg['optimizers'],
        step_disc_every_n_steps=gan_cfg['step_disc_every_n_steps'],
    ).double()

    return gan_module, autoencoder

In [18]:
def train_lastqgan(seed: int):
    # Função simples para ilustrar o treinamento sem depender da CLI
    import lightning as l

    seed_everything(seed)
    torch.set_float32_matmul_precision("high")
    gan_module, _ = load_lastqgan_components()

    dataset = DigitsDataset(path_to_csv=MNIST_CSV, label=LABELS)
    num_workers = max(1, min(8, (os.cpu_count() or 1) - 1))
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
        prefetch_factor=2,
    )

    trainer = l.Trainer(
        accelerator='cuda' if torch.cuda.is_available() else 'cpu',
        devices=1,
        max_epochs=NUM_EPOCHS,
        enable_progress_bar=True,
        logger=False,
        enable_checkpointing=False,
    )

    trainer.fit(model=gan_module, train_dataloaders=loader)
    return gan_module

def sample_from_lastqgan(generator, *, batch_size: int, latent_dim: int, device, label_target=None):
    # O gerador quântico opera no espaço latente de qubits; usamos ruído normal padrão.
    noise = torch.randn(batch_size, latent_dim, device=device, dtype=torch.double)
    with torch.no_grad():
        hidden_states = generator.generator.eval()(noise)
        decoded = generator.autoencoder.decode(hidden_states)
    return decoded.float()


def evaluate_generator(generator_module, *, label_target: int | None, latent_dim: int, device, dataset):
    fid = FrechetInceptionDistance(feature=64, normalize=True).to(device)
    is_metric = InceptionScore(normalize=True).to(device)
    generator_module = generator_module.to(device)
    generator_module.eval()

    with torch.no_grad():
        for real_batch in prepare_real_batches(dataset, label_target=label_target):
            real = real_batch.to(device).unsqueeze(1)
            real = preprocess_for_inception(real)
            batch_size = real.size(0)
            fake = sample_from_lastqgan(
                generator_module,
                batch_size=batch_size,
                latent_dim=latent_dim,
                device=device,
                label_target=label_target,
            )
            fake = preprocess_for_inception(fake)

            fid.update(real, real=True)
            fid.update(fake, real=False)
            is_metric.update(fake)

    fid_score = float(fid.compute())
    is_mean, is_std = [float(x) for x in is_metric.compute()]
    return fid_score, is_mean, is_std

In [None]:
results = []
dataset = DigitsDataset(path_to_csv=MNIST_CSV, label=LABELS)

for trial in range(NUM_TRAINING_RUNS):
    seed = BASE_SEED + trial
    print(f'===== Rodada {trial + 1}/{NUM_TRAINING_RUNS} - LaSt-QGAN =====')
    qgan_model = train_lastqgan(seed)

    for label in LABELS:
        for repeat in range(NUM_EVAL_REPEATS):
            set_global_seed(seed * 1000 + repeat)
            fid, is_mean, is_std = evaluate_generator(
                qgan_model,
                label_target=label,
                latent_dim=LATENT_DIM,
                device=device,
                dataset=dataset,
            )
            results.append(
                {
                    'Model': 'LaSt-QGAN',
                    'Label': label,
                    'Trial': trial,
                    'Repeat': repeat,
                    'FID': fid,
                    'IS_Mean': is_mean,
                    'IS_Std': is_std,
                }
            )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

===== Rodada 1/3 - LaSt-QGAN =====


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params | Mode 
------------------------------------------------------------
0 | autoencoder   | AutoencoderModule | 13.2 M | train
1 | generator     | QuantumGenerator  | 4.0 K  | train
2 | discriminator | Sequential        | 8.6 K  | train
3 | criterion     | WassersteinLoss   | 0      | train
4 | penalty_loss  | PenaltyLoss       | 0      | train
------------------------------------------------------------
13.2 M    Trainable params
0         Non-trainable params
13.2 M    Total params
52.704    Total estimated model params size (MB)
59        Modules in train mode
0         Modules in eval mode
/home/mahlow/anaconda3/envs/my_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value o

Training: |                                               | 0/? [00:00<?, ?it/s]

In [None]:
results_df = pd.DataFrame(results)
results_df

In [None]:
summary_by_model_label = (
    results_df
    .groupby(['Model', 'Label'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

summary_by_model = (
    results_df
    .groupby(['Model'])
    .agg(
        FID_mean=('FID', 'mean'),
        FID_std=('FID', 'std'),
        IS_mean_mean=('IS_Mean', 'mean'),
        IS_mean_std=('IS_Mean', 'std'),
        IS_std_mean=('IS_Std', 'mean'),
        IS_std_std=('IS_Std', 'std'),
    )
    .reset_index()
)

display(summary_by_model_label)
display(summary_by_model)