# Improved GAN face synthesis

This will be an improvement over the previous implementation of our GAN with research-proven methods to improve performance

### References

-   Adding noise to images used for training the GAN
    -   Sønderby, Casper Kaae, Jose Caballero, Lucas Theis, Wenzhe Shi, and Ferenc Huszár. “Amortised MAP Inference for Image Super-Resolution.” ArXiv:1610.04490 [Cs, Stat], February 21, 2017. http://arxiv.org/abs/1610.04490.

-   Use of DNNs in GAN, ADAM Optimizer for Generator and SGD for Discriminator

    -   Radford, Alec, Luke Metz, and Soumith Chintala. “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.” ArXiv:1511.06434 [Cs], January 7, 2016. http://arxiv.org/abs/1511.06434.
    
-   One sided label smoothing 

    -   Salimans, Tim, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. “Improved Techniques for Training GANs.” ArXiv:1606.03498 [Cs], June 10, 2016. http://arxiv.org/abs/1606.03498.

## Import Required libraries

In [1]:
import os
import time
import torch
import torchvision
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn
from torchvision import transforms

from PIL import Image
from typing import Set
import matplotlib.pyplot as plt

%matplotlib inline

## Import required libraries for data loading

In [2]:
import os.path
from glob import glob
from tqdm import tqdm

from torch.utils.data.dataset import Dataset

## GPU Device Configuration

Check if GPU training is available with cuda

In [3]:
device = None
if torch.cuda.is_available():
    # inbuilt cudnn auto-tuner searches for best algorithm for hardware
    # cuddn.benchmark should be set to True when our input size does not vary
    torch.backends.cudnn.benchmark = True
    print("GPU training available")
    device = torch.device("cuda:0")
    print(f"Index of CUDA device in use is {torch.cuda.current_device()}")
else:
    print("GPU training NOT available")
    device = torch.device("cpu")
    print("Can only train on CPU")

GPU training available
Index of CUDA device in use is 0


## Hyper Parameter Object for our model

In [4]:
class HyperParameter:
    def __init__(
        self,
        latent_sz,
        in_img_size,
        in_img_channel=3,
        data_dir="../data/img_align_celeba",
        output_dir="../generated_imgs",
        lr=0.0002,
        beta1=0.5,
        epochs=100,
        batch_sz=64,
        d_trained_wt_dir="../weights/discriminator_trained_weights",
        g_trained_wt_dir="../weights/generator_trained_weights",
    ):
        self.latent_size = latent_sz
        self.learning_rate = lr
        self.beta1 = beta1

        self.epochs = epochs
        self.batch_size = batch_sz
        self.input_img_channel = in_img_channel
        self.input_img_size = in_img_size

        self.discriminator_trained_weight_dir = d_trained_wt_dir
        self.generator_trained_weight_dir = g_trained_wt_dir
        self.output_dir = output_dir
        self.data_dir = data_dir

        os.makedirs(self.discriminator_trained_weight_dir, exist_ok=True)
        os.makedirs(self.generator_trained_weight_dir, exist_ok=True)
        os.makedirs(self.output_dir, exist_ok=True)

    def __repr__(self):
        return (
            f"latent_size: {self.latent_size}\n"
            + f"learning_rate: {self.learning_rate}\n"
            + f"beta1: {self.beta1}\n"
            + f"input_img_size: {self.input_img_size}\n"
            + f"input_img_channel: {self.input_img_channel}\n"
            + f"epochs: {self.epochs}\n"
            + f"batch_size: {self.batch_size}\n"
            + f"data_dir: {self.data_dir}\n"
            + f"output_dir: {self.output_dir}\n"
            + f"discriminator_trained_weight_dir: {self.discriminator_trained_weight_dir}\n"
            + f"generator_trained_weight_dir: {self.generator_trained_weight_dir}\n"
        )

## Data load Utility functions

In [5]:
VALID_IMG_EXTENSIONS = {
    ".jpg",
    ".JPG",
    ".jpeg",
    ".JPEG",
    ".png",
    ".PNG",
    ".ppm",
    ".PPM",
    ".bmp",
    ".BMP",
}


def _is_image_file(fpath, valid_img_ext: Set = VALID_IMG_EXTENSIONS) -> bool:
    """Validates if a file is an img file"""
    _, img_ext = os.path.splitext(fpath)
    return img_ext in valid_img_ext


def make_img_dataset(root_dir, valid_img_ext: Set = VALID_IMG_EXTENSIONS):
    """Returns a list of valid img files after recursively chking in rootdir"""
    img_dataset = []
    for subdir, dirs, files in os.walk(root_dir):
        for file in files:
            if _is_image_file(file, valid_img_ext):
                img_path = os.path.join(subdir, file)
                img_dataset.append(img_path)

    return img_dataset


def default_loader(img):
    """Converts img file into RGB mode"""
    try:
        opened_img = Image.open(img)
        return opened_img.convert("RGB")
    except Exception as e:
        print(f"Exception: {e}. Skipping {img}")
        return False

## Create an ImageDataset Class that inherits from torch.utils.data.dataset.Dataset

In [6]:
class ImageDataset(Dataset):
    def __init__(
        self, root_dir, transform=None, valid_img_ext: Set = VALID_IMG_EXTENSIONS
    ):
        self.transform = transform

        self.face_dataset = make_img_dataset(root_dir, valid_img_ext)
        if len(self.face_dataset) == 0:
            raise IndexError("Face dataset is empty")

    def __len__(self):
        return len(self.face_dataset)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.face_dataset[idx]
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)

        return image

## Data loader class to get to iterate over the training set

In [7]:
def get_data_loader(
    root_data_dir,
    data_transform=None,
    batch_size=64,
    num_workers=2,
    shuffle=True,
    drop_last=True,
):
    """
    root_dir is the directory with the images
    """
    face_dataset = ImageDataset(root_data_dir, data_transform)
    data_loader = torch.utils.data.DataLoader(
        face_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        drop_last=True,
    )
    return data_loader

_____

# Improved GAN implementation

The Utility functions and the complement classes above are identical to our normal GAN implementation, but now we create Transform classes and the Generator-Discriminator Networks with enhanced structures

## Image Preprocessing Transforms class where we add noise to input images

### Gaussian Noise Transform class

In [8]:
class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        """Gaussian Noise has a mu of 0 and sigma of 1"""
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + "(mean={0}, std={1})".format(
            self.mean, self.std
        )

In [9]:
img_data_transform = transforms.Compose(
    [
        transforms.CenterCrop(160),
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

_____
# Creating our Neural Networks

## Discriminator Module

With our base Discriminator Model, we were having problems with model collapse and the discriminator reducing its loss at a fast rate while the Generator Loss increased. This is usually due to the fact that our discriminator is learning too fast and has overfit the training data and the generator cannot learn anymore.

So we add regularization to our Discriminator by adding `nn.Dropout(0.4)` layers after the activations which randonly zeros some of the elements of the input tensor with probability 0.4.

In [10]:
class Discriminator(nn.Module):
    def __init__(self, in_img_size=64, in_img_channels=3, n_gpu=1):
        super(Discriminator, self).__init__()
        self.n_gpu = n_gpu

        self.main = nn.Sequential(
            # Input size is input_img_size*input_img_size*3 (img_width, img_height, input_img_channels)
            nn.Linear(in_img_size * in_img_size * in_img_channels, 256),
            nn.LeakyReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 512),
            nn.LeakyReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 1),
            nn.Sigmoid(),
        )

    def forward(self, X):
        if X.is_cuda and self.n_gpu > 1:
            output = nn.parallel.data_parallel(self.main, X, range(self.n_gpu))
        else:
            output = self.main(X)
        return output

## Generator Module

For our Generator Model, we avoid sparse gradients with layers of `ReLU` or `maxpooling`, instead we swtich to the `LeadkyReLU` activation in the Generator.

In [11]:
class Generator(nn.Module):
    def __init__(self, latent_vector_size, in_img_size=64, in_img_channels=3, n_gpu=1):
        super(Generator, self).__init__()
        self.n_gpu = n_gpu

        self.main = nn.Sequential(
            # laten vector reprs the latent space
            nn.Linear(latent_vector_size, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, in_img_size * in_img_size * in_img_channels),
            nn.Tanh(),
        )

    def forward(self, X):
        if X.is_cuda and self.n_gpu > 1:
            output = nn.parallel.data_parallel(self.main, X, range(self.n_gpu))
        else:
            output = self.main(X)
        return output

## GAN Module

We swtich the optimizer for the Discriminator Network to SGD from an ADAM optimizer.

More importantly to penalize znd regularize the discriminator, we do one-sided smoothing with the Discriminator, i.e. if the label is real, then replace the label with a random number between 0.7 and 1.2

We can also change some real_labels randomly to fake_labels when training the Discriminator to add noise

In [12]:
class GAN:
    """
    GAN Class with fit method that trains the GAN
    """

    def __init__(
        self,
        hyper_parameter,
        load_wt=True,
        save_wt=True,
        save_wt_interval=10,
        save_img_interval=50,
    ):
        self.hp = hyper_parameter
        self.G_net = Generator(self.hp.latent_size, self.hp.input_img_size).to(device)
        self.D_net = Discriminator(self.hp.input_img_size).to(device)
        self.D_loss_overtime = []
        self.G_loss_overtime = []

        if load_wt:
            self._load_saved_weights()
        # Binary Cross Entropy Loss
        self.criterion = nn.BCELoss()

        # Optimizers
        self.G_optimizer = torch.optim.Adam(
            self.G_net.parameters(),
            lr=self.hp.learning_rate,
            betas=(self.hp.beta1, 0.999),
        )
        self.D_optimizer = torch.optim.Adam(
            self.D_net.parameters(),
            lr=self.hp.learning_rate,
            betas=(self.hp.beta1, 0.999),
        )

    def _load_saved_weights(self):
        D_weight_files = glob(self.hp.discriminator_trained_weight_dir + "/*.pt")
        if D_weight_files:
            latest_D_wt = max(D_weight_files, key=os.path.getctime)
            print(f"Loading weight {latest_D_wt} for Discriminator")
            self.D_net.load_state_dict(torch.load(latest_D_wt))
            self.D_net.eval()

        G_weight_files = glob(self.hp.generator_trained_weight_dir + "/*.pt")
        if G_weight_files:
            latest_G_wt = max(G_weight_files, key=os.path.getctime)
            print(f"Loading weight {latest_G_wt} for Generator")
            self.G_net.load_state_dict(torch.load(latest_G_wt))
            self.G_net.eval()

    @staticmethod
    def denorm(X):
        """This is the denorm when norm is done with transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))"""
        out = (X + 1) / 2
        return out.clamp(0, 1)

    @staticmethod
    def plot_gan_loss(G_loss, D_loss, save_dir="../loss_curves"):
        plt.plot(G_loss, label="Generator Loss")
        plt.plot(D_loss, label="Discriminator Loss")
        plt.title("GAN Loss")
        plt.ylabel("BCE Loss")
        plt.xlabel("Iterations (x10)")
        plt.legend()

        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(f"{save_dir}/{time.time()}_GAN_loss.png")
        plt.show()

    def plot_loss(self):
        GAN.plot_gan_loss(self.G_loss_overtime, self.D_loss_overtime)

    def fit(
        self, train_data_loader, save_wt=True, save_img_interval=50, save_wt_interval=10
    ):
        # Generator uses this noise to generate the images in the dataset for benchmarking
        fixed_noise = torch.randn(
            self.hp.batch_size, self.hp.latent_size, device=device
        )

        for epoch in tqdm(range(self.hp.epochs)):
            d_running_loss, g_running_loss = 0, 0

            # mini-batch training
            for idx, data in enumerate(train_data_loader):
                # # set to (64, -1) -1 should be equi to img_sz * img_sz * img_ch
                X_data = data.reshape(self.hp.batch_size, -1)
                X_data = X_data.to(device)

                # real_label = 1, fake_label = 0
                # real_labels = torch.ones(self.hp.batch_size, 1).to(device) # FOR REFERENCE
                real_labels = (
                    (1.2 - 0.7) * torch.rand((self.hp.batch_size, 1)) + 0.7
                ).to(device)
                fake_labels = torch.zeros(self.hp.batch_size, 1).to(device)

                ### Train Discriminator which maximizes log(D(x)) + log(1 - D(G(z))) ###
                # Using real images
                self.D_net.zero_grad()
                D_real_output = self.D_net(X_data)  # feedforward
                D_real_loss = self.criterion(D_real_output, real_labels)  # cal loss
                D_real_loss.backward()

                # Using fake images
                noise = torch.randn(
                    self.hp.batch_size, self.hp.latent_size, device=device
                )
                G_fake_output = self.G_net(noise)  # feedforward
                D_fake_output = self.D_net(G_fake_output.detach())
                D_fake_loss = self.criterion(D_fake_output, fake_labels)
                D_fake_loss.backward()

                D_loss = D_real_loss + D_fake_loss
                self.D_optimizer.step()

                ### Train Generator which maximizes log(D(G(z))) as Gradient Descent is expensive ###
                self.G_net.zero_grad()
                G_output = self.D_net(G_fake_output)
                G_loss = self.criterion(G_output, real_labels)
                G_loss.backward()
                self.G_optimizer.step()

                d_running_loss += D_loss.item()
                g_running_loss += G_loss.item()
                fmt_epoch = "{:04d}".format(epoch)
                fmt_idx = "{:04d}".format(idx)

                if idx % save_img_interval == 0:
                    # Real image
                    torchvision.utils.save_image(
                        data,
                        f"{self.hp.output_dir}/{fmt_epoch}_{fmt_idx}_real_samples.png",
                        normalize=True,
                    )
                    # Generated fake image
                    fake_gen = self.G_net(fixed_noise)
                    fake_gen = GAN.denorm(
                        fake_gen.reshape(
                            self.hp.batch_size,
                            3,
                            self.hp.input_img_size,
                            self.hp.input_img_size,
                        )
                    )
                    torchvision.utils.save_image(
                        fake_gen,
                        f"{self.hp.output_dir}/{fmt_epoch}_{fmt_idx}_fake_samples.png",
                        normalize=True,
                    )

                if idx % 20 == 0:
                    print(
                        f"Discriminator Loss at epoch: {epoch}, iter {idx} = {D_loss.item()}"
                    )
                    print(
                        f"Generator Loss at epoch: {epoch}, iter {idx} = {G_loss.item()}"
                    )

                    d_avg_running_loss = d_running_loss / max(1, idx)
                    g_avg_running_loss = g_running_loss / max(1, idx)
                    self.D_loss_overtime.append(d_avg_running_loss)
                    self.G_loss_overtime.append(g_avg_running_loss)

                # Save checkpoint weights
                if save_wt and idx % save_wt_interval == 0:
                    torch.save(
                        self.D_net.state_dict(),
                        self.hp.discriminator_trained_weight_dir
                        + f"/dnet_epoch_{fmt_epoch}_iter_{fmt_idx}.pt",
                    )
                    torch.save(
                        self.G_net.state_dict(),
                        self.hp.generator_trained_weight_dir
                        + f"/gnet_epoch_{fmt_epoch}_iter_{fmt_idx}.pt",
                    )

## Training

### Set Hyper-parameters, load data and apply transformations

In [19]:
hp = HyperParameter(
    latent_sz=100,
    in_img_size=64,
    in_img_channel=3,
    data_dir="../data/img_align_celeba/",
    output_dir="../improved_generated_imgs",
    d_trained_wt_dir="../weights/improved_discriminator_trained_weights",
    g_trained_wt_dir="../weights/improved_generator_trained_weights",
    lr=0.0002,
    beta1=0.5,
    epochs=100,
    batch_sz=64,
)

big_dataset = get_data_loader(hp.data_dir, data_transform=img_data_transform)

### Run the GAN.fit() function to train

In [None]:
gan = GAN(hp)
gan.fit(big_dataset)

## Results

In [None]:
gan.plot_loss()