In [2]:
import torch                                     # - IMPORT PYTORCH LIBRARY
import torch.nn as nn                            # - IMPORT NEURAL NETWORK MODULE
import torch.optim as optim                      # - IMPORT OPTIMIZATION ALGORITHMS
import torchvision                               # - IMPORT COMPUTER VISION LIBRARY
import torchvision.datasets as datasets          # - IMPORT DATASET UTILITIES
from torch.utils.data import DataLoader          # - IMPORT DATA LOADING TOOL
import torchvision.transforms as transforms      # - IMPORT DATA TRANSFORMATION TOOLS
from torch.utils.tensorboard import SummaryWriter # - IMPORT VISUALIZATION TOOL


In [3]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim,128),      # - INPUT LAYER: TRANSFORMS IMAGE TO HIDDEN LAYER
            nn.LeakyReLU(0.1),           # - ACTIVATION: ALLOWS SMALL NEGATIVE VALUES
            nn.Linear(128,1),            # - OUTPUT LAYER: PRODUCES SINGLE SCORE
            nn.Sigmoid(),                # - ACTIVATION: SQUASHES OUTPUT TO 0-1 RANGE
        )

    def forward(self, x):
        return self.disc(x)    # - FORWARD PASS: PROCESSES INPUT THROUGH NETWORK

class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),       # - INPUT LAYER: TRANSFORMS NOISE TO HIDDEN LAYER
            nn.LeakyReLU(0.1),           # - ACTIVATION: ALLOWS SMALL NEGATIVE VALUES
            nn.Linear(256, img_dim),     # - OUTPUT LAYER: PRODUCES FAKE IMAGE
            nn.Tanh(),                   # - ACTIVATION: SQUASHES OUTPUT TO -1 TO 1 RANGE
        )

    def forward(self, x):
        return self.gen(x)     # - FORWARD PASS: GENERATES FAKE IMAGE FROM INPUT NOISE

In [6]:
# HYPER PARAMETERS
device = "cuda" if torch.cuda.is_available() else "cpu" # - SET DEVICE TO GPU IF AVAILABLE
lr = 3e-4                                               # - LEARNING RATE FOR OPTIMIZATION
z_dim = 64                                              # - DIMENSION OF NOISE INPUT
image_dim = 28 * 28 * 1                                 # - FLATTENED MNIST IMAGE DIMENSION
batch_size = 32                                         # - NUMBER OF SAMPLES PER BATCH
num_epochs = 50                                         # - NUMBER OF TRAINING ITERATIONS

disc = Discriminator(image_dim).to(device)              # - CREATE AND MOVE DISCRIMINATOR TO DEVICE
gen = Generator(z_dim, image_dim).to(device)            # - CREATE AND MOVE GENERATOR TO DEVICE
fixed_noise = torch.randn((batch_size, z_dim)).to(device) # - CREATE FIXED NOISE FOR VISUALIZATION

transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,))] # - PREPROCESS IMAGES
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True) # - LOAD MNIST DATASET
loader = DataLoader(dataset,batch_size=batch_size, shuffle=True) # - CREATE DATA LOADER

opt_disc = optim.Adam(disc.parameters(), lr=lr)         # - OPTIMIZER FOR DISCRIMINATOR
opt_gen = optim.Adam(gen.parameters(), lr=lr)           # - OPTIMIZER FOR GENERATOR
criterion = nn.BCELoss()                                # - BINARY CROSS ENTROPY LOSS FUNCTION

writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")     # - TENSORBOARD WRITER FOR FAKE IMAGES
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")     # - TENSORBOARD WRITER FOR REAL IMAGES
steps = 0                                               # - COUNTER FOR TRAINING STEPS

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:15<00:00, 636368.46it/s] 


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 63882.32it/s]


Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:09<00:00, 181550.56it/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1400156.46it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw






In [8]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)            # - FLATTEN AND MOVE REAL IMAGES TO DEVICE
        batch_size = real.shape[0]
        
        # TRAIN DISCRIMINATOR: MAXIMIZE log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device) # - GENERATE RANDOM NOISE
        fake = gen(noise)                               # - GENERATE FAKE IMAGES
        disc_real = disc(real).view(-1)                 # - DISCRIMINATOR OUTPUT FOR REAL IMAGES
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))  # - LOSS FOR REAL IMAGES
        disc_fake = disc(fake).view(-1)                 # - DISCRIMINATOR OUTPUT FOR FAKE IMAGES
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) # - LOSS FOR FAKE IMAGES
        lossD = (lossD_real + lossD_fake) / 2           # - TOTAL DISCRIMINATOR LOSS
        disc.zero_grad()                                # - RESET DISCRIMINATOR GRADIENTS
        lossD.backward(retain_graph=True)               # - BACKPROPAGATE DISCRIMINATOR LOSS
        opt_disc.step()                                 # - UPDATE DISCRIMINATOR WEIGHTS

        # TRAIN GENERATOR: min log(1 - D(G(z))) <==> MAX log(D(G(z)))
        output = disc(fake).view(-1)                    # - DISCRIMINATOR OUTPUT FOR FAKE IMAGES
        lossG = criterion(output, torch.ones_like(output)) # - GENERATOR LOSS
        gen.zero_grad()                                 # - RESET GENERATOR GRADIENTS
        lossG.backward()                                # - BACKPROPAGATE GENERATOR LOSS
        opt_gen.step()                                  # - UPDATE GENERATOR WEIGHTS

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] "
                f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28) # - GENERATE FAKE IMAGES FROM FIXED NOISE
                data = real.reshape(-1, 1, 28, 28)             # - RESHAPE REAL IMAGES
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True) # - CREATE GRID OF FAKE IMAGES
                img_grid_real = torchvision.utils.make_grid(data, normalize=True) # - CREATE GRID OF REAL IMAGES

                writer_fake.add_image("Fake Images", img_grid_fake, global_step=steps) # - LOG FAKE IMAGES
                writer_real.add_image("Real Images", img_grid_real, global_step=steps) # - LOG REAL IMAGES

                steps += 1

Epoch [0/50] Loss D: 0.6733, Loss G: 0.7109
Epoch [1/50] Loss D: 0.2048, Loss G: 1.9953
Epoch [2/50] Loss D: 0.1054, Loss G: 2.9619
Epoch [3/50] Loss D: 0.0380, Loss G: 3.7025
Epoch [4/50] Loss D: 0.1167, Loss G: 4.6240
Epoch [5/50] Loss D: 0.0120, Loss G: 4.7344
Epoch [6/50] Loss D: 0.0121, Loss G: 5.0336
Epoch [7/50] Loss D: 0.0411, Loss G: 4.7770
Epoch [8/50] Loss D: 0.0186, Loss G: 5.4537
Epoch [9/50] Loss D: 0.1231, Loss G: 4.1787
Epoch [10/50] Loss D: 0.0102, Loss G: 6.5711
Epoch [11/50] Loss D: 0.0070, Loss G: 5.5123
Epoch [12/50] Loss D: 0.0277, Loss G: 6.9773
Epoch [13/50] Loss D: 0.1368, Loss G: 5.9380
Epoch [14/50] Loss D: 0.0093, Loss G: 5.7800
Epoch [15/50] Loss D: 0.0108, Loss G: 5.6349
Epoch [16/50] Loss D: 0.0158, Loss G: 6.5621
Epoch [17/50] Loss D: 0.0188, Loss G: 4.8580
Epoch [18/50] Loss D: 0.0120, Loss G: 5.4685
Epoch [19/50] Loss D: 0.0217, Loss G: 6.3864
Epoch [20/50] Loss D: 0.0027, Loss G: 6.7986
Epoch [21/50] Loss D: 0.0052, Loss G: 5.9710
Epoch [22/50] Loss D