<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/GANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid  # Import make_grid
import matplotlib.pyplot as plt
import os

# Define the generator network
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)

# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# Model, loss, and optimizer
latent_dim = 100
generator = Generator(latent_dim, 28*28)
discriminator = Discriminator(28*28)
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=30, gamma=0.5)
scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=30, gamma=0.5)

# Create directory to save generated images
if not os.path.exists('gan_images'):
    os.makedirs('gan_images')

# Function to save generated images
def save_generated_images(epoch):
    with torch.no_grad():
        test_noise = torch.randn(16, latent_dim)
        generated_images = generator(test_noise).view(-1, 1, 28, 28)
        grid = make_grid(generated_images, nrow=4, normalize=True)
        plt.imshow(grid.permute(1, 2, 0))
        plt.savefig(f'gan_images/epoch_{epoch}.png')
        plt.close()

# Training loop
epochs = 100
for epoch in range(epochs):
    for real_images, _ in dataloader:
        real_images = real_images.view(-1, 28*28)
        batch_size = real_images.size(0)

        # Train discriminator
        labels_real = torch.ones(batch_size, 1)
        labels_fake = torch.zeros(batch_size, 1)
        latent_noise = torch.randn(batch_size, latent_dim)
        fake_images = generator(latent_noise)

        optimizer_d.zero_grad()
        output_real = discriminator(real_images)
        output_fake = discriminator(fake_images.detach())
        loss_d_real = criterion(output_real, labels_real)
        loss_d_fake = criterion(output_fake, labels_fake)
        loss_d = loss_d_real + loss_d_fake
        loss_d.backward()
        optimizer_d.step()

        # Train generator
        optimizer_g.zero_grad()
        output_fake = discriminator(fake_images)
        loss_g = criterion(output_fake, labels_real)
        loss_g.backward()
        optimizer_g.step()

    scheduler_g.step()
    scheduler_d.step()

    print(f"Epoch {epoch+1}, Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}")

    if epoch % 10 == 0:
        save_generated_images(epoch)