# Pose Guided Person Image Generation (PG²)

This is an attempt to re-implement the paper PG²

Paper: https://arxiv.org/pdf/1705.09368.pdf

Other Resources: 
* https://github.com/charliememory/Pose-Guided-Person-Image-Generation

In [1]:
import tensorflow as tf

In [2]:
class ResBlock(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
        self.act1 = tf.keras.layers.ReLU()
        self.conv2 = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
        self.act2 = tf.keras.layers.ReLU()
        
    def call(self, inputs):
        x = self.act1(self.conv1(inputs))
        x = self.act2(self.conv2(x))
        out = tf.add(x, inputs)
        return out

In [3]:
class GAN(object):
    def __init__(self, img_shape, pose_img_shape):
        self.img_shape = img_shape
        self.pose_img_shape = pose_img_shape
        
    def __down_sample(self, x, filters, name):
        
        model = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (2, 2), padding = 'same'),
            tf.keras.layers.ReLU()
        ], name = name)
        
        #x = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
        #x = tf.keras.layers.ReLU()(x)
        #return x
        return model(x)
    
    def __up_sample(self, x, filters, name):
        model = tf.keras.models.Sequential([
            tf.keras.layers.UpSampling2D(size = (2, 2)),
            tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same'),
            tf.keras.layers.ReLU(),
        ], name = name)
        return model(x)
    
    def __conv_block(self, x, filters, name):
        model = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same'),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same'),
            tf.keras.layers.ReLU(),
        ], name = name)
        return model(x)
        
    @property
    def stage1_generator(self):
        
        inp_img = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, name = 'stage1_generator_inp_img')
        inp_pose = tf.keras.layers.Input(shape = self.pose_img_shape, dtype = tf.float32, 
                                         name = 'stage1_generator_pose_img')
        
        x = tf.keras.layers.Concatenate(axis = -1)([inp_img, inp_pose])
        
        x = tf.keras.layers.Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
        x = tf.keras.layers.ReLU()(x)
        
        # Encoder
        e0 = ResBlock(filters = 32)(x)
        
        e1 = self.__down_sample(x = e0, filters = 64, name = 'g1_down_sample_encoder_1')
        e1 = ResBlock(filters = 64)(e1)
        
        e2 = self.__down_sample(x = e1, filters = 64, name = 'g1_down_sample_encoder_2')
        e2 = ResBlock(filters = 64)(e2)
        
        e3 = self.__down_sample(x = e2, filters = 128, name = 'g1_down_sample_encoder_3')
        e3 = ResBlock(filters = 128)(e3)
        
        e4 = self.__down_sample(x = e3, filters = 128, name = 'g1_down_sample_encoder_4')
        e4 = ResBlock(filters = 128)(e4)
        
        e5 = self.__down_sample(x = e4, filters = 256, name = 'g1_down_sample_encoder_5')
        e5 = ResBlock(filters = 256)(e5)
        
        #e6 = self.__down_sample(x = e5, filters = 256)
        #e6 = ResBlock(filters = 256)(e6)
        ###
        
        # Bottle neck
        b = tf.keras.layers.Flatten()(e5)
        b = tf.keras.layers.Dense(units = 256)(b)
        b = tf.keras.layers.Dense(units = e5.get_shape()[1] * e5.get_shape()[2] * e5.get_shape()[3])(b)
        b = tf.keras.layers.Reshape((e5.get_shape()[1], e5.get_shape()[2], e5.get_shape()[3]))(b)
        ###
        
        #d1 = tf.keras.layers.UpSampling2D(size = (2, 2))(b)
        d0 = ResBlock(filters = 256)(b)
        d0 = tf.keras.layers.Concatenate()([d0, e5])
        
        d1 = self.__up_sample(x = d0, filters = 128, name = 'g1_up_sample_decoder_1')
        d1 = ResBlock(filters = 128)(d1)
        d1 = tf.keras.layers.Concatenate()([d1, e4])
        
        d2 = self.__up_sample(x = d1, filters = 128, name = 'g1_up_sample_decoder_2')
        d2 = ResBlock(filters = 128)(d2)
        d2 = tf.keras.layers.Concatenate()([d2, e3])
        
        d3 = self.__up_sample(x = d2, filters = 64, name = 'g1_up_sample_decoder_3')
        d3 = ResBlock(filters = 64)(d3)
        d3 = tf.keras.layers.Concatenate()([d3, e2])
        
        d4 = self.__up_sample(x = d3, filters = 64, name = 'g1_up_sample_decoder_4')
        d4 = ResBlock(filters = 64)(d4)
        d4 = tf.keras.layers.Concatenate()([d4, e1])
        
        d5 = self.__up_sample(x = d4, filters = 32, name = 'g1_up_sample_decoder_5')
        d5 = ResBlock(filters = 32)(d5)
        d5 = tf.keras.layers.Concatenate()([d5, e0])
        ###
        
        out = tf.keras.layers.Conv2D(filters = 3, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(d5)
        out = tf.keras.layers.Activation('tanh')(out)
        
        return tf.keras.models.Model([inp_img, inp_pose], out, name = 'stage2_generator')
    
    
    @property
    def stage2_generator(self):
        
        inp_img = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, name = 'stage2_generator_inp_img')
        inp_coarse = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, 
                                           name = 'stage2_generator_inp_coarse')
        
        x = tf.keras.layers.Concatenate()([inp_img, inp_coarse])
        
        x = tf.keras.layers.Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
        x = tf.keras.layers.ReLU()(x)
        
        # Encoder
        e0 = self.__conv_block(x, filters = 64, name = 'g2_conv_block_encoder_0')
        
        e1 = self.__down_sample(e0, filters = 64, name = 'g2_down_sample_encoder_1')
        e1 = self.__conv_block(e1, filters = 64, name = 'g2_conv_block_encoder_1')
        
        e2 = self.__down_sample(e1, filters = 128, name = 'g2_down_sample_encoder_2')
        e2 = self.__conv_block(e2, filters = 128, name = 'g2_conv_block_encoder_2')
        
        e3 = self.__down_sample(e2, filters = 128, name = 'g2_down_sample_encoder_3')
        e3 = self.__conv_block(e3, filters = 128, name = 'g2_conv_block_encoder_3')
        
        e4 = self.__down_sample(e3, filters = 256, name = 'g2_down_sample_encoder_4')
        e4 = self.__conv_block(e4, filters = 256, name = 'g2_conv_block_encoder_4')
        ###
        
        # Bottleneck
        b = self.__conv_block(e4, filters = 256, name = 'g2_conv_block_bottleneck')
        
        # Decoder
        d0 = self.__conv_block(b, filters = 256, name = 'g2_conv_block_decoder_0')
        d0 = tf.keras.layers.Concatenate()([d0, e4])
        
        d1 = self.__up_sample(d0, filters = 128, name = 'g2_up_sample_decoder_1')
        d1 = self.__conv_block(d1, filters = 128, name = 'g2_conv_block_decoder_1')
        d1 = tf.keras.layers.Concatenate()([d1, e3])
        
        d2 = self.__up_sample(d1, filters = 128, name = 'g2_up_sample_decoder_2')
        d2 = self.__conv_block(d2, filters = 128, name = 'g2_conv_block_decoder_2')
        d2 = tf.keras.layers.Concatenate()([d2, e2])
        
        d3 = self.__up_sample(d2, filters = 64, name = 'g2_up_sample_decoder_3')
        d3 = self.__conv_block(d3, filters = 64, name = 'g2_conv_block_decoder_3')
        d3 = tf.keras.layers.Concatenate()([d3, e1])
        
        d4 = self.__up_sample(d3, filters = 64, name = 'g2_up_sample_decoder_4')
        d4 = self.__conv_block(d4, filters = 64, name = 'g2_conv_block_decoder_4')
        d4 = tf.keras.layers.Concatenate()([d4, e0])
        ###
        
        out = tf.keras.layers.Conv2D(filters = 3, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(d4)
        out = tf.keras.layers.Activation('tanh')(out)
        
        return tf.keras.models.Model([inp_img, inp_coarse], out)
    
    @property
    def discriminator(self):
        inp_img_a = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, name = 'discriminator_real_input')
        inp_img_b = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, name = 'discriminator_gen_input')
        
        x = tf.keras.layers.Concatenate()([inp_img_a, inp_img_b])
        
        x = tf.keras.layers.Conv2D(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 256, kernel_size = (3, 3), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(units = 1)(x)
        
        return tf.keras.models.Model([inp_img_a, inp_img_b], x, name = 'discriminator')

In [4]:
gan = GAN((256, 256, 3), (256, 256, 1))
# gan.stage2_generator.summary()
# gan.discriminator.summary()

In [5]:
# tf.keras.utils.plot_model(gan.stage2_generator, dpi = 64, show_shapes = True)
# tf.keras.utils.plot_model(gan.discriminator, dpi = 64, show_shapes = True)

In [6]:
class Losses(object):
    def __init__(self, LAMBDA = 1):
        self.LAMBDA = LAMBDA
        self.bce = tf.keras.losses.BinaryCrossentropy(from_logits = True)
        
    def pose_mask_loss(self, gen_out, tar_img, pose_mask):
        return tf.math.reduce_mean(tf.math.abs((gen_out - tar_img) * (1 + pose_mask)))
    
    def discriminator_adversarial_loss(self, disc_real_out, disc_gen_out):
        real_loss = self.bce(tf.ones_like(disc_real_out), disc_real_out)
        gen_loss = self.bce(tf.zeros_like(disc_gen_out), disc_gen_out)
        return real_loss + gen_loss
    
    def generator_adversarial_loss(self, disc_gen_out):
        return self.bce(tf.ones_like(disc_gen_out), disc_gen_out)

In [8]:
class Trainer(object):
    def __init__(self, learning_rate = 2e-5, img_shape = (256, 256, 3), pose_img_shape = (256, 256, 1)):
        self.gen1_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        self.gen2_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        self.disc_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        
        gan = GAN(img_shape = img_shape, pose_img_shape = pose_img_shape)
        self.generator_1 = gan.stage1_generator
        self.generator_2 = gan.stage2_generator
        self.discriminator = gan.discriminator
        
        self.losses = Losses()
        
    @property
    def train_step(self, inp_img, pose_inp_img, tar_img, pose_mask):
        
        with tf.GradientTape(persistent = True) as tape:
            gen1_out = self.generator_1([inp_img, pose_inp_img], training = True)
            gen2_out = self.generator_2([inp_img, gen1_out], training = True)
            
            disc_real_out = self.discriminator([inp_img, tar_img], training = True)
            disc_gen_out = self.discriminator([inp_img, gen2_out], training = True)
            
            gen1_loss = self.losses.pose_mask_loss(gen1_out, tar_img, pose_mask) 
            #gen1_loss =+ self.losses.generator_adversarial_loss(self.discriminator([inp_img, gen1_out], training = True))
            #gen1_loss =+ self.losses.generator_adversarial_loss(disc_gen_out)
            
            gen2_loss = self.losses.pose_mask_loss(gen2_out, tar_img, pose_mask) * self.losses.LAMDA
            gen2_loss += self.losses.generator_adversarial_loss(disc_gen_out)
            
            disc_loss = self.losses.discriminator_adversarial_loss(disc_real_out, disc_gen_out)
            
        gen1_grads = tape.gradient(gen1_loss, self.generator_1.trainable_variables)
        self.gen1_optimizer.apply_gradients(zip(gen1_grads, self.generator_1.trainable_variables))
        
        gen2_grads = tape.gradient(gen2_loss, self.generator_2.trainable_variables)
        self.gen2_optimizer.apply_gradients(zip(gen2_grads, self.generator_2.trainable_variables))
        
        disc_grads = tape.gradient(disc_loss, self.discriminator.trainable_variables)
        self.disc_optimizer.apply_gradients(zip(disc_grads, self.discriminator.trainable_variables))
        
        return gen1_loss, gen2_loss, disc_loss
    
    def train(self, data, epochs = 1):
        gen1_losses, gen2_losses, disc_losses = [], [], []
        for e in range(epochs):
            print(f'Epoch: {e} Starts')
            for inp_img, pose_inp_img, tar_img, pose_mask in data:
                gen1_loss, gen2_loss, disc_loss = self.train_step(inp_img, pose_inp_img, tar_img, pose_mask)
                print('.', end = '')
                
            gen1_losses.append(gen1_loss)
            gen2_losses.append(gen2_loss)
            disc_losses.append(disc_loss)
            print(f'\nGenerator(1) Loss: {gen1_loss} \t Generator(2) Loss: {gen2_loss} \t Discriminator Loss: {disc_loss}')
            print(f'Epochs: {e} Ends \n')
        
        return {'gen1_losses': gen1_losses, 'gen2_losses': gen2_losses, 'disc_losses': disc_losses}

In [10]:
trainer = Trainer()
# trainer.train(dataset, epochs = 1)