# Deepfakes with GANs
> Re-enactment using Pix2Pix

<img src="deepfake_banner.png">

We covered image-to-image translation GAN architectures in Chapter 5. Particularly, we discussed in detail how **pix2pix GAN** is a powerful architecture which enables paired translation tasks. In this notebook, we will leverage pix2pix GAN to develop a face re-enactment setup from scratch. We will:
+ build a pix2pix network
+ prepare the dataset using a video
+ train the model for reenactment using facial landmarks

The actual reenactment part is covered in the second notebook for this chapter. 

## Load Libraries

In [None]:
import os
import cv2
import dlib
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torchvision.transforms as transforms
from PIL import Image

In [None]:
from gan_utils import PATCH_GAN_SHAPE
from gan_utils import Generator,Discriminator 
from gan_utils import (IMG_WIDTH,
                        IMG_HEIGHT,
                        NUM_CHANNELS,
                        BATCH_SIZE,
                        N_EPOCHS,
                        SAMPLE_INTERVAL)

In [None]:
from dataset_utils import ImageDataset, prepare_data
from dataset_utils import DATASET_PATH, DOWNSAMPLE_RATIO

## Set Parameters

In [None]:
CUDA = True if torch.cuda.is_available() else False
os.makedirs("saved_models/", exist_ok=True)

In [None]:
# get landmarks model if not already available
!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
!bunzip2 "shape_predictor_68_face_landmarks.dat.bz2"

In [None]:
# instantiate objects for face and landmark detection
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor('shape_predictor_68_face_landmarks.dat')

# Pix2Pix GAN for Re-enactment

In their work titled [“Image to Image Translation with Conditional Adversarial Networks”](https://arxiv.org/abs/1611.07004), Isola and Zhu et. al. present a conditional GAN network which is able to learn task specific loss functions and thus work across datasets. As the name suggests, this GAN architecture takes a specific type of image as input and transforms it into a different domain. It is called pair-wise style transfer as the training set needs to have samples from both, source and target domains.

## U-Net Generator
The U-Net architecture uses skip connections to shuttle important features between the input and outputs. In case of pix2pix GAN, skip connections are added between every $ith$ down-sampling and $(n-i)th$ over-sampling layers, where $n$ is the total number of layers in the generator. The skip connection leads to concatenation of all channels from the ith and $(n-i)th$ layers.

## Patch-GAN Discriminator
The authors for pix2pix propose a Patch-GAN setup for the discriminator which takes the required inputs and generates an output of size NxN. Each $x_{ij}$ element of the NxN output signifies whether the corresponding patch ij in the generated image is real or fake. Each output patch can be traced back to its initial input patch basis the effective receptive field for each of the layers.

## Initialize Generator and Discriminator Model Objects

In [None]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Loss functions
adversarial_loss = torch.nn.MSELoss()
pixelwise_loss = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
weight_pixel_wise_identity = 100

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

## Prepare Dataset

In [None]:
# prepare data
prepare_data('obama.mp4',
             detector,
             predictor,
             num_samples=400,
             downsample_ratio = DOWNSAMPLE_RATIO)

## Setup Objects based on GPU Availability

In [None]:
if CUDA:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    adversarial_loss.cuda()
    pixelwise_loss.cuda()
    Tensor = torch.cuda.FloatTensor
else:
  Tensor = torch.FloatTensor

## Define Transformations and Dataloaders

In [None]:
image_transformations = [
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

In [None]:
train_dataloader = DataLoader(
    ImageDataset(DATASET_PATH, image_transformations=image_transformations),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [None]:
val_dataloader = DataLoader(
    ImageDataset(DATASET_PATH,image_transformations=image_transformations),
    batch_size=BATCH_SIZE//8,
    shuffle=True
)

## Training Begins!

In [None]:
def sample_images(val_dataloader,batches_done):
    """
        Method to generate sample images for validation
        Parameters:
            val_dataloader: instance of dataloader
            batches_done: training iteration counter
    """
    imgs = next(iter(val_dataloader))
    # condition
    real_A = Variable(imgs["B"].type(Tensor))
    # real
    real_B = Variable(imgs["A"].type(Tensor))
    # generated
    generator.eval()
    fake_B = generator(real_A)
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    save_image(img_sample, f"{DATASET_PATH}/{batches_done}.png", nrow=4, normalize=True)

In [None]:
for epoch in range(0, N_EPOCHS):
    for i, batch in enumerate(train_dataloader):

        # prepare inputs
        real_A = Variable(batch["B"].type(Tensor))
        real_B = Variable(batch["A"].type(Tensor))

        # ground truth
        valid = Variable(Tensor(np.ones((real_A.size(0), *PATCH_GAN_SHAPE))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *PATCH_GAN_SHAPE))), requires_grad=False)

        #  Train Generator
        optimizer_G.zero_grad()

        # generator loss
        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B, real_A)
        adv_loss = adversarial_loss(pred_fake, valid)
        loss_pixel = pixelwise_loss(fake_B, real_B)

        # Overall Generator loss
        g_loss = adv_loss + weight_pixel_wise_identity * loss_pixel

        g_loss.backward()

        optimizer_G.step()

        #  Train Discriminator
        optimizer_D.zero_grad()

        pred_real = discriminator(real_B, real_A)
        loss_real = adversarial_loss(pred_real, valid)
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = adversarial_loss(pred_fake, fake)

        # Overall Discriminator loss
        d_loss = 0.5 * (loss_real + loss_fake)

        d_loss.backward()
        optimizer_D.step()

        # Progress Report
        batches_done = epoch * len(train_dataloader) + i
        print(f'Epoch: {epoch}/{N_EPOCHS}-Batch: {i}/{len(train_dataloader)}--D.loss:{d_loss.item():.4f},G.loss:{g_loss.item():.4f}--Adv.Loss:{adv_loss.item():.4f}')

        # generate samples
        if batches_done % SAMPLE_INTERVAL == 0:
            sample_images(val_dataloader,batches_done)

## Save the Trained Models

In [None]:
torch.save(generator.state_dict(), "saved_models/generator.pt")
torch.save(discriminator.state_dict(), "saved_models/discriminator.pt")