### Notebook Teaching Objective

The aim of this notebook is to show the participants how the ***GAN loss*** are computed and what are the ***inputs*** of the models which implicitely affect the losses. 

Compared to the other notebooks we are introducing VGG and Feature Matching Loss.

##### Analysis of original function

In the forward pass the different losses used for the training are defined. The most important components of the forward pass are the following:

Networks to be trained to generate and discriminate
1. ``self.netG``: the generator network
2. ``self.netD``: the discriminator network

Loss Discriminator
1. ``loss_D_fake``, discriminator loss from discriminating fake generated images: -log(1-D(G(z)))
2. ``loss_D_real``, discriminator loss from discriminating real images sampled from data: -log(D(x))

Loss Generator
1. ``loss_G_Gan``, generator loss from the "Fake Pass Loss" which label the fake images as correct


***Original Function***

In [None]:
def forward(self, label, inst, image, feat, infer=False):
    # Encode Inputs
    input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  

    # Fake Generation
    if self.use_features:
        if not self.opt.load_features:
            feat_map = self.netE.forward(real_image, inst_map)                     
        input_concat = torch.cat((input_label, feat_map), dim=1)                        
    else:
        input_concat = input_label
    # TODO----------------------#    
    fake_image = self.netG.forward(input_concat.float())

    # Fake Detection and Loss
    pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
    loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

    # Real Detection and Loss        
    pred_real = self.discriminate(input_label, real_image)
    loss_D_real = self.criterionGAN(pred_real, True)

    # GAN loss (Fake Passability Loss)        
    pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
    loss_G_GAN = self.criterionGAN(pred_fake, True)               

    # GAN feature matching loss
    loss_G_GAN_Feat = 0
    if not self.opt.no_ganFeat_loss:
        feat_weights = 4.0 / (self.opt.n_layers_D + 1)
        D_weights = 1.0 / self.opt.num_D
        for i in range(self.opt.num_D):
            for j in range(len(pred_fake[i])-1):
                loss_G_GAN_Feat += D_weights * feat_weights * \
                    self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat

    # VGG feature matching loss
    loss_G_VGG = 0
    if not self.opt.no_vgg_loss:
        loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat

    # Only return the fake_B image if necessary to save BW
    return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]

In order to make the notebook of easier comprehension, the functions have been simplified and commented.

``forward`` function modifications:

-``netE``: not in use

-``nwtG`` and ``netD`` become functions themselves which are initialized before

-``self.encode_input`` become a separate function to explain what goes into the model and is placed outside the forward to teach the audience what goes in the forward path and needs to be encoded. The encoding part itself is probably irrelevant, but a independent function has to be created if an example of forward pass has to be created.

The forward pass is the most interesting one, here we can show it detached from the class itself, and then when we show the training loop, we highlight where the model calls it.



*Ideas to clean up the function more*:

-``criterionGAN`` becomes a function

-condition of training the gan based on the generator and discriminator features' neurons.

-condition of training the gan using VGG to recognize images


TO DO:


1) Maybe further the functions based on the input parameter of ``opt``.

2) Load the poses and images

3) Create the cell where the network's inference ability is shown at different training phases.

3) Simplify function for criterionGAN

4) Simplify criterionFeat and criterionVGG

5) Double check with Gaetan and Thibault about content and what to be added and learn most important opt parameters.

Qs: 

- Is the explaination of the input dimension correct?
- Does the encode_input add noise to label_map?
- Weights in VGG loss and loop
- Feature matching loss part

***All below is the actual notebook which will be used in the workshop***

In [2]:
#necessary package imports
import numpy as np
import torch
import os
from torch.autograd import Variable
import sys
sys.path.append('../..')
from src.pix2pixHD.util.image_pool import ImagePool
from src.pix2pixHD.models.base_model import BaseModel
from src.pix2pixHD.models import networks
import src.config.train_opt_notebook as opt

## Forward Pass Pix2Pix

The aim of this notebook is to show how the ***GAN loss*** are computed and what are the ***inputs*** of the models involved in the training. 

VGG and Feature Matching Loss are introduced and their function is explained.

#### Initialize Models

Here the generator and the discriminator model are initialized since they are going to be used in the training process.

Among the config variables which are used to define the network, some which are worth noticing are:

- ``netG_input_nc`` is the number of classes which are fed in the generator model. In other words it is the number of coordinates to define a pose.
- ``use_sigmoid`` is the option which being set to *True* defines the use of the least square loss.
- ``netD_input_nc`` is the number of input classes going into the generator which are the classes defining a pose and those defining the image.

As it can be seen the structure of the generator and the discriminator is extremely complex.

In [4]:
netG_input_nc = opt.label_nc
generator_model=networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=opt.gpu_ids)   
use_sigmoid = opt.no_lsgan
netD_input_nc = netG_input_nc + opt.output_nc
discriminator_model=networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=opt.gpu_ids)

GlobalGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(18, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace)
    (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (11): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (12): ReLU(inplace)
    (13): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (14): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_run

#### Inputs

(This part might be removed) In order to feed the data to the network, the input must be encoded in tensors in order to make them usable for the training.

In [5]:
def encode_input(label_map, real_image, infer=False): 
    
    # create one-hot vector for the pose
    size = label_map.size()
    oneHot_size = (size[0], opt.label_nc, size[2], size[3])
    input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_().to(self.device)
    input_label = input_label.scatter_(1, label_map.data.long().to(self.device), 1.0)
    if opt.data_type == 16:
        input_label = input_label.half()
        
    input_label = Variable(input_label, volatile=infer)

    # real images for training
    if real_image is not None:
        real_image = Variable(real_image.data.to(self.device))


    return input_label, real_image

To illustrate the GAN's inference ability, a pose is loaded in order to show how its ability to generate dancing individuals

In [9]:

label = # load the label (The poses) 
image = # load the image (The real images)

input_label, real_image = encode_input(label, image)  

SyntaxError: invalid syntax (<ipython-input-9-87c887113609>, line 3)

Before exploring the forward pass, it is useful to inspect how the losses are calculated. Hence the functions to compute the LGAN loss, the VGG loss and the Feature Matching loss are reported below.

In [None]:
def GAN_loss():
    pass


def VGG_loss(x,y):
    
    #loss criteria
    criterion = nn.L1Loss()
    
    weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
    x_vgg, y_vgg = networks.Vgg19(x,y)
    
    loss = 0
    for i in range(len(x_vgg)):
        loss += weights[i] * criterion(x_vgg[i], y_vgg[i].detach())
    return loss


def Feat_Match_loss():
    
    criterion=torch.nn.L1Loss()
    
    return criterion

By going through the forward pass, the participants will get acquainted with the loss involved in the GAN and they will understand how the input previously encoded are used to generate or discriminate fake and real images. Additionally the Feature Matching Loss and VGG Loss will be presented.

In [6]:
def forward(label, image, feat, infer=False):

    input_concat = label
    
    ### Fake Generation
    fake_image = generator_model.forward(input_concat.float())

    ### Fake Detection
    
    #concatenate the condition input_label used to generate the image and the fake image 
    input_concat = torch.cat((input_label, fake_image.detach()), dim=1) 
    
    #adapt the image dimension
    fake_query = ImagePool(opt.pool_size).query(input_concat)
        
    #predict if the generate image is False or True
    pred_fake_pool = discriminator_model.forward(fake_query)
    
    #predict if the real images is False or True      
    pred_real = discriminator_model.forward(input_label, real_image)
    
    ### Dicriminator Loss Based on Fake Images
    
    # label_loss=  .....the ground truth to compute the discriminator loss for fake images......
    
    loss_D_fake = self.criterionGAN(pred_fake_pool, label_loss)      
    
    ### Dicriminator Loss Based on Fake Images
    
    # label_loss=  .....the ground truth to compute the discriminator loss for real images......
    
    loss_D_real = self.criterionGAN(pred_real, label_loss)

    ### Generator loss (Fake Passability Loss)        
    pred_fake = discriminator_model.forward(torch.cat((input_label, fake_image), dim=1))  
    
    # label_loss=  .....the ground truth to compute the generate loss to teach it to generate fake images......
    
    loss_G_GAN = self.criterionGAN(pred_fake, label_loss)
    

    ### Feature Matching Loss
    loss_G_GAN_Feat = 0
    feat_weights = 4.0 / (opt.n_layers_D + 1)
    D_weights = 1.0 / opt.num_D
    for i in range(opt.num_D):
        for j in range(len(pred_fake[i])-1):
            loss_G_GAN_Feat += D_weights * feat_weights * \
                Feat_Match_loss(pred_fake[i][j], pred_real[i][j].detach()) * opt.lambda_feat
            
    

    ### VGG Loss
    loss_G_VGG = 0
    
    input1, input2 = # ...... the inputs of to calculate the VGG loss using the L1 function to ensure VGG classifies fake and real images equally .........
    
    
    loss_G_VGG = VGG_loss(fake_image, real_image) * opt.lambda_feat
    
    
    # set the flags based on if the VGG and the Feature Matching Losses are used
    flags = (True, not opt.no_ganFeat_loss, not opt.no_vgg_loss, True, True)
    
    # only return the fake_B image if necessary to save BW
    losses_to_return=[[l for (l,f) in zip((loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake),flags) if f], None if not infer else fake_image]
    

    return losses_to_return

SyntaxError: invalid syntax (<ipython-input-6-adefe7e8e0ce>, line 55)

#### Inference Ability

As it can be seen the inference ability of the network,

Either they run the inference for real, or they just load a checkpoint.