<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Generative_Adversarial_Networks_(GANs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

# Hyperparameters
z_dim = 64
batch_size = 64
epochs = 10000

# Load dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
mnist_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(mnist_data, batch_size=batch_size, shuffle=True)

# Initialize models
generator = Generator(z_dim)
discriminator = Discriminator()

# Loss and optimizer
criterion = nn.BCELoss()
optimizer_gen = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_dis = optim.Adam(discriminator.parameters(), lr=0.0002)

# Training loop
for epoch in range(epochs):
    for real_data, _ in data_loader:
        batch_size = real_data.size(0)  # Get the actual batch size
        real_data = real_data.view(batch_size, -1)  # Flatten MNIST images

        # Train discriminator
        z = torch.randn(batch_size, z_dim)
        fake_data = generator(z)
        real_label = torch.ones(batch_size, 1)
        fake_label = torch.zeros(batch_size, 1)

        discriminator.zero_grad()
        output_real = discriminator(real_data)
        output_fake = discriminator(fake_data.detach())
        loss_real = criterion(output_real, real_label)
        loss_fake = criterion(output_fake, fake_label)
        loss_dis = loss_real + loss_fake
        loss_dis.backward()
        optimizer_dis.step()

        # Train generator
        generator.zero_grad()
        output_fake = discriminator(fake_data)
        loss_gen = criterion(output_fake, real_label)
        loss_gen.backward()
        optimizer_gen.step()

    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, Loss D: {loss_dis.item()}, Loss G: {loss_gen.item()}")