In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# Custom GAN generator and discriminator
class Generator(nn.Module):
    def __init__(self, latent_dim, image_size):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, image_size),
            nn.Sigmoid()  # Sigmoid activation for image pixels (0 to 1)
        )

    def forward(self, z):
        img = self.model(z)
        return img

class Discriminator(nn.Module):
    def __init__(self, image_size):
        super(Discriminator, self).__init__() 
        self.model = nn.Sequential(
            nn.Linear(image_size, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

# Custom GAN loss function for the generator
def custom_generator_loss(discriminator_output_fake):
    return torch.mean(torch.log(1 - discriminator_output_fake))

# Set hyperparameters
latent_dim = 100
image_size = 28 * 28
batch_size = 64
epochs = 20
lr = 0.0002
sample_interval = 10  # Number of epochs to generate and show sample images

# Initialize the generator, discriminator, and optimizers
generator = Generator(latent_dim, image_size)
discriminator = Discriminator(image_size)
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# Custom data transformation for MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load the MNIST dataset
mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=False)
data_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(epochs):
    for batch_idx, (real_images, _) in enumerate(data_loader):
        # Adversarial ground truths
        valid = torch.ones((batch_size, 1))
        fake = torch.zeros((batch_size, 1))

        # Sample noise as generator input
        z = torch.randn((batch_size, latent_dim))

        # Train the discriminator
        optimizer_D.zero_grad()
        real_images = real_images.view(batch_size, -1)
        real_loss = torch.nn.BCELoss()(discriminator(real_images), valid)
        fake_loss = torch.nn.BCELoss()(discriminator(generator(z).detach()), fake)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train the generator
        optimizer_G.zero_grad()
        g_loss = custom_generator_loss(discriminator(generator(z)))
        g_loss.backward()
        optimizer_G.step()

        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch [{batch_idx}/{len(data_loader)}] D Loss: {d_loss.item()} G Loss: {g_loss.item()}")

    # Generate and save sample images at specified intervals
    if epoch % sample_interval == 0:
        z = torch.randn(25, latent_dim)
        generated_images = generator(z)
        generated_images = generated_images.view(generated_images.size(0), 1, 28, 28).cpu().detach().numpy()

        plt.figure(figsize=(5, 5))
        for i in range(25):
            plt.subplot(5, 5, i + 1)
            plt.imshow(generated_images[i, 0], cmap='gray')
            plt.axis('off')
        plt.show()
