In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

In [None]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Hyperparameters
batch_size = 32
lr = 0.0002
epochs = 100
noise_dim = 100

In [None]:
# Dataset loading and transformation
transform = transforms.Compose([
    transforms.Resize(28),  # Ensure image size is 28x28
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

In [None]:
# Load MNIST dataset
mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)

In [None]:
# Generator Model
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()
        self.init_size = 7  # Initial size before upsampling (7x7 from 28x28 divided by 4)
        self.fc = nn.Linear(noise_dim, 128 * self.init_size ** 2)  # Output size: (128, 7, 7)
        
        self.block = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),  # Upsample to (64, 14, 14)
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1, bias=False),  # Upsample to (1, 28, 28)
            nn.Tanh()  # Output image in range [-1, 1]
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten noise input into (batch_size, noise_dim)
        out = self.fc(x).view(x.size(0), 128, self.init_size, self.init_size)  # Reshape into (batch_size, 128, 7, 7)
        return self.block(out)

In [None]:
# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),  # Input: (1, 28, 28), Output: (64, 14, 14)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),  # Input: (64, 14, 14), Output: (128, 7, 7)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 3, 2, 1, bias=False),  # Input: (128, 7, 7), Output: (256, 4, 4)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, 4, 1, 0, bias=False),  # Input: (256, 4, 4), Output: (1, 1, 1)
            nn.Sigmoid()  # Output between [0, 1]
        )

    def forward(self, x):
        return self.model(x).view(-1, 1)

In [None]:
# Initialize models
generator = Generator(noise_dim).to(device)
discriminator = Discriminator().to(device)

In [None]:
# Optimizers
optim_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optim_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
# Loss function
criterion = nn.BCELoss()

In [None]:
# Training function
def train_GAN(generator, discriminator, data_loader, epochs):
    for epoch in range(epochs):
        for i, (real_images, _) in enumerate(data_loader):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)

            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Train Discriminator
            optim_D.zero_grad()
            real_outputs = discriminator(real_images)
            real_loss = criterion(real_outputs, real_labels)

            z = torch.randn(batch_size, noise_dim).to(device)
            fake_images = generator(z)
            fake_outputs = discriminator(fake_images.detach())
            fake_loss = criterion(fake_outputs, fake_labels)

            D_loss = real_loss + fake_loss
            D_loss.backward()
            optim_D.step()

            # Train Generator
            optim_G.zero_grad()
            fake_outputs = discriminator(fake_images)
            G_loss = criterion(fake_outputs, real_labels)
            G_loss.backward()
            optim_G.step()

        print(f'Epoch [{epoch+1}/{epochs}] | D Loss: {D_loss.item():.4f} | G Loss: {G_loss.item():.4f}')

        if (epoch+1) % 10 == 0:
            visualize_comparison(real_images, fake_images)

In [None]:
# Function to display real and generated images side by side
def visualize_comparison(real_images, fake_images):
    real_images = (real_images + 1) / 2  # Rescale to [0, 1]
    fake_images = (fake_images + 1) / 2  # Rescale to [0, 1]

    fig, axs = plt.subplots(2, 5, figsize=(10, 4))
    for i in range(5):
        axs[0, i].imshow(real_images[i].squeeze().cpu().detach().numpy(), cmap='gray')  # Detach and convert to NumPy
        axs[0, i].set_title('Real')
        axs[0, i].axis('off')

        axs[1, i].imshow(fake_images[i].squeeze().cpu().detach().numpy(), cmap='gray')  # Detach and convert to NumPy
        axs[1, i].set_title('Fake')
        axs[1, i].axis('off')

    plt.show()

In [None]:
# Train and visualize results
train_GAN(generator, discriminator, data_loader, epochs)