In [1]:
from __future__ import print_function, division

from keras.models import Sequential, Model
from keras.layers import *
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.initializers import RandomNormal
from keras.applications import *
import keras.backend as K
from tensorflow.contrib.distributions import Beta
import tensorflow as tf
from keras.optimizers import Adam
from keras import losses
from keras.utils import to_categorical

Using TensorFlow backend.


In [2]:
from image_augmentation import random_transform
from image_augmentation import random_warp
from utils import get_image_paths, load_images, stack_images
from pixel_shuffler import PixelShuffler

In [3]:
import time
import numpy as np
from PIL import Image
import cv2
import glob
from random import randint, shuffle
from IPython.display import clear_output
from IPython.display import display
import matplotlib.pyplot as plt
%matplotlib inline

Code borrow from [eriklindernoren](https://github.com/eriklindernoren) and [fchollet](https://github.com/fchollet)

https://github.com/eriklindernoren/Keras-GAN/blob/master/aae/adversarial_autoencoder.py

https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.5-introduction-to-gans.ipynb

In [7]:
class FaceSwapGAN():
    def __init__(self):
        self.img_size = 64 
        self.channels = 3
        self.img_shape = (self.img_size, self.img_size, self.channels)
        self.encoded_dim = 1024
        self.img_dirA = './faceA/*.*'
        self.img_dirB = './faceB/*.*'
        self.random_transform_args = {
            'rotation_range': 20,
            'zoom_range': 0.05,
            'shift_range': 0.05,
            'random_flip': 0.5,
            }

        optimizer = Adam(2e-4, 0.5, 0.999)

        # Build and compile the discriminator
        self.netDA, self.netDB = self.build_discriminator()
        self.netDA.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
        self.netDB.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

        # Build and compile the generator
        self.netGA, self.netGB = self.build_generator()
        self.netGA.compile(loss=['mae', 'mse'], optimizer=optimizer)
        self.netGB.compile(loss=['mae', 'mse'], optimizer=optimizer)

        img = Input(shape=self.img_shape)
        alphaA, reconstructed_imgA = self.netGA(img)
        alphaB, reconstructed_imgB = self.netGB(img)

        # For the adversarial_autoencoder model we will only train the generator
        self.netDA.trainable = False
        self.netDB.trainable = False

        def one_minus(x): return 1 - x
        # masked_img = alpha * reconstructed_img + (1 - alpha) * img
        masked_imgA = add([multiply([alphaA, reconstructed_imgA]), multiply([Lambda(one_minus)(alphaA), img])])
        masked_imgB = add([multiply([alphaB, reconstructed_imgB]), multiply([Lambda(one_minus)(alphaB), img])])
        out_discriminatorA = self.netDA(masked_imgA)
        out_discriminatorB = self.netDB(masked_imgB)

        # The adversarial_autoencoder model  (stacked generator and discriminator) takes
        # img as input => generates encoded represenation and reconstructed image => determines validity 
        self.adversarial_autoencoderA = Model(img, [reconstructed_imgA, out_discriminatorA])
        self.adversarial_autoencoderB = Model(img, [reconstructed_imgB, out_discriminatorB])
        self.adversarial_autoencoderA.compile(loss=['mae', 'mse'],
                                              loss_weights=[1, 0.5],
                                              optimizer=optimizer)
        self.adversarial_autoencoderB.compile(loss=['mae', 'mse'],
                                              loss_weights=[1, 0.5],
                                              optimizer=optimizer)

    def build_generator(self):
        def conv_block(input_tensor, f):
            x = input_tensor
            x = Conv2D(f, kernel_size=3, strides=2, kernel_initializer=RandomNormal(0, 0.02), 
                       use_bias=False, padding="same")(x)
            x = LeakyReLU(alpha=0.2)(x)
            return x

        def res_block(input_tensor, f):
            x = input_tensor
            x = Conv2D(f, kernel_size=3, kernel_initializer=RandomNormal(0, 0.02), 
                       use_bias=False, padding="same")(x)
            x = LeakyReLU(alpha=0.2)(x)
            x = Conv2D(f, kernel_size=3, kernel_initializer=RandomNormal(0, 0.02), 
                       use_bias=False, padding="same")(x)
            x = add([x, input_tensor])
            x = LeakyReLU(alpha=0.2)(x)
            return x

        def upscale_ps(filters, use_norm=True):
            def block(x):
                x = Conv2D(filters*4, kernel_size=3, use_bias=False, 
                           kernel_initializer=RandomNormal(0, 0.02), padding='same' )(x)
                x = LeakyReLU(0.1)(x)
                x = PixelShuffler()(x)
                return x
            return block

        def Encoder(img_shape):
            inp = Input(shape=img_shape)
            x = Conv2D(64, kernel_size=5, kernel_initializer=RandomNormal(0, 0.02), 
                       use_bias=False, padding="same")(inp)
            x = conv_block(x,128)
            x = conv_block(x,256)
            x = conv_block(x,512) 
            x = conv_block(x,1024)
            x = Dense(1024)(Flatten()(x))
            x = Dense(4*4*1024)(x)
            x = Reshape((4, 4, 1024))(x)
            out = upscale_ps(512)(x)
            return Model(inputs=inp, outputs=out)

        def Decoder_ps(img_shape):
            nc_in = 512
            input_size = img_shape[0]//8
            inp = Input(shape=(input_size, input_size, nc_in))
            x = inp
            x = upscale_ps(256)(x)
            x = upscale_ps(128)(x)
            x = upscale_ps(64)(x)
            x = res_block(x, 64)
            x = res_block(x, 64)
            alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
            rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
            return Model(inp, [alpha, rgb])
        
        encoder = Encoder(self.img_shape)
        decoder_A = Decoder_ps(self.img_shape)
        decoder_B = Decoder_ps(self.img_shape)    
        x = Input(shape=self.img_shape)
        netGA = Model(x, decoder_A(encoder(x)))
        netGB = Model(x, decoder_B(encoder(x)))           
        try:
            netGA.load_weights("models/netGA.h5")
            netGB.load_weights("models/netGB.h5")
            print ("Generator models loaded.")
        except:
            print ("Generator weights files not found.")
            pass

        return netGA, netGB, 

    def build_discriminator(self):  
        def conv_block_d(input_tensor, f, use_instance_norm=True):
            x = input_tensor
            x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=RandomNormal(0, 0.02), 
                       use_bias=False, padding="same")(x)
            x = LeakyReLU(alpha=0.2)(x)
            return x   
        def Discriminator(img_shape):
            inp = Input(shape=img_shape)
            x = conv_block_d(inp, 64, False)
            x = conv_block_d(x, 128, False)
            x = conv_block_d(x, 256, False)
            out = Conv2D(1, kernel_size=4, kernel_initializer=RandomNormal(0, 0.02), 
                         use_bias=False, padding="same", activation="sigmoid")(x)   
            return Model(inputs=[inp], outputs=out) 
        
        netDA = Discriminator(self.img_shape)
        netDB = Discriminator(self.img_shape)        
        try:
            netDA.load_weights("models/netDA.h5") 
            netDB.load_weights("models/netDB.h5") 
            print ("Discriminator models loaded.")
        except:
            print ("Discriminator weights files not found.")
            pass

        return netDA, netDB    


    def train(self, epochs, batch_size=8, save_interval=50):        
        def load_data(file_pattern):
            return glob.glob(file_pattern)
        
        def read_image(fn, random_transform_args=self.random_transform_args):
            image = cv2.imread(fn)
            image = cv2.resize(image, (256,256)) / 255 * 2 - 1
            image = random_transform(image, **random_transform_args )
            warped_img, target_img = random_warp(image)
            return warped_img, target_img

        def minibatch(data, batchsize):
            length = len(data)
            epoch = i = 0
            tmpsize = None  
            shuffle(data)
            while True:
                size = tmpsize if tmpsize else batchsize
                if i+size > length:
                    shuffle(data)
                    i = 0
                    epoch+=1        
                rtn = np.float32([read_image(data[j]) for j in range(i,i+size)])
                i+=size
                tmpsize = yield epoch, rtn[:,0,:,:,:], rtn[:,1,:,:,:]       

        def minibatchAB(dataA, batchsize):
            batchA = minibatch(dataA, batchsize)
            tmpsize = None    
            while True:        
                ep1, warped_img, target_img = batchA.send(tmpsize)
                tmpsize = yield ep1, warped_img, target_img

        # Load the dataset
        train_A = load_data(self.img_dirA)
        train_B = load_data(self.img_dirB)        
        assert len(train_A), "No image found in " + str(img_dirA) + "."
        assert len(train_B), "No image found in " + str(img_dirB) + "."
        train_batchA = minibatchAB(train_A, batch_size)
        train_batchB = minibatchAB(train_B, batch_size)

        print ("Training starts...")
        t0 = time.time()
        gen_iterations = 0
        while gen_iterations < 20000:
            #print ("iter: " + str(gen_iterations))

            # ---------------------
            #  Train Discriminators
            # ---------------------

            # Select a random half batch of images
            epoch, warped_A, target_A = next(train_batchA) 
            epoch, warped_B, target_B = next(train_batchB) 

            # Generate a half batch of new images
            gen_alphasA, gen_imgsA = self.netGA.predict(warped_A)
            gen_alphasB, gen_imgsB = self.netGB.predict(warped_B)
            gen_masked_imgsA = gen_alphasA * gen_imgsA + (1 - gen_alphasA) * warped_A
            gen_masked_imgsB = gen_alphasB * gen_imgsB + (1 - gen_alphasB) * warped_B

            valid = np.ones((batch_size, ) + self.netDA.output_shape[1:])
            fake = np.zeros((batch_size, ) + self.netDA.output_shape[1:])

            # Train the discriminators
            #print ("Train the discriminators.")
            d_lossA = self.netDA.train_on_batch(np.concatenate([target_A, gen_masked_imgsA], axis=0), 
                                                 np.concatenate([valid, fake], axis=0))
            d_lossB = self.netDB.train_on_batch(np.concatenate([target_B, gen_masked_imgsB], axis=0),
                                                 np.concatenate([valid, fake], axis=0))


            # ---------------------
            #  Train Generators
            # ---------------------

            # Train the generators
            #print ("Train the generators.")
            g_lossA = self.adversarial_autoencoderA.train_on_batch(warped_A, [target_A, valid])
            g_lossB = self.adversarial_autoencoderB.train_on_batch(warped_B, [target_B, valid])
            
            gen_iterations += 1             

            # If at save interval => save models & show results
            if (gen_iterations) % save_interval == 0:
                clear_output()
                # Plot the progress
                print('[%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'
                      % (epoch, "num_epochs", gen_iterations, d_lossA[0], 
                         d_lossB[0], g_lossA[0], g_lossB[0], time.time()-t0)) 
                
                # Save models
                self.netGA.save_weights("models/netGA.h5")
                self.netGB.save_weights("models/netGB.h5" )
                self.netDA.save_weights("models/netDA.h5")
                self.netDB.save_weights("models/netDB.h5")
                print ("Models saved.")
                
                # Show results
                _, wA, tA = train_batchA.send(14)  
                _, wB, tB = train_batchB.send(14)
                self.showG(tA, tB)
            
    def showG(self, test_A, test_B):      
        def display_fig(figure_A, figure_B):
            figure = np.concatenate([figure_A, figure_B], axis=0 )
            figure = figure.reshape((4,7) + figure.shape[1:])
            figure = stack_images(figure)
            figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
            figure = cv2.cvtColor(figure, cv2.COLOR_BGR2RGB)
            display(Image.fromarray(figure)) 
            
        out_test_A_netGA = self.netGA.predict(test_A)
        out_test_A_netGB = self.netGB.predict(test_A)
        out_test_B_netGA = self.netGA.predict(test_B)
        out_test_B_netGB = self.netGB.predict(test_B)
        
        figure_A = np.stack([
            test_A,
            out_test_A_netGA[0] * out_test_A_netGA[1] + (1 - out_test_A_netGA[0]) * test_A,
            out_test_A_netGB[0] * out_test_A_netGB[1] + (1 - out_test_A_netGB[0]) * test_A,
            ], axis=1 )
        figure_B = np.stack([
            test_B,
            out_test_B_netGB[0] * out_test_B_netGB[1] + (1 - out_test_B_netGB[0]) * test_B,
            out_test_B_netGA[0] * out_test_B_netGA[1] + (1 - out_test_B_netGA[0]) * test_B,
            ], axis=1 )
        print ("Masked results")
        display_fig(figure_A, figure_B)   
        
        figure_A = np.stack([
            test_A,
            out_test_A_netGA[1],
            out_test_A_netGB[1],
            ], axis=1 )
        figure_B = np.stack([
            test_B,
            out_test_B_netGB[1],
            out_test_B_netGA[1],
            ], axis=1 )
        print ("Raw results")
        display_fig(figure_A, figure_B)       
        
        figure_A = np.stack([
            test_A,
            np.tile(out_test_A_netGA[0],3) * 2 - 1,
            np.tile(out_test_A_netGB[0],3) * 2 - 1,
            ], axis=1 )
        figure_B = np.stack([
            test_B,
            np.tile(out_test_B_netGB[0],3) * 2 - 1,
            np.tile(out_test_B_netGA[0],3) * 2 - 1,
            ], axis=1 )
        print ("Alpha masks")
        display_fig(figure_A, figure_B)        

In [None]:
gan = FaceSwapGAN()

Discriminator weights files not found.
Generator weights files not found.


In [None]:
gan.train(epochs=1, batch_size=8, save_interval=100)

Training starts...
