## **Code for latent augmentation in GAN training**

Code largely adapted from official pytorch DCGAN tutorial:
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

**Make sure GPU is on. Change Runtime in toolbar above if necessary**

In [1]:
!nvidia-smi

Fri May 21 21:25:54 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

**Download dataset**


In [2]:
!gdown https://drive.google.com/uc?id=1BLs6H4j8cZmFXSK5ID3jjFHYP2GeqcnJ # 3K image dataset used in paper
!cp TrainingFaces.zip .
!unzip -q TrainingFaces.zip
!rm TrainingFaces.zip

Downloading...
From: https://drive.google.com/uc?id=1BLs6H4j8cZmFXSK5ID3jjFHYP2GeqcnJ
To: /content/TrainingFaces.zip
90.3MB [00:00, 118MB/s] 
cp: 'TrainingFaces.zip' and './TrainingFaces.zip' are the same file
replace fewfaces/67000/67000.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


**Hyperparameters and helper functions**

In [3]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML


# Number of workers for dataloader
workers = 0

# Batch size during training
batch_size = 128

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector to input to the generator
nz = 100

# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 400

# Learning rate for optimizers
lr = 0.0001

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root="fewfaces",
                           transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.CenterCrop(64), # Images most be resized to 64x64 to fir the networks
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=0)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

def weights_init(m):
    # Initialize weights of networks as normally distributed
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def PCA_torch(X,k, center=True, scale=False):
    # Calculate principal components using pytorch
    # Adapted from: https://medium.com/@ravikalia/pca-done-from-scratch-with-python-2b5eb2790bfc
    n,p = X.size()
    ones = torch.ones(n, device=device).view([n,1])
    h = ((1/n) * torch.mm(ones, ones.t())) if center  else torch.zeros(n*n).view([n,n])
    H = torch.eye(n, device=device) - h
    X_center =  torch.mm(H.double(), X.double())
    covariance = 1/(n-1) * torch.mm(X_center.t(), X_center).view(p,p)
    scaling =  torch.sqrt(1/torch.diag(covariance)).double() if scale else torch.ones(p,device=device).double()
    scaled_covariance = torch.mm(torch.diag(scaling).view(p,p), covariance)
    eigenvalues, eigenvectors = torch.eig(scaled_covariance, True)
    components = (eigenvectors[:, :k]).t()
    explained_variance = eigenvalues[:k, 0]
    return {'components':components,     
       'explained_variance':explained_variance }

**Create and initialize Generator and Discriminator networks**

In [4]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

netG = Generator(ngpu).to(device)
netG.apply(weights_init)

class Discriminator(nn.Module):
   def __init__(self, ngpu, augtensor=0, augment_strength=0.15, 
                augment_probability=0.15,directions = 0):
       super(Discriminator, self).__init__()
       self.ngpu = ngpu
       self.augment_strength = augment_strength 
       self.augment_probability = augment_probability * 2
       self.augtensor = augtensor
       self.directions = directions

       self.main = nn.Sequential(
           # input is (nc) x 64 x 64
           nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
           nn.LeakyReLU(0.2),
           # state size. (ndf) x 32 x 32
           nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ndf * 2),
           nn.LeakyReLU(0.2),
           # state size. (ndf*2) x 16 x 16
           nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ndf * 4),
           nn.LeakyReLU(0.2),
           # state size. (ndf*4) x 8 x 8
           nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ndf * 8),
           nn.LeakyReLU(0.2),
           # state size. (ndf*8) x 4 x 4
           nn.Conv2d(ndf * 8, 64, 4, 1, 0, bias=False),
           nn.LeakyReLU(0.2),
           nn.Flatten(),

           nn.Linear(64,16),
           nn.LeakyReLU(0.2),
           )
       self.final = nn.Sequential(
           nn.Linear(16,8),
           nn.LeakyReLU(0.2),
           nn.Linear(8,1),
           nn.Sigmoid()
       )
   def set_augmentstrength(self, strength):
      # Fucntion to dynamically change the strength of augmentations
      self.aug_strength = strength
   def set_augmentprob(self, probability):
      # Fucntion to dynamically change the probability of augmentations
      self.augment_probability = probability
   def encode(self, real_ims):    
      # encode data to 
      self.augtensor = self.main(real_ims).detach()
   def calculate_covariance(self):
      temp = PCA_torch(self.augtensor,4) # Calculate the 4 biggerst components
      directions = temp["components"]
      variance = temp["explained_variance"].reshape(-1,1)

      scaled_directions = torch.mul(directions, variance)
      self.directions = scaled_directions.detach()
   def forward(self, input):
       augment_probability = min(self.augment_probability,1) 
       augment_strength = self.augment_strength + 0.00001
       aug = self.directions
       x = self.main(input)
       
       empty = torch.zeros(x.shape[0],aug.shape[0],dtype=torch.float32)
       init_sparse = torch.empty_like(empty,dtype=torch.float32, device=device).uniform_(0, augment_probability)
       sparse_matrix = torch.bernoulli(init_sparse)
       scaled_directions = self.directions.detach()
       augmentations = torch.matmul(sparse_matrix, scaled_directions.float()).detach()

       x += augmentations
       x = self.final(x) 
       return x
netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)
netG.apply(weights_init)
netD.encode(next(iter(dataloader))[0].to(device))
netD.calculate_covariance()

**Training loop**

In [5]:
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
images = 0 # Number of images generated

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    
    if images >= 100000:
      # Once threshhold images is reached, perform latent augmenations

      # Apply augmentation strength and probability
      netD.set_augmentstrength(0.15)
      netD.augmentprob(0.1)
      
      # Calculate directions of covariance onze every epoch
      netD.encode(next(iter(dataloader))[0].to(device)) # encode directions of covariance
      netD.calculate_covariance()

    for i, data in enumerate(dataloader, 0):
        ## Train with all-real batch
        
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D

        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        images += batch_size
        # Output training stats
        if i % 10 == 0:

            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))


        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        iters += 1

Starting Training Loop...
[0/400][0/24]	Loss_D: 1.3911	Loss_G: 0.6786	D(x): 0.5428	D(G(z)): 0.5414 / 0.5074
[0/400][10/24]	Loss_D: 0.8638	Loss_G: 1.0009	D(x): 0.7351	D(G(z)): 0.4230 / 0.3679


KeyboardInterrupt: ignored

**Plot generated images**

In [None]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())