In [42]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 128
image_size = 28*28
num_epochs = 100

## Download Dataset

In [43]:
# Device configuration
device = torch.device('cuda')

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

## Models

In [44]:
# # Generator
# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()
#         self.model = nn.Sequential(
#             nn.Linear(latent_dim, 256),
#             nn.BatchNorm1d(256),
#             nn.LeakyReLU(0.2),
#             nn.Linear(256, 512),
#             nn.BatchNorm1d(512),
#             nn.LeakyReLU(0.2),
#             nn.Linear(512, 1024),
#             nn.BatchNorm1d(1024),
#             nn.LeakyReLU(0.2),
#             nn.Linear(1024, image_size),
#             nn.BatchNorm1d(image_size),
#             nn.Tanh()
#         )
# 
#     def forward(self, z):
#         return self.model(z)
# 
# # Discriminator
# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()
#         self.model = nn.Sequential(
#             nn.Linear(image_size, 1024),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(0.3),
#             nn.Linear(1024, 512),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(0.3),
#             nn.Linear(512, 256),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(0.3),
#             nn.Linear(256, 1),
#             nn.Sigmoid()
#         )
# 
#     def forward(self, img):
#         img_flat = img.view(img.size(0), -1)
#         return self.model(img_flat)
class Generator(nn.Module):
    def __init__(self, img_channels=1):
        super(Generator, self).__init__()
        self.init_size = 7  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.model = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(64, img_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.model(out)
        return img


class Discriminator(nn.Module):
    def __init__(self, img_channels=1):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        # Fully connected layer to output a single validity score per image
        self.adv_layer = nn.Sequential(
            nn.Linear(512 * 2 * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.size(0), -1)  # Flatten the output
        validity = self.adv_layer(out)
        return validity
    

## Training Stage

In [45]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

# Loss function
criterion = nn.BCELoss()

# Training
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        real_images = images.to(device)
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # --------- Train the Discriminator --------- #
        d_optimizer.zero_grad()
        outputs = discriminator(real_images)
        d_real_loss = criterion(outputs, real_labels)
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach())
        d_fake_loss = criterion(outputs, fake_labels)
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        # --------- Train the Generator --------- #
        g_optimizer.zero_grad()
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()

        if (i+1) % 400 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

    # Save generated images every epoch
    save_image(fake_images.data[:25], f'./data/gan/fake_image_{epoch+1:03d}.png', nrow=5, normalize=True)

print("Training complete.")

Epoch [1/100], Step [400/468], D Loss: 0.9690850973129272, G Loss: 1.4154253005981445
Epoch [2/100], Step [400/468], D Loss: 0.6552090644836426, G Loss: 1.7175602912902832
Epoch [3/100], Step [400/468], D Loss: 0.9909242391586304, G Loss: 1.3598992824554443
Epoch [4/100], Step [400/468], D Loss: 1.1192089319229126, G Loss: 1.3783931732177734
Epoch [5/100], Step [400/468], D Loss: 0.8978086709976196, G Loss: 1.32565438747406
Epoch [6/100], Step [400/468], D Loss: 1.1752640008926392, G Loss: 1.1328778266906738
Epoch [7/100], Step [400/468], D Loss: 1.0285788774490356, G Loss: 1.328322410583496
Epoch [8/100], Step [400/468], D Loss: 0.9722626805305481, G Loss: 1.393446683883667
Epoch [9/100], Step [400/468], D Loss: 1.0495445728302002, G Loss: 1.2878639698028564
Epoch [10/100], Step [400/468], D Loss: 0.9416558742523193, G Loss: 1.592057466506958
Epoch [11/100], Step [400/468], D Loss: 0.7443684339523315, G Loss: 1.6719391345977783
Epoch [12/100], Step [400/468], D Loss: 0.926053047180175