# GANs Clássicas com MedMNIST

Este notebook demonstra como treinar três variantes de GANs utilizando o dataset **MedMNIST**: **DCGAN**, **Conditional GAN** (CGAN) e **Wasserstein GAN com Gradiente Penalty** (WGAN-GP). O objetivo é gerar imagens sintéticas similares às do conjunto de dados.

## 1. Setup e Importações
Instale as dependências necessárias e carregue as principais bibliotecas utilizadas ao longo do notebook.

In [None]:
!pip install torch torchvision medmnist matplotlib torchmetrics seaborn scikit-learn scipy --quiet
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import medmnist
from medmnist import INFO
import matplotlib.pyplot as plt
import seaborn as sns

## 2. Carregamento e Preparação do Dataset
Utilizaremos o subset `PathMNIST` por ser relativamente pequeno e adequado para demonstrações rápidas.

In [None]:
# Seleciona o dataset
DATA_FLAG = 'pathmnist'
info = INFO[DATA_FLAG]
download = True

# Transformações básicas: conversão para tensor e normalização
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Carrega treino e teste
train_dataset = getattr(medmnist, info['python_class'])(split='train', transform=transform, download=download)
test_dataset = getattr(medmnist, info['python_class'])(split='test', transform=transform, download=download)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

## 3. Implementações das GANs
A seguir estão as implementações dos três modelos de GAN.

### DCGAN

In [None]:
class DCGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 128, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

class DCDiscriminator(nn.Module):
    def __init__(self, img_channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128*7*7, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

### Conditional GAN

In [None]:
class CGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=9, img_channels=3):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim + num_classes, 128, 4, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_input = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([noise, label_input], 1)
        return self.model(x)

class CGANDiscriminator(nn.Module):
    def __init__(self, num_classes=9, img_channels=3):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv = nn.Sequential(
            nn.Conv2d(img_channels + num_classes, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*7*7, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        label_img = self.label_emb(labels).unsqueeze(2).unsqueeze(3).expand(-1, -1, img.size(2), img.size(3))
        x = torch.cat([img, label_img], 1)
        x = self.conv(x)
        return self.fc(x)

### WGAN-GP

In [None]:
class WGANGPGenerator(DCGenerator):
    pass

class WGANGPCritic(nn.Module):
    def __init__(self, img_channels=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128*7*7, 1)
        )

    def forward(self, x):
        return self.model(x)

def gradient_penalty(critic, real, fake):
    batch_size = real.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=real.device)
    interpolated = epsilon*real + (1-epsilon)*fake
    interpolated.requires_grad_(True)
    mixed_scores = critic(interpolated)
    grad_outputs = torch.ones_like(mixed_scores)
    gradient = torch.autograd.grad(
        inputs=interpolated,
        outputs=mixed_scores,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradient = gradient.view(batch_size, -1)
    gp = ((gradient.norm(2, dim=1) - 1)**2).mean()
    return gp

## 4. Loop de Treinamento
Aqui apresentamos um laço simplificado para treinar cada uma das GANs. Ajuste hiperparâmetros conforme necessário.

In [None]:
# Exemplo de treinamento para DCGAN (pode ser adaptado para CGAN e WGAN-GP)
latent_dim = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

G = DCGenerator(latent_dim=latent_dim).to(device)
D = DCDiscriminator().to(device)

criterion = nn.BCELoss()
optim_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
optim_D = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

num_epochs = 50
for epoch in range(num_epochs):
    for imgs, labels in train_loader:
        real = imgs.to(device)
        b_size = real.size(0)
        noise = torch.randn(b_size, latent_dim, 1, 1, device=device)

        # Treina Discriminador
        optim_D.zero_grad()
        fake = G(noise)
        loss_real = criterion(D(real), torch.ones(b_size,1,device=device))
        loss_fake = criterion(D(fake.detach()), torch.zeros(b_size,1,device=device))
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optim_D.step()

        # Treina Gerador
        optim_G.zero_grad()
        output = D(fake)
        loss_G = criterion(output, torch.ones(b_size,1,device=device))
        loss_G.backward()
        optim_G.step()

    print(f"Epoch {epoch+1}/{num_epochs} - Loss D: {loss_D.item():.3f} | Loss G: {loss_G.item():.3f}")

## 5. Avaliação com Métricas
Avaliamos as imagens geradas utilizando as métricas **FID** e **Inception Score**, disponíveis no pacote `torchmetrics`.

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

fid = FrechetInceptionDistance(feature=64).to(device)
is_metric = InceptionScore().to(device)

# Exemplo: cálculo das métricas com lotes de imagens reais e geradas
with torch.no_grad():
    for real, _ in train_loader:
        real = real.to(device)
        noise = torch.randn(real.size(0), latent_dim,1,1, device=device)
        fake = G(noise)
        fid.update(real, real=True)
        fid.update(fake, real=False)
        is_metric.update(fake)

fid_score = fid.compute()
is_mean, is_std = is_metric.compute()
print('FID:', fid_score.item())
print('IS:', is_mean.item(), '+/-', is_std.item())

## 6. Visualização dos Resultados
Mostramos amostras geradas pelos modelos e gráficos de perda.

In [None]:
# Geração de amostras
G.eval()
noise = torch.randn(16, latent_dim, 1, 1, device=device)
with torch.no_grad():
    samples = G(noise).cpu()

grid = torchvision.utils.make_grid(samples, nrow=4, normalize=True)
plt.figure(figsize=(6,6))
plt.imshow(grid.permute(1,2,0))
plt.axis('off')
plt.show()

## 7. Conclusão
Discutimos rapidamente os resultados obtidos e possíveis caminhos para melhorias futuras.