# Experimento 3

In [6]:
# main_experimento3.py
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
import numpy as np

#removi os imports das pastas pois estou usando colab.
#os outros arquivos estao nas celulas abaixo

device = "cuda" if torch.cuda.is_available() else "cpu"

# ======================
# Dados
# ======================
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.MNIST("data", train=True, download=True, transform=transform)
test_data  = datasets.MNIST("data", train=False, download=True, transform=transform)

test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

# ======================
# CNN baseline (somente real)
# ======================
train_loader_real = DataLoader(train_data, batch_size=128, shuffle=True)

cnn_real = CNNClassifier().to(device)
train_cnn(cnn_real, train_loader_real, device)

# ======================
# Experimentos
# ======================
latent_dims = [32, 64, 128]
synthetic_ratios = [0.1, 0.3, 0.5]

results = []

def generate_fake(E, D, n):
    z = torch.randn(n, E.fc.out_features, device=device)
    labels = torch.randint(0, 10, (n,), device=device)
    with torch.no_grad():
        images = D(z, labels)
    return TensorDataset(images.cpu(), labels.cpu())

for z_dim in latent_dims:
    print(f"\n===== Latent dim: {z_dim} =====")

    # Treina CAAE
    E = CAAEEncoder(z_dim=z_dim).to(device)
    D = CAAEDecoder(z_dim=z_dim).to(device)
    C = CAAEDiscriminator(z_dim=z_dim).to(device)

    train_caae(E, D, C, train_loader_real, device)

    # Gera dados sintéticos
    fake_dataset = generate_fake(E, D, len(train_data))

    for ratio in synthetic_ratios:
        print(f"\n--- Proporção sintética: {int(ratio*100)}% ---")

        n_fake = int(len(train_data) * ratio)
        fake_subset, _ = torch.utils.data.random_split(
            fake_dataset, [n_fake, len(fake_dataset) - n_fake]
        )

        mixed_dataset = ConcatDataset([train_data, fake_subset])
        mixed_loader = DataLoader(mixed_dataset, batch_size=128, shuffle=True)

        # Treina CNN mista
        cnn_mix = CNNClassifier().to(device)
        train_cnn(cnn_mix, mixed_loader, device)

        # Avaliação
        plot_confusion_matrix(
            cnn_mix, test_loader, device,
            f"CAAE z={z_dim} | {int(ratio*100)}% sintético"
        )

        # ======================
        # FID
        # ======================
        fake_loader = DataLoader(fake_subset, batch_size=128, shuffle=False)

        real_feat = extract_features(cnn_real, train_loader_real, device)
        fake_feat = extract_features(cnn_real, fake_loader, device)

        mu_r, sig_r = compute_statistics(real_feat)
        mu_f, sig_f = compute_statistics(fake_feat)

        fid = calculate_fid(mu_r, sig_r, mu_f, sig_f)

        print(f"FID: {fid:.2f}")

        results.append({
            "latent_dim": z_dim,
            "synthetic_ratio": ratio,
            "fid": fid
        })

print("\nResumo dos resultados:")
for r in results:
    print(r)


100%|██████████| 9.91M/9.91M [00:00<00:00, 22.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 685kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.56MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 2.69MB/s]


Epoch [1/10] | Loss: 0.0608


KeyboardInterrupt: 

In [2]:
# EVALUATION:

#confusion_matrix.py
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def plot_confusion_matrix(model, dataloader, device, title):
    """
    Calcula e plota a matriz de confusão de um classificador

    Parâmetros:
    - model: CNN treinada
    - dataloader: loader do conjunto de teste
    - device: cpu ou cuda
    - title: título do gráfico
    """

    model.eval()  # coloca o modelo em modo de avaliação

    all_preds = []   # lista para armazenar predições
    all_labels = []  # lista para armazenar rótulos reais

    # Desativa cálculo de gradiente (economia de memória)
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)

            # Classe predita = argmax dos logits
            preds = torch.argmax(outputs, dim=1)

            # Armazena resultados
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    # Concatena todos os batches
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    # Calcula a matriz de confusão
    cm = confusion_matrix(all_labels, all_preds)

    # Cria objeto de visualização
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)

    # Plota
    fig, ax = plt.subplots(figsize=(7, 7))
    disp.plot(ax=ax, cmap="Blues", colorbar=True)
    ax.set_title(title)

    plt.show()


    #fid.py
    import torch
import numpy as np
from scipy import linalg

def extract_features(model, dataloader, device):
    """
    Extrai features intermediárias de uma CNN para cálculo do FID

    Parâmetros:
    - model: CNN treinada (usada como extrator)
    - dataloader: DataLoader (real ou fake)
    - device: cpu ou cuda

    Retorno:
    - features: array numpy (N, D)
    """

    model.eval()               # modo avaliação
    features = []              # lista para armazenar features

    with torch.no_grad():      # desativa gradientes
        for images, _ in dataloader:
            images = images.to(device)

            # Forward apenas até a camada convolucional
            x = model.conv1(images)
            x = torch.relu(x)
            x = torch.max_pool2d(x, 2)

            x = model.conv2(x)
            x = torch.relu(x)
            x = torch.max_pool2d(x, 2)

            # Flatten das features
            x = x.view(x.size(0), -1)

            features.append(x.cpu().numpy())

    # Concatena todos os batches
    features = np.concatenate(features, axis=0)
    return features


def compute_statistics(features):
    """
    Calcula média e covariância das features

    Parâmetros:
    - features: array (N, D)

    Retorno:
    - mu: média (D,)
    - sigma: covariância (D, D)
    """

    mu = np.mean(features, axis=0)
    sigma = np.cov(features, rowvar=False)

    return mu, sigma


def calculate_fid(mu1, sigma1, mu2, sigma2):
    """
    Calcula Fréchet Inception Distance (FID)

    Fórmula:
    ||μ1 - μ2||² + Tr(Σ1 + Σ2 - 2(Σ1Σ2)¹ᐟ²)
    """

    diff = mu1 - mu2

    # Produto das covariâncias
    covmean, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)

    # Correção numérica (parte imaginária pequena)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid


In [3]:
# MODELS:

#caae_decoder.py
import torch
import torch.nn as nn

class CAAEDecoder(nn.Module):
    """
    Decoder convolucional condicional
    Recebe z + rótulo e reconstrói imagem
    """
    def __init__(self, z_dim=100, n_classes=10):
        super().__init__()

        self.label_emb = nn.Embedding(n_classes, n_classes)

        self.fc = nn.Linear(z_dim + n_classes, 64 * 7 * 7)

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z, labels):
        y = self.label_emb(labels)
        x = torch.cat([z, y], dim=1)
        x = self.fc(x)
        x = x.view(-1, 64, 7, 7)
        return self.deconv(x)

#caae_discriminator.py
import torch
import torch.nn as nn

class CAAEDiscriminator(nn.Module):
    """
    Discriminador no espaço latente
    Força z ~ N(0, I)
    """
    def __init__(self, z_dim=100):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.net(z)



#caae_encoder.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class CAAEEncoder(nn.Module):
    """
    Encoder convolucional condicional
    Mapeia imagem -> espaço latente z
    """
    def __init__(self, z_dim=100):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),   # 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),  # 7x7
            nn.ReLU()
        )

        self.fc = nn.Linear(64 * 7 * 7, z_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)




#cnn.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNClassifier(nn.Module):
    """
    CNN simples para classificação do MNIST
    Entrada: imagem 1x28x28
    Saída: logits para 10 classes
    """
    def __init__(self):
        super().__init__()

        # Bloco convolucional 1
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        # Camadas totalmente conectadas
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Entrada: (B, 1, 28, 28)

        x = F.relu(self.conv1(x))        # (B, 32, 28, 28)
        x = F.max_pool2d(x, 2)           # (B, 32, 14, 14)

        x = F.relu(self.conv2(x))        # (B, 64, 14, 14)
        x = F.max_pool2d(x, 2)           # (B, 64, 7, 7)

        x = x.view(x.size(0), -1)        # Flatten
        x = F.relu(self.fc1(x))
        return self.fc2(x)               # Logits


In [4]:
# TRAINING:


#train_caae.py
import torch
import torch.nn as nn
import torch.optim as optim

def train_caae(E, D, C, dataloader, device, epochs=10):
    """
    Treinamento do Autoencoder Adversarial Condicional
    E: Encoder
    D: Decoder
    C: Discriminador latente
    """

    recon_loss = nn.MSELoss()
    adv_loss = nn.BCELoss()

    opt_E = optim.Adam(E.parameters(), lr=2e-4)
    opt_D = optim.Adam(D.parameters(), lr=2e-4)
    opt_C = optim.Adam(C.parameters(), lr=2e-4)

    for epoch in range(epochs):
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            batch = imgs.size(0)

            valid = torch.ones(batch, 1, device=device)
            fake = torch.zeros(batch, 1, device=device)

            # ======================
            # Discriminador latente
            # ======================
            z_real = torch.randn(batch, E.fc.out_features, device=device)
            z_fake = E(imgs)

            loss_C = adv_loss(C(z_real), valid) + adv_loss(C(z_fake.detach()), fake)

            opt_C.zero_grad()
            loss_C.backward()
            opt_C.step()

            # ======================
            # Encoder + Decoder
            # ======================
            z = E(imgs)
            recon = D(z, labels)

            loss_recon = recon_loss(recon, imgs)
            loss_adv = adv_loss(C(z), valid)

            loss_ED = loss_recon + 0.01 * loss_adv

            opt_E.zero_grad()
            opt_D.zero_grad()
            loss_ED.backward()
            opt_E.step()
            opt_D.step()

        print(f"Epoch {epoch+1}/{epochs} | Recon: {loss_recon.item():.4f} | Adv: {loss_adv.item():.4f}")


#train_cnn.py
import torch
import torch.nn as nn
import torch.optim as optim

def train_cnn(model, dataloader, device, epochs=10):
    """
    Treinamento supervisionado da CNN
    """
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    model.train()

    for epoch in range(epochs):
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)

            outputs = model(imgs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}] | Loss: {loss.item():.4f}")
