# Pose Transfer

This is an attempt to re-implement the paper pose transfer

Paper: https://arxiv.org/pdf/1904.03349.pdf

Other Resources: 
* https://github.com/tengteng95/Pose-Transfer

In [1]:
import tensorflow as tf

In [2]:
class InstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
    
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.gamma = self.add_weight(shape = (1, 1, 1, inp_filters), initializer = init, 
                                     trainable = True, name = 'gamma')
        self.beta = self.add_weight(shape = (1, 1, 1, inp_filters), initializer = 'zeros', 
                                    trainable = True, name = 'beta')
        
    def call(self, inputs):
        mean_x = tf.math.reduce_mean(inputs, axis = [1, 2], keepdims = True)
        rstd_x = tf.math.rsqrt(tf.math.reduce_variance(inputs, axis = [1, 2], keepdims = True) + self.epsilon)
        norm = (inputs - mean_x) * rstd_x
        out = self.gamma * norm + self.beta
        return out

In [3]:
class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding, **kwargs):
        super().__init__(**kwargs)
        
        if not isinstance(padding, str):
            if isinstance(padding, int):
                self.padding = ((padding, padding), (padding, padding))
            elif isinstance(padding, tuple) | isinstance(padding, list):
                if isinstance(padding[0], int):
                    self.padding = ((padding[0], padding[0]), (padding[1], padding[1]))
                elif isinstance(padding[0], tuple):
                    self.padding = ((padding[0][0], padding[0][1]), (padding[1][0], padding[1][1]))
                else:
                    raise Exception('')
                    
            else:
                raise Exception('')
                
        else:
            raise Exception('padding must not be a `string`')
    
    def call(self, inputs):
        return tf.pad(inputs, ((0, 0), self.padding[0], self.padding[1], (0, 0)), mode = 'REFLECT')

In [4]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, neurons, **kwargs):
        super().__init__(**kwargs)
        self.neurons = neurons
        
    def build(self, input_shape):
        inp_neurons = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.W = self.add_weight(shape = (inp_neurons, self.neurons), initializer = init, 
                                 trainable = True, name = 'weight')
        self.B = self.add_weight(shape = (1, self.neurons), initializer = 'zeros', 
                                 trainable = True, name = 'bias')
        
    def call(self, inputs):
        return tf.add(tf.matmul(inputs, self.W), self.B)

In [5]:
class Conv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, padding, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        
        if isinstance(padding, str):
            if padding.upper() in ['SAME', 'VALID']:
                self.padding = padding.upper()
            else:
                raise Exception('invalid padding type.')
        elif isinstance(padding, tuple) | isinstance(padding, list) | isinstance(padding, int):
            self.padding = ReflectionPadding2D(padding)
            
        else:
            raise Exception('invalid padding type.')
            
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.W = self.add_weight(shape = self.kernel_size + (inp_filters, self.filters), initializer = init, 
                                 trainable = True, name = 'weight')
        self.B = self.add_weight(shape = (1, 1, 1, self.filters), initializer = 'zeros', 
                                 trainable = True, name = 'bias')
            
    def call(self, inputs):
        if isinstance(self.padding, str):
            return tf.add(tf.nn.conv2d(inputs, self.W, self.strides, self.padding), self.B)
        return tf.add(tf.nn.conv2d(self.padding(inputs), self.W, self.strides, 'VALID'), self.B)

In [6]:
class Normalization(tf.keras.layers.Layer):
    def __init__(self, norm, **kwargs):
        super().__init__(**kwargs)
        if norm == 'batch_norm':
            self.norm = tf.keras.layers.BatchNormalization()
        elif norm == 'inst_norm':
            self.norm = InstanceNormalization()
        else:
            self.norm = norm
            
    def call(self, inputs):
        return self.norm(inputs)
    
    
class Activation(tf.keras.layers.Layer):
    def __init__(self, activation, **kwargs):
        super().__init__(**kwargs)
        if activation == 'relu':
            self.activation = tf.keras.layers.ReLU()
        elif activation == 'leaky_relu':
            self.activation = tf.keras.layers.LeakyReLU()
        elif activation == 'tanh':
            self.activation = tf.keras.layers.Activation('tanh')
        else:
            self.activation = activation
            
    def call(self, inputs):
        return self.activation(inputs)

In [7]:
class Conv2DBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, padding, norm, activation, **kwargs):
        super().__init__(**kwargs)
        self.conv = Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)
        self.norm = Normalization(norm = norm) if norm is not None else lambda x: x
        self.act = Activation(activation = activation) if activation is not None else lambda x: x
        
    def call(self, inputs):
        return self.act(self.norm(self.conv(inputs)))

In [8]:
class PATBlock(tf.keras.layers.Layer):
    '''
        Pose-Attentional Transfer Block
    '''
    def __init__(self, filters, norm, activation, use_dropout, first_block, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.norm = norm
        self.activation = activation
        self.use_dropout = use_dropout 
        self.first_block = first_block
        
        self.image_pathway = self.image_pathway_block()
        self.pose_pathway = self.pose_pathway_block()
    
    def image_pathway_block(self):
        '''
            conv_p
        '''
        model = tf.keras.models.Sequential()
        model.add(Conv2DBlock(filters = self.filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                              norm = self.norm, activation = self.activation))
        if self.use_dropout:
            model.add(tf.keras.layers.Dropout(0.5))
            
        model.add(Conv2DBlock(filters = self.filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                              norm = self.norm, activation = None))
        return model
    
    def pose_pathway_block(self):
        '''
            conv_s
        '''
        dim = self.filters if self.first_block else self.filters * 2
        model = tf.keras.models.Sequential()
        
        model.add(Conv2DBlock(filters = self.filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                              norm = self.norm, activation = self.activation))

        if self.use_dropout:
            model.add(tf.keras.layers.Dropout(0.5))
            
        model.add(Conv2DBlock(filters = self.filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                              norm = None, activation = None))
        return model
    
    def call(self, inputs):
        assert len(inputs) == 2
        img, pose = inputs
        
        img_out = self.image_pathway(img)
        pose_out = self.pose_pathway(pose)
        
        m = tf.nn.sigmoid(pose_out)
        
        fp = (img_out * m) + img
        fs = tf.concat([pose_out, fp], axis = -1)
        
        return fp, fs

In [9]:
class PATNetwork(tf.keras.layers.Layer):
    '''
        Pose-Attentional Transfer Network
    '''
    def __init__(self, dim = 256, norm = 'inst_norm', activation = 'relu', n_patb = 9, use_dropout = False, **kwargs):
        super().__init__(**kwargs)
        
        self.pat_blocks = []
        for i in range(n_patb):
            first_block = i == 0
            self.pat_blocks.append(PATBlock(filters = dim, norm = norm, activation = activation, 
                                            use_dropout = use_dropout, first_block = first_block))

    def call(self, inputs):
        assert len(inputs) == 2
        img, pose = inputs
        
        for patb in self.pat_blocks:
            img, pose = patb([img, pose])
            
        return img, pose

In [10]:
class DownSamplingBlock(tf.keras.layers.Layer):
    def __init__(self, dim, norm, activation, n_blocks = 2, **kwargs):
        super().__init__(**kwargs)
        
        self.model = []
        for i in range(n_blocks):
            self.model.append(Conv2DBlock(filters = dim * (i + 1), kernel_size = (3, 3), strides = (2, 2), 
                                          padding = (1, 1), norm = norm, activation = activation))
            
    def call(self, x):
        for m in self.model:
            x = m(x)
        return x

In [11]:
class UpSamplingBlock(tf.keras.layers.Layer):
    def __init__(self, dim, norm, activation, n_blocks = 2, **kwargs):
        super().__init__(**kwargs)
        
        self.model = []
        for i in range(n_blocks):
            self.model.append(tf.keras.layers.Conv2DTranspose(filters = dim//(i+1), kernel_size = (3, 3), 
                                                              strides = (2, 2), padding = 'same'))
            if norm is not None:
                self.model.append(Normalization(norm = norm))
            if activation is not None:
                self.model.append(Activation(activation = activation))
                
    def call(self, x):
        for m in self.model:
            x = m(x)
        return x

In [12]:
class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, filters = None, norm = 'batch_norm', **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.norm = norm
    
    def build(self, input_shape):
        filters = input_shape[-1] if self.filters is None else self.filters
        
        self.conv_1 = Conv2DBlock(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                                  norm = self.norm, activation = 'relu')
        self.conv_2 = Conv2DBlock(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = (1, 1), 
                                  norm = self.norm, activation = None)
        
    def call(self, inputs):
        return self.conv_2(self.conv_1(inputs)) + inputs

In [13]:
class Generator(object):
    def __init__(self, n_up_down = 2, dim = 64, norm = 'inst_norm', activation = 'relu', n_patb = 9, use_dropout = False):
        self.n_up_down = n_up_down
        self.dim = dim
        self.norm = norm
        self.activation = activation
        self.n_patb = n_patb
        self.use_dropout = use_dropout
        
    def generator(self, condition_img_shape, pose_shape):
        inp_condition_img = tf.keras.layers.Input(shape = condition_img_shape, dtype = tf.float32, 
                                                  name = 'condition_image_Pc')
        inp_condition_pose = tf.keras.layers.Input(shape = pose_shape, dtype = tf.float32, 
                                                   name = 'condition_pose_Sc')
        inp_target_pose = tf.keras.layers.Input(shape = pose_shape, dtype = tf.float32, 
                                                name = 'target_pose_St')
        
        # conv Fp_0
        fp = Conv2DBlock(filters = self.dim, kernel_size = (7, 7), strides = (1, 1), padding = (3, 3), 
                         norm = self.norm, activation = self.activation)(inp_condition_img)
        
        # conv Fs_0
        fs = tf.keras.layers.Concatenate()([inp_condition_pose, inp_target_pose])
        fs = Conv2DBlock(filters = self.dim, kernel_size = (7, 7), strides = (1, 1), padding = (3, 3), 
                         norm = self.norm, activation = self.activation)(fs)
        
        # Downsampling Blocks (inputs)
        fp = DownSamplingBlock(dim = self.dim * 2, norm = self.norm, activation = self.activation, 
                               n_blocks = self.n_up_down)(fp)
        fs = DownSamplingBlock(dim = self.dim * 2, norm = self.norm, activation = self.activation, 
                               n_blocks = self.n_up_down)(fs)
        
        # Pose-Attentional Transfer Network
        fp, fs = PATNetwork(dim = self.dim * (2**self.n_up_down), norm = self.norm, activation = self.activation, 
                            n_patb = self.n_patb, use_dropout = self.use_dropout)([fp, fs])
        
        # Upsampling Block
        fp = UpSamplingBlock(dim = self.dim * (2*self.n_up_down), norm = self.norm, activation = self.activation)(fp)
        
        # Converting into image (channels of 3)
        out = Conv2DBlock(filters = 3, kernel_size = (7, 7), strides = (1, 1), padding = (3, 3), 
                          norm = None, activation = 'tanh')(fp)
        
        return tf.keras.models.Model([inp_condition_img, inp_condition_pose, inp_target_pose], out, name = 'Generator')

In [14]:
class Discriminator(object):
    def __init__(self, dim = 64, n_down = 2, n_blocks = 6, logits = False):
        self.dim = dim
        self.n_down = n_down
        self.n_blocks = n_blocks
        self.logits = logits
    
    def discriminator(self, inpA_shape, inpB_shape):
        inpA = tf.keras.layers.Input(shape = inpA_shape, dtype = tf.float32, name = 'inputA')
        inpB = tf.keras.layers.Input(shape = inpB_shape, dtype = tf.float32, name = 'inputB')

        x = tf.keras.layers.Concatenate()([inpA, inpB])

        x = Conv2DBlock(filters = self.dim, kernel_size = (7, 7), strides = (1, 1), padding = (0, 0), norm = 'batch_norm', 
                        activation = 'leaky_relu')(x)

        dim = self.dim
        for i in range(self.n_down):
            if i < 2:
                dim *= 2
            x = Conv2DBlock(filters = dim, kernel_size = (3, 3), strides = (2, 2), padding = (1, 1), norm = 'batch_norm', 
                            activation = 'leaky_relu')(x)


        for _ in range(self.n_blocks):
            x = ResidualBlock(filters = dim, norm = 'batch_norm')(x)

        if self.logits:
            x = tf.keras.layers.Activation('sigmoid')(x)

        return tf.keras.models.Model([inpA, inpB], x, name = 'Discriminator')
    

In [15]:
class PerceptualLoss(object):
    def __init__(self, pt_model = None, pt_layers = [], p = 1):
        if p == 1:
            self.loss_object = lambda x, y: tf.math.reduce_mean(tf.math.abs(x - y))
        elif p == 2:
            self.loss_object = lambda x, y: tf.math.reduce_mean(tf.math.square(x - y))
        
        
        model = tf.keras.applications.VGG19(include_top = False, weights = 'imagenet') if pt_model is None else pt_model
        model.trainable = False
        
        layers = ['block1_conv2'] if len(pt_layers) == 0 else pt_layers
        outs = [model.get_layer(layer).output for layer in layers]
        
        self.model = tf.keras.models.Model(model.inputs, outs)
        self.preprocess_input = lambda x: tf.keras.applications.vgg19.preprocess_input(x)
        
    def __call__(self, real, pred):
        # convert images pixel range from (-1, 1) to (0, 255) before preprocessing
        real_outs = self.model(self.preprocess_input((real + 1)*127.5))
        gen_outs = self.model(self.preprocess_input((gen + 1)*127.5))
        
        if isinstance(real_outs, list):
            loss = 0
            for r, g in zip(real_outs, gen_outs):
                loss += self.loss_object(r, g)
        else:
            loss = self.loss_object(real_outs, gen_outs)
        return loss
    
class L_Loss(object):
    def __init__(self, p = 1):
        if p == 1:
            self.loss = lambda x, y: tf.math.reduce_mean(tf.math.abs(x - y))
        elif p == 2:
            self.loss = lambda x, y: tf.math.reduce_mean(tf.math.square(x - y))
        else:
            raise Exception('')
            
    def __call__(self, real, pred):
        return self.loss(real, pred)
    
class GANLoss(object):
    def __init__(self, loss_type = 'adversarial'):
        if loss_type == 'adversarial':
            self.loss = tf.keras.losses.BinaryCrossentropy(from_logits = True)
        elif loss_type == 'lsgan':
            self.loss = tf.keras.losses.MeanSquaredError()
    
    def discriminator_loss(self, disc_real_out_A, disc_gen_out_A, disc_real_out_S, disc_gen_out_S):
        assert disc_real_out_A.shape == disc_real_out_S.shape
        assert disc_gen_out_A.shape == disc_gen_out_S.shape
        real = self.loss(tf.ones_like(disc_real_out_A), disc_real_out_A*disc_real_out_S)
        gen = self.loss(tf.zeros_like(disc_gen_out_A), disc_gen_out_A*disc_gen_out_S)
        return real + gen
    
    def generator_loss(self, disc_gen_out_A, disc_gen_out_S):
        assert disc_gen_out_A.shape == disc_gen_out_S.shape
        return self.loss(tf.ones_like(disc_gen_out_A))

In [16]:
class Trainer(object):
    def __init__(self, generator, discriminator, img_shape, pose_shape, learning_rate = 2e-4, loss_type = 'adversarial', 
                 alpha = 5, lambda_1 = 1, lambda_2 = 1):
        self.alpha = alpha
        self.lambda_1 = lambda_1
        self.lambda_2 = lambda_2
        
        self.gen_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        self.disc_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 =0.5, beta_2 = 0.999)
        
        self.generator = generator.generator(img_shape, pose_shape)
        self.discriminator_A = discriminator.discriminator(img_shape, img_shape)
        self.discriminator_S = discriminator.discriminator(pose_shape, img_shape)
        
        self.perceptual_loss = PerceptualLoss(p = 1)
        self.l1_loss = L_Loss(p = 1)
        self.gan_loss = GANLoss(loss_type = loss_type)
        
    @tf.function
    def train_step(self, condition_img, condition_pose, target_img, target_pose):
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_out = self.generator([condition_img, condition_pose, target_pose], training = True)
            
            disc_real_out_A = self.discriminator_A([condition_img, target_img], training = True)
            disc_gen_out_A = self.discriminator_A([condition_img, gen_out], training = True)
            disc_real_out_S = self.discriminator_S([target_pose, target_img], training = True)
            disc_gen_out_S = self.discriminator_S([target_pose, gen_out], training = True)
            
            gen_loss = self.perceptual_loss(target_img, gen_out) * self.lambda_2
            gen_loss += self.l1_loss(target_img, gen_out) * self.lambda_1
            gen_loss += self.alpha * self.gan_loss.generator_loss(disc_gen_out_A, disc_gen_out_S)
            
            disc_loss = self.gan_loss.discriminator_loss(disc_real_out_A, disc_gen_out_A, disc_real_out_S, disc_gen_out_S)
            
        gen_grads = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.gen_optimizer.apply_gradients(zip(gen_grads, self.generator.trainable_variables))
        
        disc_params = self.discriminator_A.trainable_variables + self.discriminator_S.trainable_variables
        disc_grads = disc_tape.gradient(disc_loss, disc_params)
        self.disc_optimizer.apply_gradients(zip(disc_grads, disc_params))
        
        return gen_loss, disc_loss
    

    def train(self, data, epochs = 1):
        gen_losses, disc_losses = [], []
        for e in range(epochs):
            print(f'Epoch: {e} Starts')
            for condition_img, condition_pose, target_img, target_pose in data:
                gen_loss, disc_loss = self.train_step(condition_img, condition_pose, target_img, target_pose)
                print('.', end = '')
                
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)
            print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
            print(f'Epoch: {e} Ends.\n')
            
        return {'gen_losses': gen_losses, 'disc_losses': disc_losses}