# Recursos computacionais das GANs clássicas

Este notebook resume o custo de treinamento e de inferência das três arquiteturas de GAN usadas nos experimentos clássicos (DCGAN, CGAN e WGAN-GP).
A ideia é executar apenas o ciclo de treinamento/inferência e registrar métricas essenciais de tempo e tamanho do gerador.


In [None]:
import time
from statistics import mean

import pandas as pd
import torch

from classical_gans import (
    CGANDiscriminator,
    CGANGenerator,
    DCDiscriminator,
    DCGenerator,
    WGANGPCritic,
    WGANGPGenerator,
    train_cgan,
    train_gan_for_class,
    train_wgangp,
)
from medmnist_data import load_medmnist_data


In [None]:
DATA_FLAG = 'breastmnist'
BATCH_SIZE = 128
LATENT_DIM = 100
NUM_EPOCHS = 500

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bundle = load_medmnist_data(data_flag=DATA_FLAG, batch_size=BATCH_SIZE, download=True)
train_loader = bundle.train_loader
label_names = {int(k): v for k, v in bundle.label_names.items()}
label_ids = sorted(label_names.keys())
num_classes = bundle.num_classes
img_channels = bundle.train_dataset[0][0].shape[0]


In [None]:
def count_parameters(model):
    return sum(param.numel() for param in model.parameters())

def measure_inference_time(generator, *, latent_dim, label=None, num_runs=32):
    generator.eval()
    sync = torch.cuda.synchronize if torch.cuda.is_available() else (lambda: None)
    with torch.no_grad():
        sync()
        start = time.time()
        for _ in range(num_runs):
            noise = torch.randn(1, latent_dim, 1, 1, device=device)
            if label is None:
                generator(noise)
            else:
                labels = torch.tensor([label], device=device, dtype=torch.long)
                generator(noise, labels)
        sync()
    return (time.time() - start) / num_runs

def average_inference_time(generators, *, requires_label):
    tempos = []
    for label, generator in generators.items():
        tempos.append(
            measure_inference_time(
                generator, latent_dim=LATENT_DIM, label=label if requires_label else None
            )
        )
    return mean(tempos) if tempos else float('nan')


In [None]:
def run_dcgan():
    start = time.time()
    generators = {}
    for label in label_ids:
        G = DCGenerator(latent_dim=LATENT_DIM, img_channels=img_channels).to(device)
        D = DCDiscriminator(img_channels=img_channels).to(device)
        generators[label] = train_gan_for_class(
            train_loader=train_loader,
            label_target=label,
            G=G,
            D=D,
            latent_dim=LATENT_DIM,
            num_epochs=NUM_EPOCHS,
            device=device,
        ).eval()
    total_time = time.time() - start
    generator_params = count_parameters(next(iter(generators.values())))
    inference_time = average_inference_time(generators, requires_label=False)
    return {
        'GAN': 'DCGAN',
        'Tempo_treinamento_seg': total_time,
        'Parametros_Gerador': generator_params,
        'Tempo_inferência_img_seg': inference_time,
    }

def run_cgan():
    start = time.time()
    generators = {}
    for label in label_ids:
        G = CGANGenerator(
            latent_dim=LATENT_DIM, num_classes=num_classes, img_channels=img_channels
        ).to(device)
        D = CGANDiscriminator(
            num_classes=num_classes, img_channels=img_channels
        ).to(device)
        generators[label] = train_cgan(
            train_loader=train_loader,
            G=G,
            D=D,
            latent_dim=LATENT_DIM,
            num_classes=num_classes,
            num_epochs=NUM_EPOCHS,
            device=device,
            label_target=label,
        ).eval()
    total_time = time.time() - start
    generator_params = count_parameters(next(iter(generators.values())))
    inference_time = average_inference_time(generators, requires_label=True)
    return {
        'GAN': 'CGAN',
        'Tempo_treinamento_seg': total_time,
        'Parametros_Gerador': generator_params,
        'Tempo_inferência_img_seg': inference_time,
    }

def run_wgangp():
    start = time.time()
    generators = {}
    for label in label_ids:
        G = WGANGPGenerator(latent_dim=LATENT_DIM, img_channels=img_channels).to(device)
        D = WGANGPCritic(img_channels=img_channels).to(device)
        generators[label] = train_wgangp(
            train_loader=train_loader,
            G=G,
            D=D,
            latent_dim=LATENT_DIM,
            num_epochs=NUM_EPOCHS,
            device=device,
            label_target=label,
        ).eval()
    total_time = time.time() - start
    generator_params = count_parameters(next(iter(generators.values())))
    inference_time = average_inference_time(generators, requires_label=False)
    return {
        'GAN': 'WGAN-GP',
        'Tempo_treinamento_seg': total_time,
        'Parametros_Gerador': generator_params,
        'Tempo_inferência_img_seg': inference_time,
    }


In [None]:
results = [run_dcgan(), run_cgan(), run_wgangp()]
df_results = pd.DataFrame(results)
df_results['Tempo_treinamento_min'] = df_results['Tempo_treinamento_seg'] / 60
df_results
