# Temporal Smoothing GAN

In this notebook, we will implement the same model as before, with the temporal smoothing component in addition.

In [None]:
!nvidia-smi

!pip install googledrivedownloader

!git clone https://github.com/VisiumCH/AMLD2020-Dirty-Gancing.git dancing

!rm -rf dancing/data/targets/gianluca

from google_drive_downloader import GoogleDriveDownloader as gdd

gdd.download_file_from_google_drive(file_id='1NnV2OCYV3EchLSDdMTfI3KwGZkwENqvv',
                                    dest_path='./gianluca.zip',
                                    unzip=True)

!mv gianluca_1000 dancing/data/targets/gianluca

gdd.download_file_from_google_drive(file_id='1KvLXCvse8hUATR8ZbZgYGwHVQOCcqMdS',
                                    dest_path='./gianluca_pretrained_ts.zip',
                                    unzip=True)

!mv gianluca_pretrained_ts dancing/checkpoints/gianluca_ts

gdd.download_file_from_google_drive(file_id='19jq_fuGfPjbEI67dpo01e5ka2JJW8uTO',
                                    dest_path='./gianluca_ts.mp4',
                                    unzip=False)

!mv gianluca_ts.mp4 dancing/results

!rm -rf dancing/checkpoints/testnotebook_ts

%cd dancing/notebooks/temporal_smoothing
%pwd

In [None]:
import os
import sys
import time
import torch
import torch.nn as nn
import warnings
import json
import matplotlib.pyplot as plt

from collections import OrderedDict
from IPython.display import HTML
from base64 import b64encode

sys.path.append('../..')
sys.path.append('../../src/pix2pixHD/')

import src.config.train_opt_notebook as opt
import src.pix2pixHD.util.util as util

from src.utils.torch_utils import get_torch_device
from src.utils.plt_utils import init_figure, plot_current_results
from src.pix2pixHD.data.data_loader import CreateDataLoader
from src.pix2pixHD.models import networks
from src.pix2pixHD.models.base_model import BaseModel
from src.pix2pixHD.models.models import create_model
from src.pix2pixHD.util.image_pool import ImagePool

warnings.filterwarnings('ignore')

device = get_torch_device()

## How to make people Dance (a bit better)

As you may have noticed in the previous notebook, there are a few problems with the model that was implemented. For example:

- The background is flickering from time to time: This occurs because of the objects in the pictures. Our subject could not provide training images on top of these objects, so the model does not know what to do when we ask him to draw a person at this position.
- Missed detections of the Pose Estimation algorithm causes missing limbs in the generated frames: This can be improved if we use a better (but slower) Pose Estimation model
- The face of the generated subject doesn't look very natural: This is addressed by [Chan et. al.](https://arxiv.org/pdf/1808.07371.pdf) with an additional Face Generator. It is included in this repo for those who are curious to try.
- There are artifacts on the clothes and alack of temporal coherence: This is the problem we are going to address now!

The lack of temporal coherence occurs mostly because all the frames are generated independently. The model has no notion of temporality. It creates frames, not videos.
The model of [Chan et. al.](https://arxiv.org/pdf/1808.07371.pdf) includes a temporal smoothing component, which we will now detail.

![GAN Architecture](imgs/gan_architecture.png)

The core of the model is the same as before. The main difference is that we want the **Generator** to create a frame that is **temporally** coherent with the previous frame. To do so, we have to change a bit the inputs of the two networks:

- The **Generator** will receive as input the **target pose** and the **previously generated frame**. Moreover, we will always generate **two** frames so as to create a sequence!
- The **Discriminator** will receive a sequence of **two images** and **two poses**

In order to force the discriminator to learn to use the temporal aspect, we have to concatenate the images and the poses along the width or height, but not the channels!

You feel like you are ready? First let's define our loss functions again!

In [None]:
def GAN_Loss(x,
            target_is_real):
    """ Adversarial loss function for LS-GAN
    
    Args:
        x (iterable of iterables of torch.Tensors): discriminator activations
        target_is_real (bool): whether x should be treated as fake or real samples
        use_ls_gan (bool): if True, use the LS-GAN loss, else, use the vanilla GAN loss
        gpu_ids (list): list of available GPU ids
    
    Returns:
        torch.Tensor: adversarial loss value for x
    """
    
    criterion = nn.MSELoss()
    target_value = 1.0 if target_is_real else 0.0

    loss = 0
    for activations in x:
        pred = activations[-1].to(device)
        target_tensor = torch.FloatTensor(pred.size()).requires_grad_(False) \
                             .fill_(target_value).to(device)
        loss += criterion(pred, target_tensor)
    return loss

In [None]:
def Feat_Match_loss(pred_fake,
                    pred_real,
                    opt):
    """ Feature Matching loss for Pix2Pix
    
    Args:
        pred_fake (iterable of iterables of torch.Tensors): discriminator activations for fake samples
        pred_real (iterable of iterables of torch.Tensors): discriminator activations for real samples
        n_layers_D (int): number of layers in discriminator
        num_D (int): number of discriminators (multi-scale)
        lambda_feat (float): scaling factor
    
    Returns:
        torch.Tensor: feature matching loss value for pred_fake and pred_real
    """

    criterion = nn.L1Loss()
    
    D_weights = 1 / opt.num_D
    feat_weights = 4 / (opt.n_layers_D + 1)
    factor = D_weights * feat_weights * opt.lambda_feat
    
    loss = 0
    for i in range(opt.num_D):
        for j in range(len(pred_fake[i]) - 1):
            loss += factor * criterion(pred_fake[i][j], pred_real[i][j].detach())

    return loss


In [None]:
class VGGLoss(nn.Module):
    def __init__(self, gpu_ids):
        super(VGGLoss, self).__init__()
        if len(gpu_ids) > 0:
            self.vgg = networks.Vgg19().cuda()
        else:
            self.vgg = networks.Vgg19()
        self.criterion = nn.L1Loss()
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]

    def forward(self, x, y):
        """ Perceptual loss for Pix2Pix
        Args:
            x, y (torch.Tensor): samples to compare

        Returns:
            torch.Tensor: perceptual loss value for x, y
        """ 
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss


We will now implement the **forward pass** of the model with the temporal smoothing.

Keep in mind that:

- The Generator needs to create a sequence of two images (hint: For the very first generation, you can give a blank image with torch.zeros(...))
- The Discriminator need to receive two images and two poses (hint: Try to first concatenate the two labels, then the two poses, and then the two resulting tensors together)

In [None]:
class SmoothGancing(BaseModel):
    
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain

        ######check for gpu
        self.gpu_ids = opt.gpu_ids
        if len(self.gpu_ids) > 0:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')        
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        # Generator network
        netG_input_nc = input_nc + 3       
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)


        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)         

        # set loss functions and optimizers
        if self.isTrain:            
            
            # Names so we can breakout loss
            self.loss_names = ['G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake']
            
            # Define VGGloss so we don't load VGG at each batch
            self.VGG_Loss = VGGLoss(opt.gpu_ids)
            
            # initialize optimizers
            params = list(self.netG.parameters())
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
            
    
    def encode_input(self, label_map, real_image): 
    
        # create one-hot vector for the pose
        size = label_map.size()
        oneHot_size = (size[0], opt.label_nc, size[2], size[3])
        input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_().to(self.device)
        input_label = input_label.scatter_(1, label_map.data.long().to(self.device), 1.0).to(self.device)
        if opt.data_type == 16:
            input_label = input_label.half()

        # real images for training
        if real_image is not None:
            real_image = real_image.to(self.device)

        return input_label, real_image

    def forward(self, label, image, previous_label=None, previous_image=None, infer=False):
        
        
        # Encode Inputs
        input_label, real_image = self.encode_input(label, image)  
        
        if previous_label is not None and previous_image is not None:
            previous_input_label, previous_real_image = self.encode_input(label, image)

        # Fake Generation
        previous_fake_image = ... # Implement me 
        
        fake_image = ... # Implement me

        ### Detection

        #predict if the real images as False or True
        real_inputD = ... # Implement me
        
        pred_real = self.netD.forward(real_inputD)
        
        #predict if the fake image as False or True
        fake_inputD_D = ... # Implement me
        
        pred_fake_D = self.netD.forward(fake_inputD_D)
        
        #Getfake Generation for backprop in Generator
        fake_inputD_G = ... # Implement me
        
        pred_fake_G = self.netD.forward(fake_inputD_G)

        ##################### Dicriminator Loss Based on Fake Images #####################

        GAN_label_Discriminator_Fake = False
        
        loss_D_fake = GAN_Loss(pred_fake_D, GAN_label_Discriminator_Fake)  

        ##################### Dicriminator Loss Based on Real Images #####################
        
        GAN_label_Discriminator_Real =  True

        loss_D_real = GAN_Loss(pred_real, GAN_label_Discriminator_Real)

        ##################### Generator Loss #####################       
        
        GAN_label_Generator=  True

        loss_G_GAN = GAN_Loss(pred_fake_G, GAN_label_Generator)  
                   
        ##################### Feature Matching Loss ##################### 
        
        loss_G_GAN_Feat = Feat_Match_loss(pred_fake_G, pred_real, opt)

        ##################### VGG Loss ##################### 
        loss_G_VGG = 0

        # Try to find out by yourself the inputs needed to calculate the VGG loss using the L1 function to ensure VGG classifies fake and real images equally
        input1, input2 = ... , ...
        input3, input4 = ... , ...


        loss_G_VGG = (self.VGG_Loss(input1, input2) + self.VGG_Loss(input3, input4)) * opt.lambda_feat


        # only return the fake_B image if necessary to save BW
        losses_to_return=[[loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake], None if not infer else fake_image]


        return losses_to_return

    def inference(self, label, previous_image):
        # Encode Inputs
        input_label, previous_image_input = self.encode_input(label, previous_image) 

        # Fake Generation
        fake_image = self.netG.forward(torch.cat((input_label, previous_image_input)), dim=1)
        return fake_image


    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)


In [None]:
dir_name = os.path.abspath('')

def update_opt(opt, target_dir, run_name):
    opt.dataroot = os.path.join(target_dir, 'train')
    opt.checkpoints_dir = os.path.join(dir_name, "../../checkpoints")
    opt.name = run_name
    if os.path.isdir(os.path.join(opt.checkpoints_dir, run_name)):
        print("Run already exists, will try to resume training")
        opt.load_pretrain = os.path.join(dir_name, "../../checkpoints", run_name)

    if device == torch.device('cpu'):
        opt.gpu_ids = []
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    opt.temporal_smoothing = True

    return opt
    
def train_pose2vid_temporal_smoothing(target_dir, run_name):
    import src.config.train_opt_notebook as opt

    opt = update_opt(opt, target_dir, run_name)
    
    os.makedirs(os.path.join(opt.checkpoints_dir, run_name), exist_ok=True)
    iter_path = os.path.join(opt.checkpoints_dir, run_name, 'iter.json')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    if opt.load_pretrain != '':
        with open(iter_path, 'r') as f:
            iter_json = json.load(f)
    else:
        iter_json = {'start_epoch': 1, 'epoch_iter': 0}

    start_epoch = iter_json['start_epoch']
    epoch_iter = iter_json['epoch_iter']
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = SmoothGancing()
    model.initialize(opt)
    model = model.to(device)

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            ############## Forward Pass ######################
            losses, generated = model(data['label'], data['image'],
                                      data['previous_label'], data['previous_image'], infer=True)


            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()


            ############## Display results and errors ##########

            print(f"Epoch {epoch} batch {i}:")
            print(f"loss_D: {loss_D}, loss_G: {loss_G}")
            print(f"loss_D_fake: {loss_dict['D_fake']}, loss_D_real: {loss_dict['D_real']}")
            print(f"loss_G_GAN {loss_dict['G_GAN']}, loss_G_GAN_Feat: {loss_dict.get('G_GAN_Feat', 0)}, loss_G_VGG: {loss_dict.get('G_VGG', 0)}\n")

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize


            ### display output images
            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                    ('synthesized_image', util.tensor2im(generated.data[0])),
                                    ('real_image', util.tensor2im(data['image'][0]))])
            fig, axes = init_figure()
            plot_current_results(visuals, fig, axes)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
                model.save('latest')
                iter_json['start_epoch'] = epoch
                iter_json['epoch_iter'] = epoch_iter
                with open(iter_path, 'w') as f:
                    json.dump(iter_json, f)

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            iter_json['start_epoch'] = epoch + 1
            iter_json['epoch_iter'] = 0
            with open(iter_path, 'w') as f:
                json.dump(iter_json, f)


In [None]:
target_dir = '../../data/targets/gianluca/'
run_name = 'testnotebook_ts'

train_pose2vid_temporal_smoothing(target_dir, run_name)

In [None]:
target_dir = '../../data/targets/gianluca/'
run_name = 'gianluca_ts'

train_pose2vid_temporal_smoothing(target_dir, run_name)

In [None]:
mp4 = open('../../results/gianluca_ts.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)