In [None]:
# Code adapted from: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/ac5dcd03a40a08a8af7e1a67ade37f28cf88db43/ML/Pytorch/GANs/2.%20DCGAN/train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as tfms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import os, math
import random
import numpy as np

## Declare GAN Structure

In [None]:
class Discriminator(nn.Module):
    def __init__(self, features, device='cpu'):
        super(Discriminator, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(1, features, kernel_size=3, stride=1, padding=1, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(features, features * 2, kernel_size=3, stride=1, padding=1, bias=False, device=device),
            nn.InstanceNorm2d(features * 2, affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(features * 2, features * 4, kernel_size=3, stride=1, padding=1, bias=False, device=device),
            nn.InstanceNorm2d(features * 4, affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(features * 4, 1, kernel_size=3, stride=1, padding=1, device=device), # convert to single channel
            nn.AdaptiveAvgPool2d(1),    # pool the matrix into a single value for sigmoid
        )

    def forward(self, x):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)
        return y

In [None]:
# 3 input channels (noise, map, initial path)
class Generator(nn.Module):
    def __init__(self, features, device='cpu'):
        super(Generator, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(1, features, 3, 1, 1, device=device),
            nn.ReLU()
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(features, features*2, 3, 1, 1, device=device),
            nn.InstanceNorm2d(features*2, affine=True, device=device),
            nn.ReLU()
        )

        self.block3 = nn.Sequential(
            nn. Conv2d(features*2, features, 3, 1, 1, device=device),
            nn.InstanceNorm2d(features, affine=True, device=device),
            nn.ReLU()
        )

        self.block4 = nn.Sequential(
            nn. Conv2d(features, 1, 3, 1, 1, device=device),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)

        y = y*255
        y = self._round(y)
        return y
    
    def _round(self, mat):
        # TODO: cite something? (this function is based off of Thor's code)
        mat_hard = torch.round(mat)
        mat = (mat_hard - mat.data) + mat

        return mat

## Define Essential Functions

In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
def gradient_penalty(coeff, critic, real, fake, device="cpu"):
    # sample x_hat from P(x_hat)
    rand = torch.randn((real.shape[0], 1, 1, 1), device=device) # generate a random number from 0 to 1 for each matrix in the batch
    x_hat = rand*real + (1-rand)*fake

    critic_output = critic(x_hat)
    grad_ones = torch.ones_like(critic_output, device=device)

    gp = torch.autograd.grad(                                   # find magnitude of critic's resulting gradient
        inputs = x_hat,
        outputs = critic_output,
        grad_outputs = grad_ones,
        create_graph = True,
        retain_graph = True
    )[0]

    gp = torch.norm(gp, p=2, dim=(1,2,3))    # vector norm of each gradient
    gp = (gp - 1)**2
    gp = coeff * torch.mean(gp)

    return gp

## Set Constants, etc.

Testbed Parameters

In [None]:
DATASET = 'mnist'   # Data at benchmark/{DATASET}/data/
RUN_ID = 't1'       # Checkpoints at benchmark/{DATASET}/checkpoints/{RUN_ID}/

LOAD = False
SAVE = False
curr_epoch = 0

GEN_PATH = f'./benchmark/{DATASET}/checkpoints/{RUN_ID}/generator/'
DISC_PATH = f'./benchmark/{DATASET}/checkpoints/{RUN_ID}/critic/'
LOAD_EPOCH = 0  # The epoch checkpoint to load

GAN Hyperparameters

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 50

LEARNING_RATE = 1e-5
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NUM_EPOCHS = 10
FEATURES_DISC = 64
FEATURES_GEN = 64

NOISE_SHAPE = (BATCH_SIZE, 1, 0, 0)

#Speicific to WGAN
CRITIC_ITERATIONS = 5 # how many times the critic loop runs for each generator loop
LAMBDA_GP = 10

## Initialize Data & GAN

Define dataset transforms

In [None]:
class AddAffine(nn.Module):
    def __init__(self):
        super().__init__()
    
    # Assumes input is a 3D matrix (C,H,W)
    def forward(self, image):
        max = (image.shape[2]-1, image.shape[1]-1)
        affine = torch.zeros_like(image[:1,:,:])

        # Draw lines between the corners of the affine matrix
        x = 0
        b1, b2 = (0,max[1])
        m1 = max[1] / max[0]
        m2 = -max[1] / max[0]
        while (x <= max[0]):
            y1 = m1*x + b1
            y2 = m2*x + b2
            y1 = round(y1)
            y2 = round(y2)
            affine[0,y1,x] = 1
            affine[0,y2,x] = 1
            x += 1

        y = 0
        b1, b2 = (0,max[0])
        m1 = max[0] / max[1]
        m2 = -max[0] / max[1]
        while (y <= max[1]):
            x1 = m1*y + b1
            x2 = m2*y + b2
            x1 = round(x1)
            x2 = round(x2)
            affine[0,y,x1] = 1
            affine[0,y,x2] = 1
            y += 1

        # draw line along the top
        x = 0
        while (x <= max[0]):
            affine[0,0,x] = 1
            x += 1
        
        # Append affine to image along channels axis
        out = torch.concat((image, affine), axis=0)
        return out

In [None]:
class RoundImg(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, image):
        # TODO: cite something? (this function is based off of Thor's code)
        image_hard = torch.round(image)
        image = (image_hard - image.data) + image

        return image

In [None]:
tf = tfms.Compose(
    [
        tfms.ToTensor(),
        nn.Sequential(
            # RoundImg(),
            # AddAffine(),
            # tfms.RandomAffine(degrees=180, translate=(0.5,0.5), scale=(0.5,1.5), shear=None)
            # AddLabel()
        )
    ]
)

Load Data & Initialize GAN

In [None]:
data_train = datasets.MNIST(root='benchmark/datasets/', train=True, download=True, transform=tf)
dataloader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

gen = Generator(FEATURES_GEN, device=device)
critic = Discriminator(FEATURES_DISC, device=device)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas = (0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas = (0.0, 0.9))

In [None]:
if LOAD:
    # Load gen
    checkpoint = torch.load(f'{GEN_PATH}epoch-{LOAD_EPOCH}.tar')
    gen.load_state_dict(checkpoint['model_state_dict'])
    opt_gen.load_state_dict(checkpoint['optimizer_state_dict'])
    curr_epoch = checkpoint['epoch']
    loss_gen = checkpoint['loss']

    # Load critic
    checkpoint = torch.load(f'{DISC_PATH}epoch-{LOAD_EPOCH}.tar')
    critic.load_state_dict(checkpoint['model_state_dict'])
    opt_critic.load_state_dict(checkpoint['optimizer_state_dict'])
    curr_epoch = checkpoint['epoch']
    loss_critic = checkpoint['loss']
else:
    initialize_weights(gen)
    initialize_weights(critic)

In [None]:
fixed_noise = None
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
# writer_affine = SummaryWriter(f'logs/affine')
# writer_labels = SummaryWriter(f'logs/labels')
step = 0

## Train GAN

In [None]:
gen.train()
critic.train()

In [None]:
for epoch in range(NUM_EPOCHS):
    curr_epoch += 1
    for batch_idx, (imgs, labels) in enumerate(dataloader):
        NOISE_SHAPE = (imgs.shape[0], 1, imgs.shape[2], imgs.shape[3])
        if fixed_noise == None:
            fixed_noise = torch.randn(NOISE_SHAPE, device=device)
        # labels = labels[:,None,None,None]
        # labels = labels.expand(NOISE_SHAPE).to(device)

        real = imgs.to(device)
        # real = torch.concat((real, labels), axis=1)

        # affine = real[:,1:2,:,:]
        noise = torch.randn(NOISE_SHAPE, device=device)
        # noise = torch.concat((noise, affine, labels), axis=1)
        fake = gen(noise)

        # fixed_input = torch.concat((fixed_noise, affine, labels), axis=1)

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(NOISE_SHAPE, device=device)
            # noise = torch.concat((noise, affine, labels), axis=1)
            fake = gen(noise)
            # fake = torch.concat((fake, affine, labels), axis=1)
            critic_real = critic(real)
            critic_fake = critic(fake)
            gp = gradient_penalty(LAMBDA_GP, critic, real, fake, device=device) # compute the gradient penalty
            loss_critic = (
                torch.mean(critic_fake) - torch.mean(critic_real) + gp
            )
                                                                            #   optim algorithms are for minimizing so take - 
            critic.zero_grad()
            loss_critic.backward(retain_graph=True) # want to re use the computations for fake for generator
            opt_critic.step()

        ### Training generator: min E(critic(gen_fake))
        output = critic(fake)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{curr_epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} " +     # TODO: print correct ending epoch based on initial (loaded) epoch num
                  f"Loss D: {loss_critic:.4f}, Lambda GP: {LAMBDA_GP*gp:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # fake = gen(fixed_input)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:BATCH_SIZE,:,:,:], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:BATCH_SIZE,:,:,:], normalize=True
                )
                # img_grid_labels = torchvision.utils.make_grid(
                #     real[:BATCH_SIZE,:,:,:], normalize=True
                # )
                # img_grid_affine = torchvision.utils.make_grid(
                #     real[:BATCH_SIZE,:,:,:], normalize=True
                # )

                writer_real.add_image("Digits", img_grid_real, global_step=step)
                writer_fake.add_image("Digits", img_grid_fake, global_step=step)
                # writer_labels.add_image('Info', img_grid_labels, global_step=step)
                # writer_labels.add_image('Info', img_grid_labels, global_step=step)

            step += 1

    # save generator checkpoint
    if SAVE:
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': gen.state_dict(),
                    'optimizer_state_dict': opt_gen.state_dict(),
                    'loss': loss_gen,
        }, f"{GEN_PATH}epoch-{epoch}.tar")

        # save critic checkpoint
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': critic.state_dict(),
                    'optimizer_state_dict': opt_critic.state_dict(),
                    'loss': loss_critic,
        }, f"{DISC_PATH}epoch-{epoch}.tar")