<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Generative_Adversarial_Networks_(GANs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the Generator class
class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 128),  # First fully connected layer
            nn.ReLU(),  # ReLU activation
            nn.Linear(128, output_dim),  # Second fully connected layer
            nn.Tanh()  # Tanh activation
        )

    def forward(self, z):
        return self.fc(z)  # Forward pass

# Define the Discriminator class
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),  # First fully connected layer
            nn.ReLU(),  # ReLU activation
            nn.Linear(128, 1),  # Second fully connected layer
            nn.Sigmoid()  # Sigmoid activation
        )

    def forward(self, x):
        return self.fc(x)  # Forward pass

# Define the train_gan function
def train_gan(generator, discriminator, dataloader, epochs, latent_dim):
    criterion = nn.BCELoss()  # Binary Cross-Entropy loss
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)  # Optimizer for the generator
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)  # Optimizer for the discriminator

    for epoch in range(epochs):
        for real_data, _ in dataloader:
            real_data = real_data.view(real_data.size(0), -1)  # Flatten the images
            # Train Discriminator
            d_optimizer.zero_grad()
            real_labels = torch.ones(real_data.size(0), 1)  # Labels for real data
            fake_labels = torch.zeros(real_data.size(0), 1)  # Labels for fake data
            real_output = discriminator(real_data)  # Discriminator output for real data
            d_loss_real = criterion(real_output, real_labels)  # Loss for real data

            z = torch.randn(real_data.size(0), latent_dim)  # Random noise for generating fake data
            fake_data = generator(z)  # Generate fake data
            fake_output = discriminator(fake_data.detach())  # Discriminator output for fake data
            d_loss_fake = criterion(fake_output, fake_labels)  # Loss for fake data

            d_loss = d_loss_real + d_loss_fake  # Total loss for the discriminator
            d_loss.backward()  # Backpropagate the loss
            d_optimizer.step()  # Update the discriminator

            # Train Generator
            g_optimizer.zero_grad()
            z = torch.randn(real_data.size(0), latent_dim)  # Random noise for generating fake data
            fake_data = generator(z)  # Generate fake data
            fake_output = discriminator(fake_data)  # Discriminator output for fake data
            g_loss = criterion(fake_output, real_labels)  # Generator loss (goal is to fool the discriminator)

            g_loss.backward()  # Backpropagate the loss
            g_optimizer.step()  # Update the generator

        print(f'Epoch [{epoch + 1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}')

# Example usage
latent_dim = 100
generator = Generator(latent_dim=latent_dim, output_dim=784)
discriminator = Discriminator(input_dim=784)

# Define the transformation for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Train the GAN
train_gan(generator, discriminator, dataloader, epochs=20, latent_dim=latent_dim)