# Recursos computacionais do LaSt-QGAN

Este notebook replica a metodologia usada em `gans_classical_resources.ipynb`, agora aplicada ao modelo LaStQGAN.
O objetivo é medir custo de treinamento e de inferência, bem como o número de parâmetros do gerador, 
executando um ciclo enxuto por classe do MNIST.


In [None]:
import time
from statistics import mean

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pennylane as qml

import sys
from pathlib import Path

sys.path.append(str(Path('LaSt-QGAN-main').resolve()))
from utils import build_model_from_config, parse_config, seed_everything, DigitsDataset
from models import WassersteinLoss, PenaltyLoss


In [None]:
GAN_CONFIG_PATH = 'LaSt-QGAN-main/gan.yaml'
AUTOENCODER_CONFIG_PATH = 'LaSt-QGAN-main/autoencoder.yaml'

gan_cfg = parse_config(GAN_CONFIG_PATH)
ae_cfg = parse_config(AUTOENCODER_CONFIG_PATH)

BATCH_SIZE = gan_cfg.get('batch_size', 128)
NUM_EPOCHS = 5
LATENT_DIM = gan_cfg['n_qubits']
N_CIRCUITS = gan_cfg['n_circuits']
N_ROTS = 6
GEN_DROPOUT = gan_cfg.get('generator_dropout', 0.0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed_everything(gan_cfg.get('random_state', 0))

full_dataset = DigitsDataset(path_to_csv=gan_cfg['path_to_mnist'], label=range(10))
label_ids = list(range(10))


In [None]:
def count_parameters(model):
    return sum(param.numel() for param in model.parameters())


def build_autoencoder_modules():
    autoencoder = build_model_from_config(ae_cfg["autoencoder"]).double().to(device)
    encoder = autoencoder[0]
    decoder = autoencoder[1]
    return encoder, decoder


def make_discriminator():
    discriminator = build_model_from_config(gan_cfg["discriminator"]).double().to(device)
    return discriminator


def make_generator():
    return QuantumGenerator(
        n_qubits=LATENT_DIM,
        n_rots=N_ROTS,
        n_circuits=N_CIRCUITS,
        dropout=GEN_DROPOUT,
    ).double().to(device)


class QuantumGenerator(nn.Module):
    def __init__(self, n_qubits, n_rots, n_circuits, dropout):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_rots = n_rots
        self.n_circuits = n_circuits
        self.rot_params = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(in_features=n_qubits, out_features=n_qubits * n_rots, bias=True),
                    nn.Dropout(p=dropout),
                )
                for _ in range(n_circuits)
            ]
        )
        self.q_device = qml.device("default.qubit", wires=n_qubits)

        @qml.qnode(self.q_device, interface="torch", diff_method="backprop")
        def quantum_circuit(weights):
            n_circuits_local = weights.size(0)
            n_qubits_local = weights.size(-1)
            for i in range(n_circuits_local):
                for q in range(n_qubits_local):
                    qml.RY(weights[i][0, q], wires=q)
                    qml.RZ(weights[i][1, q], wires=q)
                    qml.RY(weights[i][2, q], wires=q)
                    qml.RZ(weights[i][3, q], wires=q)
                for q in range(n_qubits_local):
                    qml.CRY(weights[i][4, q], wires=[q, (q + 1) % n_qubits_local])
                    qml.CRZ(weights[i][5, q], wires=[q, (q + 1) % n_qubits_local])
            return [qml.expval(qml.PauliX(q)) for q in range(n_qubits_local)] + [
                qml.expval(qml.PauliZ(q)) for q in range(n_qubits_local)
            ]

        self.quantum_circuit = quantum_circuit
        self.init_weights()

    def init_weights(self):
        for layer in self.rot_params:
            for module in layer:
                if isinstance(module, nn.Linear):
                    nn.init.uniform_(module.weight, -0.01, 0.01)
                    nn.init.uniform_(module.bias, -0.01, 0.01)

    def partial_measure(self, noise):
        rotations = torch.stack(
            [linear(noise.unsqueeze(0)).reshape(self.n_rots, self.n_qubits) for linear in self.rot_params]
        )
        exps = self.quantum_circuit(rotations)
        exps = torch.stack(exps)
        return exps

    def forward(self, x):
        hidden_states = [self.partial_measure(elem) for elem in x]
        hidden_states = torch.stack(hidden_states).to(x.device)
        return hidden_states


def measure_inference_time(generator, decoder, *, latent_dim, num_runs=32):
    generator.eval()
    decoder.eval()
    sync = torch.cuda.synchronize if torch.cuda.is_available() else (lambda: None)
    with torch.no_grad():
        sync()
        start = time.time()
        for _ in range(num_runs):
            noise = torch.randn(1, latent_dim, device=device, dtype=torch.double)
            hidden = generator(noise)
            decoder(hidden)
        sync()
    return (time.time() - start) / num_runs


def average_inference_time(generators, decoders):
    tempos = []
    for label, generator in generators.items():
        decoder = decoders[label]
        tempos.append(measure_inference_time(generator, decoder, latent_dim=LATENT_DIM))
    return mean(tempos) if tempos else float("nan")


In [None]:


def train_lastqgan_for_label(label):
    encoder, decoder = build_autoencoder_modules()
    generator = make_generator()
    discriminator = make_discriminator()

    criterion = WassersteinLoss()
    penalty = PenaltyLoss(alpha=gan_cfg["alpha"])

    gen_opt_cfg = gan_cfg["optimizers"]["generator"]
    disc_opt_cfg = gan_cfg["optimizers"]["discriminator"]
    gen_optimizer = getattr(optim, gen_opt_cfg["type"])(generator.parameters(), **gen_opt_cfg["parameters"])
    disc_optimizer = getattr(optim, disc_opt_cfg["type"])(discriminator.parameters(), **disc_opt_cfg["parameters"])

    dataset = DigitsDataset(path_to_csv=gan_cfg["path_to_mnist"], label=label)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    start = time.time()
    for _ in range(NUM_EPOCHS):
        for batch in dataloader:
            batch = batch.unsqueeze(1).double().to(device)
            real_hidden = encoder(batch)
            noise = torch.randn(batch.size(0), LATENT_DIM, device=device, dtype=torch.double)
            fake_hidden = generator(noise)

            real_scores = discriminator(real_hidden)
            fake_scores = discriminator(fake_hidden.detach())
            loss_disc = criterion(real_scores, fake_scores)
            loss_disc = loss_disc + penalty(real_hidden, fake_hidden.detach(), discriminator)
            disc_optimizer.zero_grad()
            loss_disc.backward()
            disc_optimizer.step()

            fake_scores_for_gen = discriminator(fake_hidden)
            loss_gen = criterion(torch.zeros_like(fake_scores_for_gen), -fake_scores_for_gen)
            gen_optimizer.zero_grad()
            loss_gen.backward()
            gen_optimizer.step()
    total_time = time.time() - start
    return generator.eval(), decoder.eval(), total_time


def run_lastqgan():
    start_total = time.time()
    generators = {}
    decoders = {}
    per_label_training = []

    for label in label_ids:
        generator, decoder, tempo = train_lastqgan_for_label(label)
        generators[label] = generator
        decoders[label] = decoder
        per_label_training.append({
            "Classe": label,
            "Tempo_treinamento_classe_seg": tempo,
        })

    total_time = time.time() - start_total
    generator_params = count_parameters(next(iter(generators.values()))) if generators else 0
    inference_time = average_inference_time(generators, decoders)
    return {
        "GAN": "LaStQGAN",
        "Tempo_treinamento_seg": total_time,
        "Parametros_Gerador": generator_params,
        "Tempo_inferência_img_seg": inference_time,
    }, per_label_training


In [None]:
summary, per_label_training = run_lastqgan()

df_summary = pd.DataFrame([summary])
df_summary['Tempo_treinamento_min'] = df_summary['Tempo_treinamento_seg'] / 60
df_summary['Tempo_inferência_img_ms'] = df_summary['Tempo_inferência_img_seg'] * 1_000
df_summary


In [None]:
df_treinamento = pd.DataFrame(per_label_training)
df_treinamento['Tempo_treinamento_classe_min'] = df_treinamento['Tempo_treinamento_classe_seg'] / 60
df_treinamento
