<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Generative_Models_GANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the Generator class
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

# Define the Discriminator class
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

# Load the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Initialize the models
gen = Generator()
disc = Discriminator()

# Define the loss function and optimizers
criterion = nn.BCELoss()
optimizer_gen = optim.Adam(gen.parameters(), lr=0.0002)
optimizer_disc = optim.Adam(disc.parameters(), lr=0.0002)

# Training loop
num_epochs = 200
latent_dim = 100

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images.view(real_images.size(0), -1)

        # Train the Discriminator
        disc.zero_grad()
        real_labels = torch.ones(real_images.size(0), 1)
        fake_labels = torch.zeros(real_images.size(0), 1)

        outputs = disc(real_images)
        d_loss_real = criterion(outputs, real_labels)
        d_loss_real.backward()

        z = torch.randn(real_images.size(0), latent_dim)
        fake_images = gen(z)
        outputs = disc(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()

        d_loss = d_loss_real + d_loss_fake
        optimizer_disc.step()

        # Train the Generator
        gen.zero_grad()
        outputs = disc(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_gen.step()

    print(f'Epoch [{epoch}/{num_epochs}], Loss D: {d_loss.item()}, Loss G: {g_loss.item()}')

# Generate and visualize some fake images after training
import matplotlib.pyplot as plt

z = torch.randn(64, latent_dim)
fake_images = gen(z)
fake_images = fake_images.view(fake_images.size(0), 28, 28)

fig, axes = plt.subplots(1, 8, figsize=(20, 2))
for i in range(8):
    axes[i].imshow(fake_images[i].detach().numpy(), cmap='gray')
    axes[i].axis('off')
plt.show()