<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Generative_Adversarial_Networks_(GANs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

class Generator(nn.Module):
    def __init__(self, noise_dim, output_dim):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.fc(x)

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.fc(x)

# Hyperparameters
noise_dim = 100
batch_size = 64
learning_rate = 0.0002
epochs = 50

# Data loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset = MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Instantiate models and optimizers
generator = Generator(noise_dim=noise_dim, output_dim=784).to("cuda" if torch.cuda.is_available() else "cpu")
discriminator = Discriminator(input_dim=784).to("cuda" if torch.cuda.is_available() else "cpu")
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
criterion = nn.BCELoss()

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for epoch in range(epochs):
    for real_images, _ in dataloader:
        # Flatten the images and move to device
        real_images = real_images.view(-1, 784).to(device)

        # Create labels
        real_labels = torch.ones(real_images.size(0), 1).to(device)  # Real label = 1
        fake_labels = torch.zeros(real_images.size(0), 1).to(device)  # Fake label = 0

        # Train discriminator
        optimizer_D.zero_grad()

        # Real images loss
        outputs = discriminator(real_images)
        d_loss_real = criterion(outputs, real_labels)

        # Fake images loss
        noise = torch.randn(real_images.size(0), noise_dim).to(device)
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())  # Detach to avoid updating generator
        d_loss_fake = criterion(outputs, fake_labels)

        # Total discriminator loss
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train generator
        optimizer_G.zero_grad()

        # Generate fake images
        noise = torch.randn(real_images.size(0), noise_dim).to(device)
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)  # Flip labels for generator

        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch + 1}/{epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")