### GANcing Model

The aim of this notebook is to show the participants how the ***GAN loss*** are computed and what are the ***inputs*** of the models which implicitely affect the losses. 

In [None]:
!nvidia-smi

!pip install googledrivedownloader

!git clone https://github.com/VisiumCH/AMLD2020-Dirty-Gancing.git dancing

!rm -rf dancing/data/targets/gianluca

from google_drive_downloader import GoogleDriveDownloader as gdd

gdd.download_file_from_google_drive(file_id='1NnV2OCYV3EchLSDdMTfI3KwGZkwENqvv',
                                    dest_path='./gianluca.zip',
                                    unzip=True)

!mv gianluca_1000 dancing/data/targets/gianluca

gdd.download_file_from_google_drive(file_id='1HO3YQJlumkyA-7RVuR_TCAhqFZg1vY8J',
                                    dest_path='./gianluca_pretrained.zip',
                                    unzip=True)

!mv gianluca_pretrained dancing/checkpoints/gianluca

gdd.download_file_from_google_drive(file_id='1ljUgt6XGvv5pCMhzFChkhe0BYXkHKEoq',
                                    dest_path='./gianluca.mp4',
                                    unzip=False)

!mv gianluca.mp4 dancing/results

!rm -rf dancing/checkpoints/testnotebook

%cd dancing/notebooks/gan_loss_function_input
%pwd

#### Imports

In [None]:
import os
import sys
import time
import torch
import torch.nn as nn
import warnings
import json 
import matplotlib.pyplot as plt

from collections import OrderedDict
from IPython.display import HTML
from base64 import b64encode

sys.path.append('../..')
sys.path.append('../../src/pix2pixHD/')

import src.config.train_opt_notebook as opt
import src.pix2pixHD.util.util as util

from src.utils.torch_utils import get_torch_device
from src.utils.plt_utils import init_figure, plot_current_results
from src.pix2pixHD.data.data_loader import CreateDataLoader
from src.pix2pixHD.models import networks
from src.pix2pixHD.models.base_model import BaseModel
from src.pix2pixHD.models.models import create_model
from src.pix2pixHD.util.image_pool import ImagePool

warnings.filterwarnings('ignore')

device = get_torch_device()

## How to make people Dance

The model that we implemented for this workshop is taken from [Chan et. al.](https://arxiv.org/pdf/1808.07371.pdf), who first developed this architecture at UC Berkeley.

![GAN Architecture](imgs/gan_architecture.png)

The main idea behind the model is relatively simple. We want to teach a GAN to generate pictures of a person in a given environment, conditioned on the pose that this person should have. Just like for the conditional GAN, we have two models that will ompete against each other:

- The **Generator** which receives as input a target **pose**.
- The **Discriminator:** Which receives as input the target **pose** as well as in image of the subject, which can be real or generated

In order to train thesetwo models, the Loss functions that were designed are a bit more complex. You will implement three of them:

- The **Adversarial Loss** which represents the competition between **G** and **D**.
- The **Feature Matching Loss** which represents how similar the generated images are to the real images.
- The **Perceptual Loss** which represents how realistic the generated images are.

# Loss Functions

Before initializing the model class, the losses used to train GAN are defined.

## Adversarial Loss (LSGAN)

<blockquote>X. Mao, Q. Li, H. Xie, R. Y. Lau, and Z. Wang. Least Squares Generative Adversarial Networks. 2016. <a href="https://arxiv.org/abs/1611.04076">https://arxiv.org/abs/1611.04076</a></blockquote>

<p>In the original GAN, the discriminator acts as a classifier. The sigmoid of its output is seen as the probability that its input is real. The adversarial loss is based on a binary cross entropy loss. The objective of the generator is for fake samples to be classified as real, i.e. to elicit a high discriminator output.</p>

<p>In LSGAN, the discriminator acts as a regressor. The two classes are assigned a code, for instance 0 for fake and 1 for real. The output of the discriminator is seen as an estimation of its input class. The adversarial loss is based on a mean squared error (MSE). The objective of the generator is for fake samples to elicit a discriminator output close to 1.</p>

<img src="../images/gan_losses.png" width="500px">

<p>What is the advantage of using such a loss when training a GAN?</p>

<details>
<p>The core intuition is that, in a vanilla GAN, some fake samples can be classified as real by the discriminator even if they are not close to the real data. Using the crossentropy loss, the gradient vanishes for these samples, and the generator learns barely anything from them.</p>

<p>Using the mean squared error, fake samples that are 'on the right side of the decision boundary' (closer to real data than to typical fake data) are still penalized if they are too far from the real data.</p>

<p>The two plots below illustrate that idea in a simple setup. The real data is one-dimensional and its underlying distibution is a Gaussian. The generator learns the parameters of this Gaussian, and the discriminator output is an affine transformation of the sample value.</p>

<img src="../images/gan_loss_crossentropy.png" width="500px">
<img src="../images/gan_loss_mse.png" width="500px">
</details>

### To do: Fill in the missing parts of the Adversarial Loss

In [None]:
def GAN_Loss(x,
            target_is_real):
    """ Adversarial loss function for vanilla GAN or LS-GAN
    
    Args:
        x (iterable of iterables of torch.Tensors): discriminator activations
        target_is_real (bool): whether x should be treated as fake or real samples
    
    Returns:
        torch.Tensor: adversarial loss value for x
    """
    
    criterion = nn.MSELoss()
    target_value = ... # Implement me

    loss = 0
    for activations in x:
        pred = activations[-1].to(device)
        target_tensor = torch.FloatTensor(pred.size()).requires_grad_(False) \
                             .fill_(target_value).to(device)
        
        loss += ... # Implement me
    return loss

## Feature Matching Loss

<blockquote>Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X.. Improved Techniques for Training GANs. In Advances in neural information processing systems, 2016. <a href="https://arxiv.org/abs/1606.03498">https://arxiv.org/abs/1606.03498</a></blockquote>

The purpose of the **Feature Matching Loss** is to help the generator create images that look similar to the real images of the **training set**. The idea is to compare the **feature value** of the different layers of **D** when it is fed with real images and fake images. If the generated images are close to the real ones, the features of **D** will take approximately the same values, and the loss will be small.

### To do: Fill in the missing parts of the Feature Matching Loss

In [None]:
def Feat_Match_loss(pred_fake,
                    pred_real):
    """ Feature Matching loss for Pix2Pix
    
    Args:
        pred_fake (iterable of iterables of torch.Tensors): discriminator activations for fake samples
        pred_real (iterable of iterables of torch.Tensors): discriminator activations for real samples
    
    Returns:
        torch.Tensor: feature matching loss value for pred_fake and pred_real
    """

    criterion = nn.L1Loss()
    
    D_weights = 1 / opt.num_D
    feat_weights = 4 / (opt.n_layers_D + 1)
    factor = D_weights * feat_weights * opt.lambda_feat
    
    loss = 0
    for i in range(opt.num_D):
        for j in range(len(pred_fake[i]) - 1):
            loss += factor * ... # Implement me

    return loss


## Perceptual Loss (based on VGG19)

<blockquote>J. Johnson, A. Alahi, and L. Fei-Fei. Perceptual losses for real-time style transfer and super-resolution. In ECCV, 2016. <a href="https://arxiv.org/abs/1603.08155">https://arxiv.org/abs/1603.08155</a></blockquote>

<img src="../images/vgg19.png" width="600px">

The purpose of the **Perceptual Loss** is to help the generator create images that look **realistic** in general. To compute this loss, we feed the generated and the real images to a **pretrained VGG16** network and we compare the value of the features at different layers. If the generated images look real, the value of the loss will be small

### To do: Fill in the missing parts of the Perceptual Loss

In [None]:
class VGGLoss(nn.Module):
    def __init__(self, gpu_ids):
        super(VGGLoss, self).__init__()
        if len(gpu_ids) > 0:
            self.vgg = networks.Vgg19().cuda()
        else:
            self.vgg = networks.Vgg19()
        self.criterion = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]

    def forward(self, x, y):
        """ Perceptual loss for Pix2Pix
        Args:
            x, y (torch.Tensor): samples to compare

        Returns:
            torch.Tensor: perceptual loss value for x, y
        """ 
        x_vgg, y_vgg = ... # Implement me
        
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * ... # Implement me
        return loss

## Initialize Model Class

Here the model class is initialized.

First, the generator (``self.net_G``) and the discriminator (``self.net_D``) models are initialized in the ``initialize`` function:

Among the config variables which are used to define the network, some which are worth noticing are:

- ``netG_input_nc`` is the number of classes which are fed in the generator model. In other words it is the number of coordinates to define a pose. ``netG_input_nc`` is the number of channels which correspond to the number of joints. The poses are fed as images with ``netG_input_nc`` number of channels.
- ``netD_input_nc`` is the number of input classes going into the discriminator which are the classes defining a pose and those defining the image. i.e the 3 RGB channels are added to the input.

Second, the forward pass is defined in ``forward``:

The forward pass is the essential one which allows the network to be trained.

#### To do: Implement the missing parts of the forward pass

In [None]:
class Gancing(BaseModel):
    
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain

        ######check for gpu
        self.gpu_ids = opt.gpu_ids
        self.device = 'cuda' if (len(self.gpu_ids) > 0) else 'cpu'
        
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        # Generator network
        netG_input_nc = input_nc        
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)


        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)         

        # set loss functions and optimizers
        if self.isTrain:            
            
            # Names so we can breakout loss
            self.loss_names = ['G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake']
            
            # Define VGGloss so we don't load VGG at each batch
            self.VGG_Loss = VGGLoss(opt.gpu_ids)

            # initialize optimizers
            params = list(self.netG.parameters())
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
            
    
    def encode_input(self, label_map, real_image): 
    
        # create one-hot vector for the pose
        size = label_map.size()
        oneHot_size = (size[0], opt.label_nc, size[2], size[3])
        input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_().to(self.device)
        input_label = input_label.scatter_(1, label_map.data.long().to(self.device), 1.0)
        if opt.data_type == 16:
            input_label = input_label.half()

        # real images for training
        if real_image is not None:
            real_image = real_image.to(self.device)

        return input_label, real_image

    def forward(self, label, image, infer=False):
        
        
        # Encode Inputs
        input_label, real_image = self.encode_input(label, image)  

        # Fake Generation
        fake_image = self.netG.forward(input_label.float())
        
        ### Detection

        #predict if the real images as False or True
        pred_real = self.netD.forward(torch.cat((input_label, real_image.detach()), dim=1))
        
        #predict if the fake image as False or True
        pred_fake_D = self.netD.forward(torch.cat((input_label, fake_image.detach()), dim=1))

        # Get fake Generation for backprop in Generator
        pred_fake_G = self.netD.forward(torch.cat((input_label, fake_image), dim=1))

        ##################### Dicriminator Loss Based on Fake Images #####################

        # Try to find out by yourself the ground truth to compute the GAN loss for the discriminator for fake images
        GAN_label_Discriminator_Fake = ... # Implement me
        
        loss_D_fake = GAN_Loss(pred_fake_D, GAN_label_Discriminator_Fake)  

        ##################### Dicriminator Loss Based on Real Images #####################
        
        # Try to find out by yourself the ground truth to compute the GAN loss for the discriminator for real images
        GAN_label_Discriminator_Real =  ... # Implement me

        loss_D_real = GAN_Loss(pred_real, GAN_label_Discriminator_Real)

        ##################### Generator Loss #####################       
        
        # Try to find out by yourself the ground truth to compute the GAN loss for the generator
        GAN_label_Generator =  ... # Implement me

        loss_G_GAN = GAN_Loss(pred_fake_G, GAN_label_Generator)  
        
        ##################### Feature Matching Loss ##################### 
        
        # Try to find out by yourself what should be the inputs of the Feature Matching loss
        loss_G_GAN_Feat = Feat_Match_loss(..., ...) # Implement me

        ##################### VGG Loss ##################### 
        loss_G_VGG = 0

        # Try to find out by yourself the inputs needed to calculate the VGG loss using the L1 function to ensure VGG classifies fake and real images equally
        input1, input2 = ... , ...

        loss_G_VGG = self.VGG_Loss(input1, input2) * opt.lambda_feat

        # only return the fake_B image if necessary to save BW
        losses_to_return=[[loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake], None if not infer else fake_image]

        return losses_to_return

    def inference(self, label):
        # Encode Inputs
        input_label, _ = self.encode_input(label) 

        # Fake Generation
        fake_image = self.netG.forward(input_label)
        return fake_image


    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)


## Training Loop and Inference

Now that we have defined our model, we can train it and make people dance!

#### Training loop

In [None]:
dir_name = os.path.abspath('')

def update_opt(opt, target_dir, run_name):
    opt.dataroot = os.path.join(target_dir, 'train')
    opt.name = run_name
    opt.checkpoints_dir = os.path.join(dir_name, "../../checkpoints")

    if os.path.isdir(os.path.join(opt.checkpoints_dir, run_name)):
        print("Run already exists, will try to resume training")
        opt.load_pretrain = os.path.join(dir_name, "../../checkpoints", run_name)

    if device == torch.device('cpu'):
        opt.gpu_ids = []
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    opt.temporal_smoothing = False

    return opt


def train_pose2vid(target_dir, run_name):
    import src.config.train_opt_notebook as opt

    opt = update_opt(opt, target_dir, run_name)

    os.makedirs(os.path.join(opt.checkpoints_dir, run_name), exist_ok=True)
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.json')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    if opt.load_pretrain != '':
        with open(iter_path, 'r') as f:
            iter_json = json.load(f)
    else:
        iter_json = {'start_epoch': 1, 'epoch_iter': 0}

    start_epoch = iter_json['start_epoch']
    epoch_iter = iter_json['epoch_iter']
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = Gancing()
    model.initialize(opt)
    model = model.to(device)
    
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            ############## Forward Pass ######################

            losses, generated = model(data['label'], data['image'], infer=True)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()


            ############## Display results and errors ##########

            print(f"Epoch {epoch} batch {i}:")
            print(f"loss_D: {loss_D}, loss_G: {loss_G}")
            print(f"loss_D_fake: {loss_dict['D_fake']}, loss_D_real: {loss_dict['D_real']}")
            print(f"loss_G_GAN {loss_dict['G_GAN']}, loss_G_GAN_Feat: {loss_dict.get('G_GAN_Feat', 0)}, loss_G_VGG: {loss_dict.get('G_VGG', 0)}\n")

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize


            ### display output images
            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                    ('synthesized_image', util.tensor2im(generated.data[0])),
                                    ('real_image', util.tensor2im(data['image'][0]))])
            fig, axes = init_figure()
            plot_current_results(visuals, fig, axes)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
                model.save('latest')
                iter_json['start_epoch'] = epoch
                iter_json['epoch_iter'] = epoch_iter
                with open(iter_path, 'w') as f:
                    json.dump(iter_json, f)

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            iter_json['start_epoch'] = epoch + 1
            iter_json['epoch_iter'] = 0
            with open(iter_path, 'w') as f:
                json.dump(iter_json, f)

In [None]:
target_dir = '../../data/targets/gianluca/'
run_name = 'testnotebook'

train_pose2vid(target_dir, run_name)

#### Inference of pretrained model

We can quickly see that the training is working. The model slowly grasps the background of the picture, and then understands how to generate the missing figure. Let's see what happens later in the training!

In [None]:
target_dir = '../../data/targets/gianluca/'
run_name = 'gianluca'

train_pose2vid(target_dir, run_name)

#### Video Transfer

For the moment we have been training our model to recreate a picture of a person based on the pose of this same person. But since the pose can be extracted from the picture of anyone, we can take a video of someone dancing, extract the pose of the dancer on each of the frames and ask the model to create a new frame based on this pose!

Here is an example of what it could look like:

In [None]:
mp4 = open('../../results/gianluca.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)