In [None]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="2"
import torch
from datetime import datetime
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

from models import Generator_28, Discriminator_28
from utils import get_noise, get_gradient_penalty, show_tensor_images, compute_fid
from imagenet_c import corrupt


In [None]:
architecture = 'DCGAN' # DCGAN or WGAN-GP
conditional = True
corrupt_dataset = False

n_epochs = 30
z_dim = 64
display_step = 500
compute_fid_every_x_epochs = 5
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
device = 'cuda'
path_mnist_png = '../../data/MNIST/png/train'

if architecture == 'DCGAN': 
    criterion = nn.BCEWithLogitsLoss()
    crit_repeats = 1
    r1_gamma = 10
elif architecture == 'WGAN-GP':
    crit_repeats = 5
    c_lambda = 10

im_shape = (1, 28, 28)
n_classes = 10 if conditional else 0 # number of classes
im_chan = im_shape[0] # 1 for black and white

corruption_transform = transforms.Lambda(
    lambda x: corrupt(np.uint8(x), 
    np.random.randint(1, 6), 
    corruption_name='gaussian_noise'))

if corrupt_dataset:
    transform = transforms.Compose([
        corruption_transform,
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

dataloader = DataLoader(
    MNIST('../../data', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

In [None]:
generator_input_dim = z_dim + n_classes
discriminator_im_chan = im_chan + n_classes

gen = Generator_28(generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Discriminator_28(discriminator_im_chan).to(device) 
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

In [None]:
import matplotlib.pyplot as plt

cur_step = 0
generator_losses = []
critic_losses = []
fids = []
for epoch in range(1, n_epochs+1):
    # Dataloader returns the batches
    for real, labels in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        # Format labels
        if conditional:
            one_hot_labels = nn.functional.one_hot(labels.to(device), n_classes)
            image_one_hot_labels = one_hot_labels[:, :, None, None]
            image_one_hot_labels = image_one_hot_labels.repeat(1, 1, im_shape[1], im_shape[2])

        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            ### Update critic/discriminator ###
            crit_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            if conditional:
                noise_and_labels = torch.cat((fake_noise.float(), one_hot_labels.float()), dim=1)
                fake = gen(noise_and_labels)
                fake_image_and_labels = torch.cat((fake.detach().float(), image_one_hot_labels.float()), dim=1)
                real_image_and_labels = torch.cat((real.float(), image_one_hot_labels.float()), dim=1)
                crit_fake_pred = crit(fake_image_and_labels)
                crit_real_pred = crit(real_image_and_labels)
            else:
                fake = gen(fake_noise)
                crit_fake_pred = crit(fake.detach())
                crit_real_pred = crit(real)

            if architecture == 'DCGAN':
                crit_fake_loss = criterion(crit_fake_pred, torch.zeros_like(crit_fake_pred))
                crit_real_loss = criterion(crit_real_pred, torch.ones_like(crit_real_pred))
                crit_loss = (crit_fake_loss + crit_real_loss) / 2
                if r1_gamma > 0: # R1 regularisation from https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/training/loss.py
                    if conditional:
                        real_tmp = real_image_and_labels.detach().requires_grad_(True)
                    else:
                        real_tmp = real.detach().requires_grad_(True)
                    crit_real_pred_tmp = crit(real_tmp)
                    r1_grads = torch.autograd.grad(outputs=crit_real_pred_tmp.sum(), inputs=real_tmp, create_graph=True)[0]
                    # r1_grads = torch.autograd.grad(outputs=crit_real_pred_tmp, inputs=real_tmp, create_graph=True, grad_outputs=torch.ones_like(crit_real_pred))[0]
                    # r1_grads = r1_grads.view(r1_grads.shape[0], -1)
                    # r1_penalty = r1_grads.norm(2, dim=1).mean()
                    r1_penalty = r1_grads.square().sum([1,2,3]).mean()
                    loss_Dr1 = r1_penalty * (r1_gamma / 2)
                    crit_loss += loss_Dr1
            elif architecture == 'WGAN-GP':
                epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
                if conditional:
                    gp = get_gradient_penalty(crit, real_image_and_labels, fake_image_and_labels.detach(), epsilon)
                else:
                    gp = get_gradient_penalty(crit, real, fake.detach(), epsilon)
                crit_loss = torch.mean(crit_fake_pred - crit_real_pred) + c_lambda*gp

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            crit_opt.step()
        critic_losses += [mean_iteration_critic_loss]

        ### Update generator ###
        gen_opt.zero_grad()
        if conditional: # WHY DO NOT GENERATE AGAIN NEW FAKES?
            fake_image_and_labels = torch.cat((fake.float(), image_one_hot_labels.float()), dim=1)
            crit_fake_pred = crit(fake_image_and_labels)
        else:
            fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
            fake_2 = gen(fake_noise_2)
            crit_fake_pred = crit(fake_2)
        
        if architecture == 'DCGAN':
            gen_loss = criterion(crit_fake_pred, torch.ones_like(crit_fake_pred))
        elif architecture == 'WGAN-GP':
            # gen_loss = get_gen_loss(crit_fake_pred)
            gen_loss = -torch.mean(crit_fake_pred)
        gen_loss.backward()

        # Update the weights
        gen_opt.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]


        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            show_tensor_images(fake)
            # show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()
            plt.figure()
            plt.plot(np.arange(1, epoch+1, compute_fid_every_x_epochs), fids)
            plt.xlabel('Epochs')
            plt.ylabel('FID')
            plt.show()

        cur_step += 1
        
        
    # Compute FID
    if epoch == 1 or epoch % compute_fid_every_x_epochs == 0:
        fid = compute_fid(gen, conditional, n_classes, path_mnist_png)
        fids.append(fid)

In [None]:
# cond_str = '_conditional' if conditional else ''
# corrupted_str = '_corrupted' if corrupt_dataset else ''
# fname = f'../models/{architecture}{cond_str}{corrupted_str}_MNIST_weights_{datetime.now().strftime("%Y%m%d_%H%M")}.pth'
# torch.save(gen.state_dict(), fname)