# Pix2PixHD - Image Translation

This is an attempt to re-implement the paper pix2pixHD

Paper: https://arxiv.org/pdf/1711.11585.pdf

Other Resources: 
* https://tcwang0509.github.io/pix2pixHD

In [1]:
import tensorflow as tf

In [2]:
class InstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.gamma = self.add_weight(shape = (1, 1, 1, inp_filters), initializer = init, 
                                     trainable = True, name = 'gamma')
        self.beta = self.add_weight(shape = (1, 1, 1, inp_filters), initializer = init, 
                                    trainable = True, name = 'beta')
        
    def call(self, inputs):
        mean = tf.math.reduce_mean(inputs, axis = [1, 2], keepdims = True)
        rstd = tf.math.rsqrt(tf.math.reduce_variance(inputs, axis = [1, 2], keepdims = True) + self.epsilon)
        out = self.gamma * ((inputs - mean) * rstd) + self.beta
        return out

In [3]:
class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding = ((1, 1), (1, 1)), **kwargs):
        super().__init__(**kwargs)
        
        if isinstance(padding, int):
            self.padding = (padding, padding), (padding, padding)
        elif isinstance(padding, tuple) | isinstance(padding, list):
            if isinstance(padding[0], int):
                self.padding = (padding[0], padding[0]), (padding[1], padding[1])
            elif isinstance(padding[0], tuple) | isinstance(padding[0], list):
                self.padding = (padding[0][0], padding[0][1]), (padding[1][0], padding[1][1])
            else:
                raise Exception('Invalid padding input:\nValid inputs are\n(for examples)-', 
                                '\n-> 1\n-> (1, 1)\n->((1, 1), (1, 1))')
        else:
            raise Exception('Invalid padding input:\nValid inputs are\n(for examples)-', 
                                '\n-> 1\n-> (1, 1)\n->((1, 1), (1, 1))')
        
        
    def call(self, inputs):
        (pad_h0, pad_h1), (pad_w0, pad_w1) = self.padding
        return tf.pad(inputs, [[0, 0], [pad_h0, pad_h1], [pad_w0, pad_w1], [0, 0]], mode = 'REFLECT')

In [4]:
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, padding, norm = 'inst_norm', activation = 'relu', **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides

        if isinstance(padding, str):
            if padding.upper() in ['SAME', 'VALID']:
                self.padding = padding.upper()
            else:
                raise Exception('Invalid padding, available string paddings are : `SAME` & `VALID`.')
        else:
            self.padding = ReflectionPadding2D(padding = padding)

        if norm is not None:
            if norm == 'inst_norm':
                self.norm = InstanceNormalization()
            # elif norm == 'batch_norm':
            #     self.norm = tf.keras.layers.BatchNormalization()
            else:
                self.norm = norm
        else:
            self.norm = None

        if activation is not None:
            if activation == 'relu':
                self.activation = tf.keras.layers.ReLU()
            elif activation == 'leaky_relu':
                self.activation = tf.keras.layers.LeakyReLU(alpha = 0.2)
            elif activation == 'tanh':
                self.activation = tf.keras.layers.Activation('tanh')
            else:
                self.activation = activation
        else:
            self.activation = None
                
    def build(self, input_shape):
        input_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.W = self.add_weight(shape = self.kernel_size + (input_filters, self.filters), initializer = init, 
                                 trainable = True, name = 'weight')
        self.B = self.add_weight(shape = (1, 1, 1, self.filters), initializer = 'zeros', 
                                 trainable = True, name = 'bias')
        
    def call(self, inputs):
        if isinstance(self.padding, str):
            x = tf.add(tf.nn.conv2d(inputs, self.W, self.strides, self.padding, 'NHWC'), self.B)
        else:
            x = self.padding(inputs)
            x = tf.add(tf.nn.conv2d(x, self.W, self.strides, 'VALID', data_format = 'NHWC'), self.B)
            
        if self.norm is not None:
            x = self.norm(x)
            
        if self.activation is not None:
            x = self.activation(x)
            
        return x

In [5]:
class ResidualBlock(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        
        self.reflect_pad_1 = ReflectionPadding2D(((1, 1), (1, 1)))
        self.conv_1 = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'valid')
        self.norm_1 = InstanceNormalization()
        self.act_1 = tf.keras.layers.ReLU()
        
        self.reflect_pad_2 = ReflectionPadding2D(((1, 1), (1, 1)))
        self.conv_2 = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'valid')
        self.norm_2 = InstanceNormalization()
        
    def call(self, inputs):
        x = self.act_1(self.norm_1(self.conv_1(self.reflect_pad_1(inputs))))
        x = self.norm_2((self.conv_2(self.reflect_pad_2(x))))
        return x + inputs

In [6]:
class DownSampleBlock(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (2, 2), padding = 'same')
        self.norm = InstanceNormalization()
        self.act = tf.keras.layers.ReLU()
        
    def call(self, inputs):
        return self.act(self.norm(self.conv(inputs)))

In [7]:
class UpSampleBlock(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv_transpose = tf.keras.layers.Conv2DTranspose(filters = filters, kernel_size = (3, 3), strides = (2, 2), 
                                                              padding = 'same')
        self.norm = InstanceNormalization()
        self.act = tf.keras.layers.ReLU()
        
    def call(self, inputs):
        return self.act(self.norm(self.conv_transpose(inputs)))

In [8]:
class InstanceWiseAveragePooling(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        pass
    
    def call(self, inputs):
        # inp.shape = (batch, h, w, chn)
        # inst.shape = (batch, h, w, no. of instances) # one-hot encoded
        inp, inst = inputs
        assert inst.shape[-1] != 1
        
        inst_sum = tf.math.reduce_sum(inst, axis = [1, 2], keepdims = True)
        
        inp = tf.expand_dims(inp, axis = -1)
        inst = tf.expand_dims(inst/inst_sum, axis = -2)
        
        out = tf.multiply(inp, inst)
        out = tf.math.reduce_sum(out, axis = -1)
        return out

In [33]:
class Pix2PixHD(object):
    def __init__(self, img_shape, num_labels):
        self.img_shape = img_shape
        self.num_labels = num_labels
    
    def generator(self, n_chn = 64, up_down_block = 4, res_blocks = 9):
        inp_seg = tf.keras.layers.Input(shape = self.img_shape[:-1] + (self.num_labels, ), dtype = tf.float32, 
                                        name = 'generator_seg_input')
        inp_boundary_map = tf.keras.layers.Input(shape = self.img_shape[:-1] + (1, ), dtype = tf.float32, 
                                                 name = 'generator_boundary_map')
        inp_encoded = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, 
                                            name = 'generator_encoded_input')
        
        x = tf.keras.layers.Concatenate(axis = -1)([inp_seg, inp_boundary_map, inp_encoded])
        down_sampled_g1_inp = tf.keras.layers.AveragePooling2D(pool_size = (3, 3), strides = (2, 2), padding = 'same')(x)
        
        # Local Enhancer G2
        x = ConvBlock(filters = 32, kernel_size = (7, 7), strides = (1, 1), padding = ((3, 3), (3, 3)), 
                      norm = 'inst_norm', activation = 'relu')(x)
        x = ConvBlock(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'same', 
                      norm = 'inst_norm', activation = 'relu')(x)
        
        
        # Global Generator G1
        def globalGeneratorG1(x):
        
            x = ConvBlock(filters = n_chn, kernel_size = (7, 7), strides = (1, 1), padding = ((3, 3), (3, 3)), 
                          norm = 'inst_norm', activation = 'relu')(x)

            for _ in range(up_down_block):
                n_chn *= 2
                x = DownSampleBlock(n_chn)(x)

            for _ in range(res_blocks):
                x = ResidualBlock(n_chn)(x)

            for i in range(up_down_block):
                n_chn //= 2
                x = UpSampleBlock(n_chn)(x)

            out = ConvBlock(filters = 3, kernel_size = (7, 7), strides = (1, 1), padding = ((3, 3), (3, 3)), 
                            norm = 'inst_norm', activation = 'tanh')
            return x, out

        # Local Enhancer
        g1_x, g1_out = globalGeneratorG1(down_sampled_g1_inp)
        
        x = tf.keras.layers.Add()([x, g1_x])
        
        for _ in range(3):
            x = ResidualBlock(filters = 64)(x)
            
        x = UpSampleBlock(filters = 32)(x)
        x = ConvBlock(filters = 3, kernel_size = (7, 7), strides = (1, 1), padding = ((3, 3), (3, 3)), 
                      norm = 'inst_norm', activation = 'tanh')(x)
        
        return tf.keras.models.Model([inp_seg, inp_boundary_map, inp_encoded], x, name = 'Generator')
        
    def encoder(self, n_chn = 32, num_blocks = 4):
        inp = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, name = 'encoder_input')
        
        x = ConvBlock(filters = n_chn, kernel_size = (7, 7), strides = (1, 1), padding = ((3, 3), (3, 3)), 
                      norm = 'inst_norm', activation = 'relu')(inp)

        # down_sample
        for _ in range(num_blocks):
            n_chn *= 2
            x = DownSampleBlock(filters = n_chn)(x)

        # up_sample
        for _ in range(num_blocks):
            x = UpSampleBlock(filters = n_chn)(x)
            n_chn //= 2

        x = ConvBlock(filters = 3, kernel_size = (7, 7), strides = (1, 1), padding = ((3, 3), (3, 3)), 
                      norm = 'inst_norm', activation = 'tanh')(x)
        

        inst_inp = tf.keras.layers.Input(shape = self.img_shape[:-1] + (self.num_labels, ), dtype = tf.float32, 
                                         name = 'inst_input')
        x = InstanceWiseAveragePooling()([x, inst_inp])
            
        return tf.keras.models.Model([inp, inst_inp], x, name = 'encoder')
    
    def discriminator(self, img_shape):
        inp_seg = tf.keras.layers.Input(shape = img_shape[:-1] + (self.num_labels, ), dtype = tf.float32, 
                                    name = f'discriminator_seg_input_{img_shape[0]}x{img_shape[1]}x{self.num_labels}')
        inp_boundary_map = tf.keras.layers.Input(shape = img_shape[:-1] + (1, ), dtype = tf.float32, 
                                                 name = f'discriminator_boundary_map_{img_shape[0]}x{img_shape[1]}x1')
        inp_img = tf.keras.layers.Input(shape = img_shape, dtype = tf.float32, 
                                        name = f'discriminator_img_inp_{img_shape[0]}x{img_shape[1]}x{img_shape[2]}')
        
        x = tf.keras.layers.Concatenate()([inp_seg, inp_boundary_map, inp_img])
        
        
        x1 = ConvBlock(filters = 64, kernel_size = (4, 4), strides = (2, 2), padding = 'same', 
                       norm = None, activation = 'leaky_relu')(x)
        x2 = ConvBlock(filters = 128, kernel_size = (4, 4), strides = (2, 2), padding = 'same', 
                       norm = 'inst_norm', activation = 'leaky_relu')(x1)
        x3 = ConvBlock(filters = 256, kernel_size = (4, 4), strides = (2, 2), padding = 'same', 
                       norm = 'inst_norm', activation = 'leaky_relu')(x2)
        x4 = ConvBlock(filters = 512, kernel_size = (4, 4), strides = (2, 2), padding = 'same', 
                       norm = 'inst_norm', activation = 'leaky_relu')(x3)
        
        x5 = ConvBlock(filters = 1, kernel_size = (4, 4), strides = (1, 1), padding = 'same', 
                       norm = None, activation = None)(x4)
        # x5 = tf.keras.layers.Activation('sigmoid')(x5)
        
        return tf.keras.models.Model([inp_seg, inp_boundary_map, inp_img], [x1, x2, x3, x4, x5], 
                                     name = f'discriminator_{img_shape[0]}x{img_shape[1]}x{img_shape[2]}')

In [32]:
gan = Pix2PixHD(img_shape = (2048, 1024, 3), num_labels = 35)
generator = gan.generator()
encoder = gan.encoder()
discriminator = gan.discriminator((2048, 1024, 3)) #(1024, 512, 3), (512, 256, 3)

In [27]:
class Losses(object):
    def __init__(self):
        self.bce = tf.keras.losses.BinaryCrossentropy(from_logits = True)
        self.__set_vgg
        
    @property
    def __set_vgg(self):
        vgg_model = tf.keras.applications.VGG19(include_top = False, weights = 'imagenet')
        vgg_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']
        self.vgg_weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
        self.preprocess_input = tf.keras.applications.vgg19.preprocess_input
        
        outs = [vgg_model.get_layer(layer).output for layer in vgg_layers]
        self.our_model = tf.keras.models.Model(vgg_model.inputs, outs)
        
    def vgg_loss(self, real, gen):
        real_processed = self.preprocess_input((real + 1) * 127.5)
        gen_processed = self.preprocess_input((gen + 1) * 127.5)
        real_outs = self.our_model(real_processed)
        gen_outs = self.our_model(gen_processed)
        
        loss = 0
        for i in range(len(real_outs)):
            loss += tf.math.reduce_mean(tf.abs(real_outs[i] - gen_outs[i])) * self.vgg_weights[i]
        return loss

    def gan_loss(self, true, pred, gan_loss_type):
        if gan_loss_type == 'adversarial':
            return self.bce(true, pred)
        elif gan_loss_type == 'lsgan':
            return tf.math.reduce_mean(tf.math.square(true - pred))
    
    def feature_matching_loss(self, real, pred):
        loss = 0
        for r, p in zip(real, pred):
            loss += tf.math.reduce_mean(tf.abs(r - p))
        return loss

In [None]:
class Trainer(object):
    def __init__(self, img_shape = (2048, 1024, 3), num_labels = 35, learning_rate = 2e-4, gan_loss_type = 'lsgan'):
        self.gan_loss_type = gan_loss_type
        
        self.gen_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        self.disc_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        
        gan = Pix2PixHD(img_shape = img_shape, num_labels = num_labels)
        self.encoder = gan.encoder(n_chn = 32, num_blocks = 4)
        self.generator = gan.generator(n_chn = 64, up_down_block = 4, res_blocks = 9)
        
        self.discriminator_D1 = gan.discriminator(img_shape = img_shape)
        self.discriminator_D2 = gan.discriminator(img_shape = (img_shape[0]//2, img_shape[1]//2, img_shape[2]))
        self.discriminator_D3 = gan.discriminator(img_shape = (img_shape[0]//4, img_shape[1]//4, img_shape[2]))
        
        self.losses = Losses()
        
    def down_scale_img(self, imgs, scale):
        outs = []
        for img in imgs:
            outs.append(tf.image.resize(img, [img.shape[1]//scale, img.shape[2]//scale], 
                                        tf.image.ResizeMethod.NEAREST_NEIGHBOR))
        return outs
    
    def generator_loss(self, real_img, gen_img, discs_real_outs, discs_gen_outs, lambda_ = 10):
        vgg_loss = self.losses.vgg_loss(real_img, gen_img)
        gan_loss, fm_loss = 0, 0
        for i, (dr, dg) in enumerate(zip(discs_real_outs, discs_gen_outs)): 
            gan_loss += self.losses.gan_loss(tf.ones_like(dr.pop()), dg.pop(), self.gan_loss_type)
            fm_loss += self.losses.feature_matching_loss(dr, dg)
        
        gan_loss /= (i + 1)
        fm_loss /= (i + 1)
            
        return lambda_ * (vgg_loss + fm_loss) + gan_loss
    
    def discriminator_loss(self, discs_real_outs, discs_gen_outs):
        loss = 0
        for dr, dg in zip(discs_real_outs, discs_gen_outs):
            loss += self.losses.gan_loss(tf.ones_like(dr.pop()), dr.pop(), self.gan_loss_type)
            loss += self.losses.gan_loss(tf.zeros_like(dg.pop()), dg.pop(), self.gan_loss_type)    
        return loss / len(discs_real_outs)
        
    @tf.function
    def train_step(self, inp_seg, inp_boundary_map, real_img):
        img_shp = real_img.shape
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            enc_out = self.encoder([real_img, inp_seg], training = True)
            gen_out = self.generator([inp_seg, inp_boundary_map, enc_out], training = True)
            
            disc1_real_out = self.discriminator_D1([inp_seg, inp_boundary_map, real_img], training = True)
            disc2_real_out = self.discriminator_D2(self.down_scale_img([inp_seg, inp_boundary_map, real_img], 2), 
                                                   training = True)
            disc3_real_out = self.discriminator_D3(self.down_scale_img([inp_seg, inp_boundary_map, real_img], 4), 
                                                   training = True)
            
            disc1_gen_out = self.discriminator_D1([inp_seg, inp_boundary_map, gen_out], training = True)
            disc2_gen_out = self.discriminator_D2(self.down_scale_img([inp_seg, inp_boundary_map, gen_out], 2), 
                                                   training = True)
            disc3_gen_out = self.discriminator_D3(self.down_scale_img([inp_seg, inp_boundary_map, gen_out], 4), 
                                                   training = True)
            
            discs_real_outs = [disc1_real_out, disc2_real_out, disc3_real_out]
            discs_gen_outs = [disc1_gen_out, disc2_gen_out, disc3_gen_out]
            
            gen_loss = self.generator_loss(real_img, gen_out, discs_real_outs, discs_gen_outs)
            disc_loss = self.discriminator_loss(discs_real_outs, discs_gen_outs)
            
        gen_params = self.encoder.trainable_variables + self.generator.trainable_variables
        gen_grads = gen_tape.gradient(gen_loss, gen_params)
        self.gen_optimizer.apply_gradients(zip(gen_grads, gen_params))
        
        disc_params = self.discriminator_D1.trainable_variables + self.discriminator_D2.trainable_variables \
                       + self.discriminator_D3.trainable_variables
        disc_grads = disc_tape.gradient(disc_loss, disc_params)
        self.disc_optimizer.apply_gradients(zip(disc_grads, disc_params))
        
        return gen_loss, disc_loss
    
    def train(self, data, epochs = 1):
        gen_losses, disc_losses = [], []
        for e in range(epochs):
            print(f'Epoch: {e} Starts')
            for inp_seg, inp_boundary_map, real_img in data:
                gen_loss, disc_loss = self.train_step(inp_seg, inp_boundary_map, real_img)
                print('.', end = '')
                
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)
            print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
            print(f'Epoch: {e} Ends\n')
        return {'gen_loss': gen_loss, 'disc_loss': disc_loss}