<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Generative_Adversarial_Networks_(GANs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the Generator
class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Hyperparameters
noise_dim = 100
img_dim = 28 * 28  # MNIST images are 28x28 pixels
batch_size = 64
lr = 0.0002
num_epochs = 50

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model initialization
generator = Generator(noise_dim, img_dim)
discriminator = Discriminator(img_dim)

# Optimizers and loss function
optim_gen = optim.Adam(generator.parameters(), lr=lr)
optim_dis = optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCELoss()

# Training loop
for epoch in range(num_epochs):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.view(-1, img_dim)
        batch_size = real_imgs.size(0)

        # Train Discriminator
        noise = torch.randn(batch_size, noise_dim)
        fake_imgs = generator(noise)
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        real_loss = criterion(discriminator(real_imgs), real_labels)
        fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
        dis_loss = real_loss + fake_loss

        optim_dis.zero_grad()
        dis_loss.backward()
        optim_dis.step()

        # Train Generator
        gen_loss = criterion(discriminator(fake_imgs), real_labels)

        optim_gen.zero_grad()
        gen_loss.backward()
        optim_gen.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Dis Loss: {dis_loss.item()}, Gen Loss: {gen_loss.item()}")

# Generate new images after training
generator.eval()
with torch.no_grad():
    noise = torch.randn(16, noise_dim)
    generated_imgs = generator(noise).view(-1, 1, 28, 28)
    generated_imgs = (generated_imgs + 1) / 2  # Rescale to [0, 1]