In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# Verificar la disponibilidad de la GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Crear un directorio para almacenar las imágenes generadas
os.makedirs('gan_images', exist_ok=True)

# Configuración de parámetros
latent_dim = 100
image_size = 28 * 28
batch_size = 64

# Definir la arquitectura del generador
generator = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
).to(device)

# Definir la arquitectura del discriminador
discriminator = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
).to(device)

# Cargar y preprocesar los datos MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset=mnist_dataset, batch_size=batch_size, shuffle=True)

# Inicializar los optimizadores
optimizer_generator = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002)

# Función de pérdida
criterion = nn.BCELoss()

# Entrenamiento de la GAN
num_epochs = 60
for epoch in range(num_epochs):
    for real_images, _ in data_loader:
        real_images = real_images.view(-1, image_size).to(device)
        real_labels = torch.ones(real_images.size(0), 1).to(device)
        optimizer_discriminator.zero_grad()
        outputs = discriminator(real_images)
        loss_real = criterion(outputs, real_labels)
        loss_real.backward()

        noise = torch.randn(real_images.size(0), latent_dim).to(device)
        fake_images = generator(noise)
        fake_labels = torch.zeros(real_images.size(0), 1).to(device)

        outputs = discriminator(fake_images.detach())
        loss_fake = criterion(outputs, fake_labels)
        loss_fake.backward()

        optimizer_discriminator.step()

        optimizer_generator.zero_grad()
        outputs = discriminator(fake_images)
        loss_generator = criterion(outputs, real_labels)
        loss_generator.backward()

        optimizer_generator.step()

    with torch.no_grad():
        target_digit = 7
        noise = torch.randn(16, latent_dim).to(device)
        #noise[:, target_digit] = 2.0
        generated_images = generator(noise).cpu().detach()

    fig, axes = plt.subplots(2, 8, figsize=(15, 3))
    for i in range(2):
        for j in range(8):
            axes[i, j].imshow(generated_images[i * 8 + j].view(28, 28), cmap='gray')
            axes[i, j].axis('off')
    plt.savefig(f'gan_images/epoch_{epoch + 1}_digit_{target_digit}.png')
    plt.show()
    plt.close()
