# Experimento 5





In [None]:
# ==========================================================
# IMPORTS
# ==========================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt

from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay

# ==========================================================
# CONFIGURAÇÕES
# ==========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 128
epochs_cnn = 10
epochs_ae = 15
num_classes = 10

latent_dims = [16, 32, 64, 128]  # atividade do aluno

# ==========================================================
# CNN PARA CLASSIFICAÇÃO E EXTRAÇÃO DE FEATURES
# ==========================================================
class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc = nn.Linear(64 * 7 * 7, num_classes)

    def forward(self, x, return_features=False):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        features = x.view(x.size(0), -1)

        if return_features:
            return features

        return self.fc(features)

# ==========================================================
# AUTOENCODER
# ==========================================================
class Autoencoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

# ==========================================================
# TREINAMENTO
# ==========================================================
def train_cnn(model, loader):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    crit = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs_cnn):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            loss = crit(model(x), y)

            opt.zero_grad()
            loss.backward()
            opt.step()

        print(f"[CNN] Epoch {epoch+1}/{epochs_cnn} | Loss: {loss.item():.4f}")

def train_autoencoder(model, loader):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    crit = nn.MSELoss()

    model.train()
    for epoch in range(epochs_ae):
        for x, _ in loader:
            x = x.view(x.size(0), -1).to(device)
            loss = crit(model(x), x)

            opt.zero_grad()
            loss.backward()
            opt.step()

        print(f"[AE] Epoch {epoch+1}/{epochs_ae} | Loss: {loss.item():.4f}")

# ==========================================================
# EXTRAÇÃO DE FEATURES
# ==========================================================
def extract_cnn_features(model, loader):
    model.eval()
    X, y = [], []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            feats = model(images, return_features=True)

            X.append(feats.cpu().numpy())
            y.append(labels.numpy())

    return np.vstack(X), np.hstack(y)

def extract_ae_features(model, loader):
    model.eval()
    X, y = [], []

    with torch.no_grad():
        for images, labels in loader:
            images = images.view(images.size(0), -1).to(device)
            z = model.encoder(images)

            X.append(z.cpu().numpy())
            y.append(labels.numpy())

    return np.vstack(X), np.hstack(y)

# ==========================================================
# DATASET
# ==========================================================
transform = transforms.ToTensor()

train_data = datasets.MNIST("data", train=True, download=True, transform=transform)
test_data  = datasets.MNIST("data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# ==========================================================
# CNN BASELINE
# ==========================================================
cnn = CNNFeatureExtractor().to(device)
train_cnn(cnn, train_loader)

X_train_cnn, y_train = extract_cnn_features(cnn, train_loader)
X_test_cnn,  y_test  = extract_cnn_features(cnn, test_loader)

svm_cnn = SVC(kernel="rbf", gamma="scale")
svm_cnn.fit(X_train_cnn, y_train)
acc_cnn = accuracy_score(y_test, svm_cnn.predict(X_test_cnn))

print(f"\nBaseline CNN features | Acurácia SVM: {acc_cnn:.4f}")

# ==========================================================
# EXPERIMENTO AUTOENCODER
# ==========================================================
results = []

for z_dim in latent_dims:
    print(f"\n===== Latent dim: {z_dim} =====")

    ae = Autoencoder(z_dim).to(device)
    train_autoencoder(ae, train_loader)

    X_train_ae, _ = extract_ae_features(ae, train_loader)
    X_test_ae,  _ = extract_ae_features(ae, test_loader)

    svm = SVC(kernel="rbf", gamma="scale")
    svm.fit(X_train_ae, y_train)
    acc = accuracy_score(y_test, svm.predict(X_test_ae))

    print(f"Acurácia SVM (AE): {acc:.4f}")

    results.append({
        "latent_dim": z_dim,
        "compression": z_dim / (28*28),
        "accuracy": acc
    })

# ==========================================================
# RESUMO FINAL
# ==========================================================
print("\n===== TRADE-OFF COMPRESSÃO × DESEMPENHO =====")
for r in results:
    print(
        f"Latent: {r['latent_dim']:3d} | "
        f"Compressão: {r['compression']:.4f} | "
        f"Acurácia: {r['accuracy']:.4f}"
    )


100%|██████████| 9.91M/9.91M [00:00<00:00, 35.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.04MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.84MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.43MB/s]
