<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/GANs_with_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

# Training the GAN
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=0.0002)
optimizerD = optim.Adam(discriminator.parameters(), lr=0.0002)

for epoch in range(100):
    # Train Discriminator
    optimizerD.zero_grad()
    real_data = torch.randn(64, 784)
    real_labels = torch.ones(64, 1)
    fake_data = generator(torch.randn(64, 100))
    fake_labels = torch.zeros(64, 1)

    real_output = discriminator(real_data)
    fake_output = discriminator(fake_data)

    d_loss_real = criterion(real_output, real_labels)
    d_loss_fake = criterion(fake_output, fake_labels)
    d_loss = d_loss_real + d_loss_fake
    d_loss.backward()
    optimizerD.step()

    # Train Generator
    optimizerG.zero_grad()
    fake_data = generator(torch.randn(64, 100))
    fake_output = discriminator(fake_data)
    g_loss = criterion(fake_output, real_labels)
    g_loss.backward()
    optimizerG.step()

    print(f"Epoch {epoch}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")