# Hyperparameters

In [1]:
import torch

# Select cuda device
print(torch.cuda.device_count())
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

8


# Imports

In [3]:
import gc
import math
import matplotlib.pyplot as plt
import numpy as np
import os
from os.path import join
import sys
from tqdm import tqdm

# Taken from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py 
from pytorchtools import EarlyStopping

import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

from torchmetrics.image.fid import FrechetInceptionDistance

from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image


def make_dir(folder):
    os.makedirs(folder, exist_ok=True)
    return folder


ROOT = make_dir( #### make your directory #####)
SAMPLES_DIR  = make_dir(join(ROOT, "samples"))
DATASETS_DIR = make_dir(join(ROOT, "datasets"))
MODELS_DIR   = make_dir(join(ROOT, "models"))
METRICS_DIR  = make_dir(join(ROOT, "metrics"))

2023-03-07 11:22:45.119591: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-07 11:22:45.804003: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-03-07 11:22:45.804066: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


# Generator + Discriminator

In [4]:
class Generator(nn.Module):
    def __init__(
        self,
        img_shape:   tuple = (1,28,28),
        num_classes: int   = 10,
        latent_dim:  int   = 100
        ):
      
        super(Generator, self).__init__()

        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim + num_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *self.img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(
        self,
        img_shape: tuple = (1, 28, 28),
        num_classes: int = 10
        ):
      
        super(Discriminator, self).__init__()

        self.img_shape = img_shape

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.img_shape) + num_classes), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

# WGAN_GP Class

In [5]:
def preprocess(dataset: TensorDataset):
    return TensorDataset(dataset.tensors[0] / 255 * 2 - 1, dataset.tensors[1])


class WGAN_GP(nn.Module):
    def __init__(
        self,
        device: torch.device,
        generation: int,
        hparams: dict = {
            'num_classes': 10,
            'channels'   : 1,
            'width'      : 28,
            'height'     : 28,
            'latent_dim' : 100,
            'batch_size' : 64,
            'n_critic'   : 5,
            'lr'         : 0.0002,
            'b1'         : 0.5,
            'b2'         : 0.999,
            'lambda_gp'  : 10
        }
    ):

        super(WGAN_GP, self).__init__()
        self.device = device
        self.generation = generation
        self.hparams = hparams

        # Modules
        self.generator = Generator().to(self.device)
        self.discriminator = Discriminator().to(self.device)

        # Optimizers
        self.opt_G = torch.optim.Adam(self.generator.parameters(),
                                      lr=self.hparams['lr'],
                                      betas=(self.hparams['b1'], self.hparams['b2']))
        self.opt_D = torch.optim.Adam(self.discriminator.parameters(),
                                      lr=self.hparams['lr'],
                                      betas=(self.hparams['b1'], self.hparams['b2']))
        
        # Samples
        self.validation_z = self.sample_z(self.hparams['num_classes'])
        self.validation_labels = torch.arange(self.hparams['num_classes'], device=self.device)



    def compute_gradient_penalty(self, real_samples, fake_samples):
        """Calculates the gradient penalty loss for WGAN GP"""
        # Random weight term for interpolation between real and fake samples
        alpha = torch.rand(real_samples.size(0), 1, device=self.device) # Only one 1 required b/c each sample is already flattened
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        d_interpolates = self.discriminator(interpolates)
        fake = Variable(torch.ones(real_samples.shape[0], 1, device=self.device), requires_grad=True)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
          outputs=d_interpolates,
          inputs=interpolates,
          grad_outputs=fake,
          create_graph=True,
          retain_graph=True,
          only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1) #.to(self.device)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty
        

    def one_hot(self, labels):
        return F.one_hot(labels.long(), self.hparams['num_classes'])

    def cat(self, imgs, labels):
        one_hot_labels = self.one_hot(labels) if labels.dim() == 1 else labels
        return torch.cat((imgs, one_hot_labels), dim=1) 
      
    def sample_z(self, length):
        return Variable(torch.randn(length, self.hparams['latent_dim'], device=self.device))

    def forward(self, z, labels):
        return self.generator(self.cat(z, labels))

    def flatten(self, imgs):
        return imgs.reshape(-1, self.hparams['channels'] * self.hparams['width'] * self.hparams['height'])
    
    
    def training_step(self, batch, batch_idx):

        imgs = batch[0].to(self.device)
        labels = batch[1].to(self.device)

        self.opt_D.zero_grad()

        # Sample latents
        z = self.sample_z(len(labels))


        ########################################################
        # Train Discriminator
        ########################################################

        # Validity of fake images
        fake_imgs = self.cat(self.flatten(self(z, labels)), labels)
        fake_validity = self.discriminator(fake_imgs)

        # Validity of real images
        real_imgs = self.cat(self.flatten(Variable(imgs)), labels)
        real_validity = self.discriminator(real_imgs)

        # Calculate Discriminator Loss
        gradient_penalty = self.compute_gradient_penalty(real_imgs, fake_imgs)
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + self.hparams['lambda_gp'] * gradient_penalty

        d_loss.backward()
        self.opt_D.step()
        self.opt_G.zero_grad()


        ########################################################
        # Train Generator
        ########################################################
        if batch_idx % self.hparams['n_critic'] == 0:

            # Validity of fake images
            fake_imgs = self.cat(self.flatten(self(z, labels)), labels)
            fake_validity = self.discriminator(fake_imgs)

            g_loss = -torch.mean(fake_validity)

            g_loss.backward()
            self.opt_G.step()

            return float(d_loss.item()), float(self.hparams['lambda_gp'] * gradient_penalty.item()), float(g_loss.item())
        
        else:
            return float(d_loss.item()), float(self.hparams['lambda_gp'] * gradient_penalty.item()), None
        
        
    def postprocess(self, tensor: torch.Tensor):
        return ((tensor / 2 + 0.5).clamp(0, 1) * 255).round().to(torch.uint8)
        
        
    @torch.no_grad()
    def make_new_dataset(self, previous_dataset, train=False, save=False): #r=1):
        # Generate and save a new dataset.
        # Save the current model.
        labels = previous_dataset.tensors[1].to(self.device)
        imgs = self(self.sample_z(len(labels)), labels).detach().cpu()
        new_dataset = TensorDataset(self.postprocess(imgs), labels.cpu())
        
        if save:
            if train:
                torch.save(new_dataset, join(DATASETS_DIR, f'train_{g:03}.pt'))
            else:
                torch.save(new_dataset, join(DATASETS_DIR, f'test_{g:03}.pt'))
        del imgs
        
        # Convert new dataset to [-1, 1]
        return preprocess(new_dataset)
    
    @torch.no_grad()
    def calculate_fid(self, dataset: TensorDataset):
        # Calculate FID:
        
        real_imgs = dataset.tensors[0].to(self.device)
        labels    = dataset.tensors[1].to(self.device)
        
        fake_imgs = self(self.sample_z(len(labels)), labels).detach()
        fake_imgs = self.postprocess(fake_imgs)
        real_imgs = self.postprocess(real_imgs)

        # New shape will be be (N_fid, 3, H, W)
        fake_imgs = torch.cat([fake_imgs] * 3, axis=1)
        real_imgs = torch.cat([real_imgs] * 3, axis=1)

        # Normalize=True -> images should be in [0, 1]
        # Normalize=False -> images should be in [0, 255]
        fid = FrechetInceptionDistance(normalize=False).to(self.device)

        batch = 50
        for i in range(len(labels) // batch):
            fid.update(real_imgs[int(i*batch):int(i+1)*batch], real=True)
            fid.update(fake_imgs[int(i*batch):int(i+1)*batch], real=False)

        val = float(fid.compute().item())
        del fake_imgs
        return val

# Training

In [6]:
real_train = MNIST(os.getcwd(), train=True,  download=True)
real_test  = MNIST(os.getcwd(), train=False, download=True)
# Normalize to [-1, 1] and put into a TensorDataset
real_train = preprocess(TensorDataset(real_train.data.unsqueeze(1), real_train.targets))
real_test  = preprocess(TensorDataset(real_test.data.unsqueeze(1),  real_test.targets))

In [None]:
generations = 50
epochs = 1000 + 1 #400 + 1
sample_interval = 20 #10

# Initialize training dataset to real dataset if no other datasets have been generated
if len(os.listdir(DATASETS_DIR)) == 0:
    train_dataset = real_train
    test_dataset  = real_test
    init_gen = 0
# Initialize training dataset to synthetic dataset if one already exists
else:
    # Find out which was the last generation
    prev_gen = max([int(f.split('_')[1].split('.')[0]) for f in os.listdir(DATASETS_DIR)])
    # Load the corresponding 60k and 10k datasets
    train_dataset = torch.load(join(DATASETS_DIR, f'train_{prev_gen:03}.pt'))
    test_dataset  = torch.load(join(DATASETS_DIR, f'test_{prev_gen:03}.pt'))
    # Normalize to [-1, 1]
    train_dataset = preprocess(train_dataset)
    test_dataset  = preprocess(test_dataset)
    init_gen = prev_gen + 1
    del prev_gen
    
    
for g in np.arange(generations) + init_gen:
    wgan = WGAN_GP(device, g)
    early_stopping = EarlyStopping(patience=2, path=join(MODELS_DIR, 'checkpoint.pt'),
                                   trace_func=lambda x: None)
    train_dataloader = DataLoader(train_dataset, batch_size = wgan.hparams['batch_size'],
                                  shuffle=True, pin_memory=True)
    fid_madc = []
    
    
    for epoch in tqdm(range(epochs)):
        for i, batch in enumerate(train_dataloader):
            out = wgan.training_step(batch, i)
        
        if epoch % sample_interval == 0:
            newest_fid = wgan.calculate_fid(test_dataset)
            fid_madc.append(newest_fid)
            early_stopping(newest_fid, wgan)
            
        if early_stopping.early_stop:
            break
            
    # Load the best model
    wgan.load_state_dict(torch.load(join(MODELS_DIR, 'checkpoint.pt')))
                
    # Rename the best model ('checkpoint.pt') to 'gan_{g}.pt'
    os.rename(join(MODELS_DIR, 'checkpoint.pt'),
              join(MODELS_DIR, f'gan_{g:03}.pt'))
            
    # Calculate FID wrt real dataset. Save FID scores as .npz
    np.savez(join(METRICS_DIR, f'fid_{g:03}.npz'),
             fid_madc=fid_madc,
             fid_real=wgan.calculate_fid(real_test))
    
    # Generate and save a new dataset. Save the current model.
    train_dataset = wgan.make_new_dataset(train_dataset, train=True, save=True)
    test_dataset  = wgan.make_new_dataset(test_dataset, train=False, save=True)
    # Free up CUDA memory.
    del wgan, train_dataloader
    #torch.cuda.empty_cache()
    gc.collect()

 14%|██████████▉                                                                   | 140/1001 [18:53<1:56:08,  8.09s/it]
  7%|█████▌                                                                         | 71/1001 [09:31<1:24:47,  5.47s/it]

In [None]:
# Check to make sure that the necessary generations / datasets are correctly normalized

In [None]:
# Uncomment to check CUDA memory:
#print(torch.cuda.memory_summary(device=device, abbreviated=False))