<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Generative_Adversarial_Networks_(GANs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the generator
class Generator(nn.Module):
    def __init__(self, noise_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Instantiate models, optimizers, and loss function
noise_dim = 100
data_dim = 28 * 28
generator = Generator(noise_dim, data_dim)
discriminator = Discriminator(data_dim)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
criterion = nn.BCELoss()

# Training loop
for epoch in range(10000):
    real_data = torch.rand(64, data_dim)
    noise = torch.rand(64, noise_dim)
    fake_data = generator(noise)

    # Train discriminator
    real_labels = torch.ones(64, 1)
    fake_labels = torch.zeros(64, 1)
    optimizer_D.zero_grad()
    output_real = discriminator(real_data)
    loss_real = criterion(output_real, real_labels)
    output_fake = discriminator(fake_data.detach())
    loss_fake = criterion(output_fake, fake_labels)
    loss_D = loss_real + loss_fake
    loss_D.backward()
    optimizer_D.step()

    # Train generator
    optimizer_G.zero_grad()
    output_fake = discriminator(fake_data)
    loss_G = criterion(output_fake, real_labels)
    loss_G.backward()
    optimizer_G.step()

    if epoch % 1000 == 0:
        print(f"Epoch {epoch}, Loss D: {loss_D.item()}, Loss G: {loss_G.item()}")