In [1]:
import math
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Cargar datos
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dl = DataLoader(dataset=train_ds, shuffle=True, batch_size=64)

# Arquitecturas
class Discriminator(nn.Module):
    def __init__(self, in_features=784, out_features=1):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(32, out_features)
        )

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        return self.model(x)

class Generator(nn.Module):
    def __init__(self, in_features=100, out_features=784):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features, 32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(32, 64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(64, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(128, out_features),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Funciones de pérdida
def compute_loss(logits, target_value, loss_fn, device):
    """Calcula pérdida para outputs reales (target=1) o falsos (target=0)"""
    batch_size = logits.shape[0]
    targets = torch.full((batch_size,), target_value, device=device)
    return loss_fn(logits.squeeze(), targets)

# Entrenamiento
def train_gan(d, g, d_optim, g_optim, loss_fn, dl, n_epochs, device, z_size=100):
    print(f'Training on [{device}]...')

    # Vector latente fijo para monitorear progreso
    fixed_z = torch.randn(16, z_size, device=device)
    fixed_samples = []
    d_losses, g_losses = [], []

    d.to(device)
    g.to(device)

    for epoch in range(n_epochs):
        d.train()
        g.train()
        d_running_loss = 0
        g_running_loss = 0

        for real_images, _ in dl:
            real_images = real_images.to(device)
            batch_size = real_images.size(0)

            # Entrenar Discriminador
            d_optim.zero_grad()

            # Con imágenes reales (escalar a [-1, 1])
            real_images = (real_images * 2) - 1
            d_real_out = d(real_images)
            d_real_loss = compute_loss(d_real_out, 1.0, loss_fn, device)

            # Con imágenes falsas
            z = torch.randn(batch_size, z_size, device=device)
            fake_images = g(z).detach()
            d_fake_out = d(fake_images)
            d_fake_loss = compute_loss(d_fake_out, 0.0, loss_fn, device)

            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            d_optim.step()
            d_running_loss += d_loss.item()

            # Entrenar Generador
            g_optim.zero_grad()

            z = torch.randn(batch_size, z_size, device=device)
            fake_images = g(z)
            g_out = d(fake_images)
            g_loss = compute_loss(g_out, 1.0, loss_fn, device)

            g_loss.backward()
            g_optim.step()
            g_running_loss += g_loss.item()

        # Guardar pérdidas por época
        d_losses.append(d_running_loss / len(dl))
        g_losses.append(g_running_loss / len(dl))

        print(f'Epoch [{epoch+1}/{n_epochs}] - D_loss: {d_losses[-1]:.4f}, G_loss: {g_losses[-1]:.4f}')

        # Guardar muestras del generador
        g.eval()
        with torch.no_grad():
            fixed_samples.append(g(fixed_z).cpu())

    # Guardar muestras fijas
    with open('fixed_samples.pkl', 'wb') as f:
        pkl.dump(fixed_samples, f)

    return d_losses, g_losses

# Configuración y entrenamiento
d = Discriminator()
g = Generator()

d_optim = optim.Adam(d.parameters(), lr=0.002)
g_optim = optim.Adam(g.parameters(), lr=0.002)
loss_fn = nn.BCEWithLogitsLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Entrenar
n_epochs = 100
d_losses, g_losses = train_gan(d, g, d_optim, g_optim, loss_fn, dl, n_epochs, device)

100%|██████████| 9.91M/9.91M [00:00<00:00, 221MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 32.7MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 31.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.81MB/s]


Training on [cpu]...
Epoch [1/100] - D_loss: 1.1602, G_loss: 1.8300
Epoch [2/100] - D_loss: 1.2673, G_loss: 1.0781
Epoch [3/100] - D_loss: 1.1982, G_loss: 1.0727
Epoch [4/100] - D_loss: 1.1126, G_loss: 1.3674
Epoch [5/100] - D_loss: 1.1161, G_loss: 1.2542
Epoch [6/100] - D_loss: 1.1521, G_loss: 1.2489
Epoch [7/100] - D_loss: 1.1074, G_loss: 1.3390
Epoch [8/100] - D_loss: 1.1356, G_loss: 1.1910
Epoch [9/100] - D_loss: 1.1379, G_loss: 1.1535
Epoch [10/100] - D_loss: 1.1910, G_loss: 1.0719
Epoch [11/100] - D_loss: 1.2027, G_loss: 1.0357
Epoch [12/100] - D_loss: 1.2057, G_loss: 1.0187
Epoch [13/100] - D_loss: 1.2086, G_loss: 1.0276
Epoch [14/100] - D_loss: 1.2144, G_loss: 0.9990
Epoch [15/100] - D_loss: 1.2322, G_loss: 0.9690
Epoch [16/100] - D_loss: 1.2292, G_loss: 0.9811
Epoch [17/100] - D_loss: 1.2278, G_loss: 0.9831
Epoch [18/100] - D_loss: 1.2296, G_loss: 0.9782
Epoch [19/100] - D_loss: 1.2291, G_loss: 0.9886
Epoch [20/100] - D_loss: 1.2430, G_loss: 0.9702
Epoch [21/100] - D_loss: 1.2