In [82]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import os
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [58]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor(), # converts to float AND divides by 255 (normalize).
    transforms.Normalize(mean=[0.5], std=[0.5])  # scale to [-1, 1]
])

# Load Fashion-MNIST dataset
train_dataset = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=transform)
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)

test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# get images and labels from loader
numbers, labels = next(iter(train_loader))

# visualize one image
images = numbers.numpy()
# get one image from the batch
img = np.squeeze(images[0])
fig = plt.figure(figsize = (1,1)) 
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')

In [None]:
#visualize a batch of images
grid = torchvision.utils.make_grid(numbers, nrow=8, padding=0, scale_each=True)
fig = plt.figure(figsize=(16,4))
plt.imshow(grid.cpu().permute(1,2,0))
plt.axis('off')
plt.show()

Define the Model:  A GAN is comprised of two adversarial networks, a discriminator and a generator.

Discriminator : The discriminator network is going to be a pretty typical linear classifier. To make this network a universal function approximator, we'll need at least one hidden layer, and these hidden layers should have a Leaky ReLu activation function applied to their outputs

In [84]:
class Discriminator(nn.Module):

    def __init__(self, input_size, hidden_dim, output_size):
        super(Discriminator, self).__init__()

        self.disciminator = nn.Sequential(
            nn.Linear(input_size, hidden_dim*4),
            nn.ReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim*4, hidden_dim*2),
            nn.ReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, output_size),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # flatten image
        out = self.disciminator(x)

        return out

Generator : The generator network will be almost exactly the same as the discriminator network, except that we're applying a tanh activation function to our output layer. Tanh scales the output to be between -1 and 1, instead of 0 and 1. these outputs to be comparable to the real input pixel values, which are read in as normalized values between 0 and 1.

In [74]:
class Generator(nn.Module):

    def __init__(self, input_size, hidden_dim, output_size):
        super(Generator, self).__init__()

        self.generator = nn.Sequential(
            nn.Linear(input_size, hidden_dim),
            nn.ReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim*2),
            nn.ReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim*2, hidden_dim*4),
            nn.ReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim*4, output_size),
            nn.Tanh()
        )
        
    def forward(self, x):
        out = self.generator(x)
        return out

In [83]:
# Discriminator hyperparams
input_size    = 784   # Size of input image to discriminator (28*28)
d_hidden_size = 32    # Size of last hidden layer in the discriminator
d_output_size = 1     # Size of discriminator output (real or fake)

# Generator hyperparams
z_size        = 100  # Size of latent vector to give to generator
g_hidden_size = 32   # Size of first hidden layer in the generator
g_output_size = 784  # Size of discriminator output (generated image)

In [76]:
# build models and move to device
D = Discriminator(input_size, d_hidden_size, d_output_size)
G = Generator(z_size, g_hidden_size, g_output_size)

# Send models to device
D = D.to(device)
G = G.to(device)

In [77]:
# Loss and optimizers
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0001)
g_optimizer = optim.Adam(G.parameters(), lr=0.0001)

In [78]:
# For generating latent noise
def generate_noise(batch_size, latent_size):
    return torch.randn(batch_size, latent_size, device=device)

In [None]:
num_epochs = 150

# Training loop
for epoch in range(num_epochs):
    for real_images, _ in train_loader:
        real_images = real_images.view(-1, input_size).to(device)
        batch_size_curr = real_images.size(0)

        # Labels for real and fake images
        real_labels = torch.ones(batch_size_curr, 1).to(device)
        fake_labels = torch.zeros(batch_size_curr, 1).to(device)

        # Train Discriminator
        outputs = D(real_images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        noise = generate_noise(batch_size_curr, z_size)
        fake_images = G(noise)
        outputs = D(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake

        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        noise = generate_noise(batch_size_curr, z_size)
        fake_images = G(noise)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)  # We want the fake images to be classified as real

        G.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

    # Generate and show fake images every 10 epochs
    if (epoch+1) % 30 == 0:
        with torch.no_grad():
            fake_images = fake_images.reshape(-1, 1, 28, 28)
            fake_images = fake_images[:16]
            grid = torchvision.utils.make_grid(fake_images, nrow=4, normalize=True)
            plt.figure(figsize=(3,3))
            plt.imshow(grid.permute(1,2,0).cpu())
            plt.axis('off')
            plt.title(f'Fake images at epoch {epoch+1}')
            plt.show()