# GAN

In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generador(nn.Module):
    def __init__(self, n_latent):
        super(Generador, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_latent, 256),
            nn.GELU(),
            nn.Linear(256, 28 * 28),
            nn.Tanh(),
        )

    def forward(self, z):
        img = self.net(z)
        return img

class Discriminador(nn.Module):
    def __init__(self):
        super(Discriminador, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(28 * 28, 516),
            nn.GELU(),
            nn.Linear(516, 64),
            nn.GELU(),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        return self.net(img)

# Parámetros de entrenamiento
n_latent = 128
n_epoch = 50
n_batch = 64
lr = 0.0002

# Cargar datos MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=n_batch, shuffle=True)

# Inicializar modelos y enviarlos a GPU si está disponible
G = Generador(n_latent).to(device)
D = Discriminador().to(device)

# Optimizadores
optG = Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optD = Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# Función de pérdida
loss_fn = nn.BCELoss()

# Entrenamiento de la GAN
for epoch in range(n_epoch):
    for img, _ in dataloader:
        img = img.view(n_batch, -1).to(device)  # Aplanar la imagen y enviarla al dispositivo

        # ----------------------
        # Entrenamiento del Discriminador
        # ----------------------
        optD.zero_grad()
        realLabel = torch.ones((n_batch, 1), device=device)
        fakeLabel = torch.zeros((n_batch, 1), device=device)

        # Predicción con imágenes reales
        realPred = D(img)
        realLoss = loss_fn(realPred, realLabel)

        # Generar imágenes falsas
        z = torch.randn(n_batch, n_latent, device=device)
        fake = G(z)
        fakePred = D(fake.detach())  # No actualizar G en esta fase
        fakeLoss = loss_fn(fakePred, fakeLabel)

        # Calcular y actualizar pérdida del discriminador
        lossD = (realLoss + fakeLoss) / 2
        lossD.backward()
        optD.step()

        # ----------------------
        # Entrenamiento del Generador
        # ----------------------
        optG.zero_grad()
        pred = D(fake)
        lossGen = loss_fn(pred, realLabel)  # Queremos que el generador engañe al discriminador
        lossGen.backward()
        optG.step()

    print(f"Época {epoch+1}/{n_epoch} - Pérdida D: {lossD.item():.4f}, Pérdida G: {lossGen.item():.4f}")
