<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Generative_Adversarial_Networks_(GANs)_with_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

# Training loop
def train_gan():
    batch_size = 64
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('.', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)

    G = Generator()
    D = Discriminator()
    criterion = nn.BCELoss()
    optimizerG = optim.Adam(G.parameters(), lr=0.0002)
    optimizerD = optim.Adam(D.parameters(), lr=0.0002)

    for epoch in range(10):
        for i, (data, _) in enumerate(train_loader):
            real_data = data.view(-1, 28*28)
            batch_size = real_data.size(0)

            # Train discriminator
            optimizerD.zero_grad()
            output = D(real_data)
            real_loss = criterion(output, torch.ones(batch_size, 1))
            real_loss.backward()

            noise = torch.randn(batch_size, 100)
            fake_data = G(noise)
            output = D(fake_data.detach())
            fake_loss = criterion(output, torch.zeros(batch_size, 1))
            fake_loss.backward()
            optimizerD.step()

            # Train generator
            optimizerG.zero_grad()
            output = D(fake_data)
            gen_loss = criterion(output, torch.ones(batch_size, 1))
            gen_loss.backward()
            optimizerG.step()

            if i % 100 == 0:
                print(f"Epoch {epoch}/{10}, Batch {i}/{len(train_loader)}, D Loss: {real_loss.item() + fake_loss.item()}, G Loss: {gen_loss.item()}")

train_gan()