### Notebook Teaching Objective

The aim of this notebook is to show the participants how the ***GAN loss*** are computed and what are the ***inputs*** of the models which implicitely affect the losses. 

Compared to the other notebooks we are introducing VGG and Feature Matching Loss.

TO DO:

- code GAN_loss
- clean imports
- debug training loop
- Explain LSGAN loss 
- Explain VGG loss 
- Explain Feature matching loss part
- make sure information written is correct and flow is good


#### Imports

In [1]:
#necessary package imports
import numpy as np
import torch
import os
from torch.autograd import Variable
import sys
sys.path.append('../..')
sys.path.append('../../src/pix2pixHD')
from src.pix2pixHD.models.pix2pixHD_model import Pix2PixHDModel
from src.pix2pixHD.util.image_pool import ImagePool
from src.pix2pixHD.models.base_model import BaseModel
from src.pix2pixHD.models import networks
import src.config.train_opt_notebook as opt


import argparse
import json
import os
import numpy as np
import torch
import time
import sys
from collections import OrderedDict
from torch.autograd import Variable
import warnings
warnings.filterwarnings('ignore')

dir_name = '../../src/GANcing/'
pix2pixhd_dir = os.path.join(dir_name, '../pix2pixHD/')
sys.path.append(pix2pixhd_dir)
sys.path.append(os.path.join(dir_name, '../..'))
sys.path.append(os.path.join(dir_name, '../utils'))


from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
# from util.visualizer import Visualizer
from torch_utils import get_torch_device

## Forward Pass Pix2Pix

The aim of this notebook is to show how the ***GAN loss*** are computed and what are the ***inputs*** of the models involved in the training. 

VGG and Feature Matching Loss are introduced and their function is explained.

## Initialize Model Class

Before initializing the model class, the losses used to train GAN are defined.

In [2]:
def GAN_loss(use_lsgan=not opt.no_lsgan, tensor=torch.Tensor, gpu_ids=opt.gpu_ids):
    
    return networks.GANLoss(use_lsgan=use_lsgan, tensor=tensor, gpu_ids=gpu_ids)


def VGG_loss(x,y):

    #loss criteria
    criterion = nn.L1Loss()

    # The number of weights decrease gradually as the network becomes deeper: it mirrors the VGG network structure.
    weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]

    # The features for different layers of the VGG are computed using the real and fake images
    x_vgg, y_vgg = networks.Vgg19(x,y)

    loss = 0
    for i in range(len(x_vgg)):
        loss += weights[i] * criterion(x_vgg[i], y_vgg[i].detach())
    return loss


def Feat_Match_loss():

    criterion=torch.nn.L1Loss()

    return criterion

Here the model class is initialized.

First, the generator (``self.net_G``) and the discriminator (``self.net_D``) models are initialized in the ``initialize`` function:

Among the config variables which are used to define the network, some which are worth noticing are:

- ``netG_input_nc`` is the number of classes which are fed in the generator model. In other words it is the number of coordinates to define a pose. ``netG_input_nc`` is the number of channels which correspond to the number of joints. The poses are fed as images with ``netG_input_nc`` number of channels.
- ``netD_input_nc`` is the number of input classes going into the discriminator which are the classes defining a pose and those defining the image. i.e the 3 RGB channels are added to the input.

Second, the forward pass is defined in ``forward``:

The forward pass is the essential one which allows the network to be trained.

In [3]:
class Gancing(BaseModel):
    
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain

        ######check for gpu
        self.gpu_ids = opt.gpu_ids
        if len(self.gpu_ids) > 0:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
            
        
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        # Generator network
        netG_input_nc = input_nc        
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)


        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)         

        # set loss functions and optimizers
        if self.isTrain:            
            
            # Names so we can breakout loss
            self.loss_names = ['G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake']

            # initialize optimizers
            params = list(self.netG.parameters())
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
            
    
    def encode_input(self, label_map, real_image, infer=False): 
    
        # create one-hot vector for the pose
        size = label_map.size()
        oneHot_size = (size[0], opt.label_nc, size[2], size[3])
        input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_().to(self.device)
        input_label = input_label.scatter_(1, label_map.data.long().to(self.device), 1.0)
        if opt.data_type == 16:
            input_label = input_label.half()

        input_label = Variable(input_label, volatile=infer)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.to(self.device))

        return input_label, real_image

    def forward(self, label, image, infer=False):
        
        
        # Encode Inputs
        input_label, real_image = self.encode_input(label, image)  

        # Fake Generation
        fake_image = self.netG.forward(input_label.float())
        

        ### Detection

        #predict if the real images as False or True      
        pred_real = self.netD.forward(torch.cat((input_label, real_image), dim=1))
        
        #predict if the fake image as False or True
        pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) 

        ##################### Dicriminator Loss Based on Fake Images #####################

        # Try to find out by yourself the ground truth to compute the GAN loss for the discriminator for fake images
        # GAN_label_Discriminator_Fake = ...
        
        loss_D_fake = GAN_loss(pred_fake, GAN_label_Discriminator_Fake)  

        ##################### Dicriminator Loss Based on Real Images #####################
        
        # Try to find out by yourself the ground truth to compute the GAN loss for the discriminator for real images
        # GAN_label_Discriminator_Real =  ... 

        loss_D_real = GAN_loss(pred_real, GAN_label_Discriminator_Real)

        ##################### Generator Loss Based #####################       
        
        # Try to find out by yourself the ground truth to compute the GAN loss for the generator
        # GAN_label_Generator=  ...

        loss_G_GAN = GAN_loss(pred_fake, GAN_label_Generator)  
                   
        
        ##################### Feature Matching Loss ##################### 
        
        loss_G_GAN_Feat = 0
        feat_weights = 4.0 / (opt.n_layers_D + 1)
        D_weights = 1.0 / opt.num_D
        for i in range(opt.num_D):
            for j in range(len(pred_fake[i])-1):
                loss_G_GAN_Feat += D_weights * feat_weights * \
                    Feat_Match_loss(pred_fake[i][j], pred_real[i][j].detach()) * opt.lambda_feat


        ##################### VGG Loss ##################### 
        loss_G_VGG = 0

        # Try to find out by yourself the inputs needed to calculate the VGG loss using the L1 function to ensure VGG classifies fake and real images equally
        # input1, input2 = ... , ...

        loss_G_VGG = VGG_loss(input1, input2) * opt.lambda_feat


        # only return the fake_B image if necessary to save BW
        losses_to_return=[[loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake], None if not infer else fake_image]


        return losses_to_return

    def inference(self, label):
        # Encode Inputs
        input_label, _ = self.encode_input(Variable(label), infer=True) 

        # Fake Generation
    
        fake_image = self.netG.forward(input_label)
        return fake_image


    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)


## Training Loop and Inference

Some explaination...

#### Training loop

In [7]:
def train_pose2vid(target_dir, run_name, temporal_smoothing=False):
    import src.config.train_opt_notebook as opt

    opt = update_opt(opt, target_dir, run_name, temporal_smoothing)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.json')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    if opt.load_pretrain != '':
        with open(iter_path, 'r') as f:
            iter_json = json.load(f)
    else:
        iter_json = {'start_epoch': 1, 'epoch_iter': 0}

    start_epoch = iter_json['start_epoch']
    epoch_iter = iter_json['epoch_iter']
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = create_model(opt)
    model = model.to(device)
#     visualizer = Visualizer(opt)

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            if temporal_smoothing:
                losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                          Variable(data['image']), Variable(data['feat']),
                                          Variable(data['previous_label']), Variable(data['previous_image']), infer=save_fake)
            else:
                losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                        Variable(data['image']), Variable(data['feat']), infer=save_fake)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward(retain_graph=True)
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()


            ############## Display results and errors ##########

            print(f"Epoch {epoch} batch {i}:")
            print(f"loss_D: {loss_D}, loss_G: {loss_G}")
            print(f"loss_D_fake: {loss_dict['D_fake']}, loss_D_real: {loss_dict['D_real']}")
            print(f"loss_G_GAN {loss_dict['G_GAN']}, loss_G_GAN_Feat: {loss_dict.get('G_GAN_Feat', 0)}, loss_G_VGG: {loss_dict.get('G_VGG', 0)}\n")

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
                # errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
#                 visualizer.print_current_errors(epoch, epoch_iter, errors, t)
#                 visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('synthesized_image', util.tensor2im(generated.data[0])),
                                       ('real_image', util.tensor2im(data['image'][0]))])
#                 visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
                model.save('latest')
                iter_json['start_epoch'] = epoch
                iter_json['epoch_iter'] = epoch_iter
                with open(iter_path, 'w') as f:
                    json.dump(iter_json, f)

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            iter_json['start_epoch'] = epoch + 1
            iter_json['epoch_iter'] = 0
            with open(iter_path, 'w') as f:
                json.dump(iter_json, f)

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

In [8]:
def update_opt(opt, target_dir, run_name, temporal_smoothing):
    opt.dataroot = os.path.join(target_dir, 'train')
    opt.name = run_name
    if os.path.isdir(os.path.join(dir_name, "../../checkpoints", run_name)):
        print("Run already exists, will try to resume training")
        opt.load_pretrain = os.path.join(dir_name, "../../checkpoints", run_name)

    if device == torch.device('cpu'):
        opt.gpu_ids = []

    opt.temporal_smoothing = temporal_smoothing

    return opt

In [None]:
device = get_torch_device()
train_pose2vid('../../data/targets/gianluca', 'trial_workshop', temporal_smoothing=False)

CustomDatasetDataLoader
dataset [AlignedDataset] was created
#training images = 100
GlobalGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(18, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace)
    (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (11): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (12): ReLU(inplace)
    (13): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1,

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /Users/gianlucamancini/.torch/models/vgg19-dcbb9e9d.pth
100%|██████████| 574673361/574673361 [02:19<00:00, 4121658.56it/s]


#### Inference

## Draft

In [7]:
class GancingModel(Pix2PixHDModel):
    
    def __init__(self, opt):
        
        Pix2PixHDModel.initialize(opt)
    
        netG_input_nc = opt.label_nc
        self.generator_model=networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                              opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                              opt.n_blocks_local, opt.norm, gpu_ids=opt.gpu_ids)   

        use_sigmoid = opt.no_lsgan
        netD_input_nc = netG_input_nc + opt.output_nc
        self.discriminator_model=networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                                  opt.num_D, not opt.no_ganFeat_loss, gpu_ids=opt.gpu_ids)
    
    def encode_input(self, label_map, real_image, infer=False): 
    
        # create one-hot vector for the pose
        size = label_map.size()
        oneHot_size = (size[0], opt.label_nc, size[2], size[3])
        input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_().to(self.device)
        input_label = input_label.scatter_(1, label_map.data.long().to(self.device), 1.0)
        if opt.data_type == 16:
            input_label = input_label.half()

        input_label = Variable(input_label, volatile=infer)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.to(self.device))


        return input_label, real_image
        
    def forward(self, label, image, feat, infer=False):
        
        input_label, real_image = self.encode_input(label, image)

        input_concat = input_label

        ### Fake Generation
        fake_image = self.generator_model.forward(input_concat.float())

        ### Fake Detection

        #concatenate the condition input_label used to generate the image and the fake image 
        input_concat = torch.cat((input_label, fake_image.detach()), dim=1) 

        #adapt the image dimension
        fake_query = ImagePool(opt.pool_size).query(input_concat)

        #predict if the generate image as False or True
        pred_fake_pool = self.discriminator_model.forward(fake_query)

        #predict if the real images as False or True      
        pred_real = self.discriminator_model.forward(input_label, real_image)
        
        #predict if the fake image as False or True
        pred_fake = self.discriminator_model.forward(torch.cat((input_label, fake_image), dim=1)) 

        ##################### Dicriminator Loss Based on Fake Images #####################

        # Try to find out by yourself the ground truth to compute the GAN loss for the discriminator for fake images
        # GAN_label_Discriminator = ...

        loss_D_fake = self.criterionGAN(pred_fake_pool, GAN_label_Discriminator)      

        ##################### Dicriminator Loss Based on Real Images #####################
        
        # Try to find out by yourself the ground truth to compute the GAN loss for the discriminator for real images
        # GAN_label_Discriminator =  ... 

        loss_D_real = self.criterionGAN(pred_real, GAN_label_Discriminator)

        ##################### Generator Loss Based #####################       
        
        # Try to find out by yourself the ground truth to compute the GAN loss for the generator
        # label_loss=  ...

        loss_G_GAN = self.criterionGAN(pred_fake, label_loss)


        ##################### Feature Matching Loss ##################### 
        
        loss_G_GAN_Feat = 0
        feat_weights = 4.0 / (opt.n_layers_D + 1)
        D_weights = 1.0 / opt.num_D
        for i in range(opt.num_D):
            for j in range(len(pred_fake[i])-1):
                loss_G_GAN_Feat += D_weights * feat_weights * \
                    Feat_Match_loss(pred_fake[i][j], pred_real[i][j].detach()) * opt.lambda_feat



        ##################### VGG Loss ##################### 
        loss_G_VGG = 0

        # Try to find out by yourself the inputs needed to calculate the VGG loss using the L1 function to ensure VGG classifies fake and real images equally
        # input1, input2 = ... , ...

        loss_G_VGG = VGG_loss(input1, input2) * opt.lambda_feat


        # set the flags based on if the VGG and the Feature Matching Losses are used
        flags = (True, not opt.no_ganFeat_loss, not opt.no_vgg_loss, True, True)

        # only return the fake_B image if necessary to save BW
        losses_to_return=[[l for (l,f) in zip((loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake),flags) if f], None if not infer else fake_image]


        return losses_to_return

##### Analysis of original function

In the forward pass the different losses used for the training are defined. The most important components of the forward pass are the following:

Networks to be trained to generate and discriminate
1. ``self.netG``: the generator network
2. ``self.netD``: the discriminator network

Loss Discriminator
1. ``loss_D_fake``, discriminator loss from discriminating fake generated images: -log(1-D(G(z)))
2. ``loss_D_real``, discriminator loss from discriminating real images sampled from data: -log(D(x))

Loss Generator
1. ``loss_G_Gan``, generator loss from the "Fake Pass Loss" which label the fake images as correct


***Original Function***

In [None]:
def forward(self, label, inst, image, feat, infer=False):
    # Encode Inputs
    input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  

    # Fake Generation
    if self.use_features:
        if not self.opt.load_features:
            feat_map = self.netE.forward(real_image, inst_map)                     
        input_concat = torch.cat((input_label, feat_map), dim=1)                        
    else:
        input_concat = input_label
    # TODO----------------------#    
    fake_image = self.netG.forward(input_concat.float())

    # Fake Detection and Loss
    pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
    loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

    # Real Detection and Loss        
    pred_real = self.discriminate(input_label, real_image)
    loss_D_real = self.criterionGAN(pred_real, True)

    # GAN loss (Fake Passability Loss)        
    pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
    loss_G_GAN = self.criterionGAN(pred_fake, True)               

    # GAN feature matching loss
    loss_G_GAN_Feat = 0
    if not self.opt.no_ganFeat_loss:
        feat_weights = 4.0 / (self.opt.n_layers_D + 1)
        D_weights = 1.0 / self.opt.num_D
        for i in range(self.opt.num_D):
            for j in range(len(pred_fake[i])-1):
                loss_G_GAN_Feat += D_weights * feat_weights * \
                    self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat

    # VGG feature matching loss
    loss_G_VGG = 0
    if not self.opt.no_vgg_loss:
        loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat

    # Only return the fake_B image if necessary to save BW
    return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]

In order to make the notebook of easier comprehension, the functions have been simplified and commented.

``forward`` function modifications:

-``netE``: not in use

-``nwtG`` and ``netD`` become functions themselves which are initialized before

-``self.encode_input`` become a separate function to explain what goes into the model and is placed outside the forward to teach the audience what goes in the forward path and needs to be encoded. The encoding part itself is probably irrelevant, but a independent function has to be created if an example of forward pass has to be created.

The forward pass is the most interesting one, here we can show it detached from the class itself, and then when we show the training loop, we highlight where the model calls it.



*Ideas to clean up the function more*:

-``criterionGAN`` becomes a function

-condition of training the gan based on the generator and discriminator features' neurons.

-condition of training the gan using VGG to recognize images
