<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"></ul></div>

In [0]:
%%bash
git clone https://github.com/5agado/face-swap.git
cd face-swap
pip install .

In [0]:
# Create required folders
%%bash
mkdir models
mkdir data
mkdir checkpoints

In [0]:
# GDrive authentication
!pip install -U -q PyDrive

from utils.colaboratory_utils import *
drive_client = get_authenticated_drive_client()
drive_service = get_drive_service()

In [0]:
# GDrive folders IDs
face_swap_dir_id = 'id_of_your_drive_dir_where_to_save_models'
sample_plots_dir_id = 'id_of_your_drive_dir_where_to_save_sample_plots'

In [0]:
# Import data needed for training
import_file(drive_client, "data_zipfile_id", "data.zip")
import_file(drive_client, "models_zipfile_id", "models.zip")

In [0]:
# Install needed dependencies
%%bash
unzip -q data.zip -d data
unzip -q models.zip -d models

#apt-get -qq install -y libsm6 libxext6 build-essential cmake
pip install --trusted-host pypi.python.org -r face-swap/requirements_train.txt

In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from numpy.random import shuffle
from IPython.display import clear_output

from pathlib import Path
import sys
import os
import yaml

import cv2

from tqdm import tqdm

#%matplotlib notebook
%matplotlib inline

sys.path.append('./face-swap')
sys.path.append('./face-swap/face_swap')

from utils import image_processing

from face_swap import autoencoder
from face_swap import gan, gan_utils
from face_swap import faceswap_utils as utils
from face_swap.train import get_original_data, get_training_data
from face_swap.plot_utils import plot_sample
from face_swap.FaceGenerator import random_transform, random_warp
from face_swap import FaceGenerator

%load_ext autoreload
%autoreload 2

In [0]:
data_folder = Path("data/data")
models_path = Path("models/checkpoints")

In [0]:
# Path of training images
img_dirA = data_folder / "cage"
img_dirB = data_folder / "trump"

In [0]:
with open("face-swap/models.cfg", 'r') as ymlfile:
    cfg = yaml.load(ymlfile)

In [0]:
model_name = "base_gan"
model_version = "v0"

# Manually add models path instead of getting it from the config
model_cfg = cfg[model_name][model_version]
model_cfg['models_path'] = str(models_path)

netGA, netGB, netDA, netDB = gan.get_gan(model_cfg)

In [0]:
# define generation and plotting function
# depending if using masked gan model or not
if model_cfg['masked']:
  netGA_train, netGB_train, netDA_train, netDB_train = gan.build_training_functions_masked(cfg[model_name][model_version], 
                                                                       netGA, netGB, netDA, netDB)
  distorted_A, fake_A, mask_A, path_A, fun_mask_A, fun_abgr = gan_utils.cycle_variables_masked(netGA)
  distorted_B, fake_B, mask_B, path_B, fun_mask_B, fun_abgr = gan_utils.cycle_variables_masked(netGB)
  gen_plot_a = lambda x: np.array(path_A([x])[0]) 
  gen_plot_b = lambda x: np.array(path_B([x])[0])
  gen_plot_mask_a = lambda x: np.array(fun_mask_A([x])[0])*2-1
  gen_plot_mask_b = lambda x: np.array(fun_mask_B([x])[0])*2-1
else:
  netGA_train, netGB_train, netDA_train, netDB_train = gan.build_training_functions(cfg[model_name][model_version], 
                                                                     netGA, netGB, netDA, netDB)
  gen_plot_a = lambda x: netGA.predict(x)
  gen_plot_b = lambda x: netGB.predict(x)

In [0]:
errsGA = []
errsGB = []
errsDA = []
errsDB = []

In [0]:
total_epochs = 0

In [0]:
models_path = Path("checkpoints")

In [0]:
show_plot = True
batch_size = 32
NB_EPOCH_CHECKPOINT = 1000
nb_epochs = 10000

images_a, images_b = get_original_data(img_dirA, img_dirB, tanh_fix=True)
samples_a, samples_b = get_original_data(img_dirA, img_dirB, (64, 64), tanh_fix=True)

for gen_iterations in tqdm(range(nb_epochs)):
    total_epochs += 1
  
    warped_A, target_A = get_training_data(images_a, batch_size, cfg)
    warped_B, target_B = get_training_data(images_b, batch_size, cfg)
    
    # Train discriminators for one batch
    errDA  = netDA_train([warped_A, target_A])
    errDB  = netDB_train([warped_B, target_B])
    errsDA.append(errDA[0])
    errsDB.append(errDB[0])

    # Train generators for one batch
    errGA = netGA_train([warped_A, target_A])
    errGB = netGB_train([warped_B, target_B])
    errsGA.append(errGA[0])
    errsGB.append(errGB[0])
    
    if (gen_iterations % NB_EPOCH_CHECKPOINT == 0) or (gen_iterations == nb_epochs):
        print("Loss_DA: {} Loss_DB: {} Loss_GA: {} Loss_GB: {}".format(errDA, errDB, errGA, errGB))
        
        # get new batch of images and generate results for visualization
        sample_img_name = "sample_{}.jpg".format(total_epochs)
        shuffle(samples_a)
        shuffle(samples_b)
        
        # if specified, show sample plot in the notebook
        if show_plot:
          if gen_iterations % (3*NB_EPOCH_CHECKPOINT) == 0:
            clear_output()
          plot_sample(samples_a, samples_b, 
                    gen_plot_a, gen_plot_b,
                    tanh_fix=True)
          if model_cfg['masked']:
            plot_sample(samples_a, samples_b, 
                      gen_plot_mask_a, gen_plot_mask_b,
                      tanh_fix=True)
        # otherwise save to file and export to Drive
        else:
          sample_img_name = "sample_{}.jpg".format(total_epochs)
          plot_sample(samples_a, samples_b, 
                    gen_plot_a, gen_plot_b,
                    tanh_fix=True, save_to=sample_img_name)
          export_file(drive_service, sample_img_name, sample_img_name, sample_plots_dir_id)
        
        # save models
        netGA.layers[1].save_weights(str(models_path / "encoder.h5"))
        netGA.layers[2].save_weights(str(models_path / "decoder_A.h5"))
        netGB.layers[2].save_weights(str(models_path / "decoder_B.h5"))
        netDA.save_weights(str(models_path / "netDA.h5"))
        netDB.save_weights(str(models_path / "netDB.h5"))

In [0]:
# Name of zip that will contain model checkpoint
zip_name = "gan_v1.zip"

In [0]:
# Zip model checkpoint
!zip -r $zip_name $models_path/*

In [0]:
# Export zip to GDrive
export_file(drive_service, zip_name, zip_name, face_swap_dir_id)