In [39]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Dataset
import torchvision.datasets as datasets
from torchvision.datasets import MNIST, CIFAR10
import torch.optim as optim

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter  

import warnings
warnings.filterwarnings("ignore")

In [40]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 784
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50

In [41]:
# dataset module

class MNISTDataset(Dataset):
    '''
    downloads MNIST dataset, performs splitting and transformation, and returns dataloaders
    '''
    def __init__(self, root = './data', download = True, transform = None):
        # download mnist dataset
        self.mnist = MNIST(root = root, download = download)

        # default transformation if no specific transformation is provided
        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
        else:
            self.transform = transform

        self.indices = list(range(len(self.mnist)))

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        img, _ = self.mnist[self.indices[idx]]
    
        if self.transform:
            img = self.transform(img)

        return img
    
    def get_dataloader(self, batch_size = batch_size, shuffle = True):
        return DataLoader(self, batch_size = batch_size, shuffle = shuffle)


In [42]:
train_dataset = MNISTDataset()
train_dataloader = train_dataset.get_dataloader()


In [43]:
# discriminator class
class Discriminator(nn.Module):
    def __init__(self, in_channels = 1):
        super().__init__()
        # Simple CNN
        self.in_channels = in_channels

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=4, stride = 2, padding = 1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride = 2, padding = 1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1)

        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(256)

        self.fc = nn.Linear(256 * 3 * 3, 1)

  
    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2, inplace = True)
        x = F.leaky_relu(self.bn1(self.conv2(x)), 0.2, inplace = True)
        x = F.leaky_relu(self.bn2(self.conv3(x)), 0.2, inplace = True)
        # Flatten the tensor so it can be fed into the FC layers
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return torch.sigmoid(x)


In [44]:
class Generator(nn.Module):
    '''
    Generates new images from random noise
    in: latent_dim 256*7*7
    out: 28x28
    '''
    def __init__(self, z_dim):
        super().__init__()
        self.gen = nn.Sequential(
        nn.Linear(z_dim, 7*7*64),
        nn.ReLU(),
        nn.Unflatten(1, (64, 7, 7)),
        nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2),  # 1x1 → 7x7
        nn.ReLU(),
        nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2),  # 7x7 → 14x14
        nn.ReLU(),
        )
        self.conv = nn.Conv2d(16, 1, kernel_size = 7)
    
    def forward(self, x):
        return self.conv(self.gen(x))

In [None]:
disc = Discriminator().to(device)
gen = Generator(z_dim).to(device)

fixed_noise = torch.randn((batch_size, z_dim)).to(device)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real) in enumerate(tqdm(train_dataloader)):
        real = real.to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(train_dataloader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                data = real
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

In [None]:
# GAN class


In [None]:
# training loop

In [None]:
# inference block