# Stack Gan

This is an attempt to re-implement the paper Stack-GAN

Paper: https://arxiv.org/pdf/1612.03242.pdf

Other Resources: 
* https://github.com/hanzhanggit/StackGAN

In [1]:
import numpy as np
import tensorflow as tf

In [3]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, neurons, gain = tf.sqrt(2.0), use_bias = True, **kwargs):
        super().__init__(**kwargs)
        self.neurons = neurons
        self.gain = gain
        self.use_bias = use_bias
        
    def build(self, input_shape):
        inp_neurons = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (inp_neurons, self.neurons), initializer = init, 
                                 trainable = True, name = 'Linear_weight')
        
        if self.use_bias:
            self.B = self.add_weight(shape = (1, self.neurons), initializer = 'zeros', 
                                     trainable = True, name = 'Linear_bias')
        
        self.wscale = self.gain * tf.math.rsqrt(tf.cast(inp_neurons, tf.float32))
        
    def call(self, inputs):
        out = tf.matmul(inputs, self.W * self.wscale)
        if self.use_bias:
            out = tf.add(out, self.B)
        return out

In [4]:
class Conv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, gain = tf.sqrt(2.0), use_bias = True, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = 'SAME' if (kernel_size[0] - 1)//2 else 'VALID'
        self.gain = gain
        self.use_bias = use_bias
        
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.kernel = self.add_weight(shape = self.kernel_size + (inp_filters, self.filters), initializer = init, 
                                      trainable = True, name = 'Conv2D_kernel')
        if self.use_bias:
            self.bias = self.add_weight(shape = (1, self.filters), initializer = 'zeros', 
                                        trainable = True, name = 'Conv2D_bias')
        
        fan_in = tf.cast(self.kernel_size[0] * self.kernel_size[1] * inp_filters, tf.float32)
        self.wscale = self.gain * tf.math.rsqrt(fan_in)
        
    def call(self, inputs):
        out = tf.nn.conv2d(inputs, self.kernel * self.wscale, self.strides, self.padding, 'NHWC')
        if self.use_bias:
            out = tf.add(out, self.bias)
        return out

In [45]:
class ConditioningAugmentation(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        pass
    
    def call(self, inputs):
        mu, sigma = inputs
        epsilon = tf.random.normal(tf.shape(mu), mean=0.0, stddev=1.0)

        return mu + tf.exp(sigma * 0.5) * epsilon
    
class SpatialRepLication(tf.keras.layers.Layer):
    def __init__(self, rep_dim, **kwargs):
        super().__init__(**kwargs)
        self.rep_dim = rep_dim
    
    def call(self, inputs):
        x = tf.reshape(inputs, (-1, 1, 1, inputs.shape[-1]))
        x = tf.tile(x, [-1, self.rep_dim, self.rep_dim, 1])
        return x

In [46]:
def stage1_generator(embed_dim = 1024, noise_dim = 100, ca_dim = 256):
    inp_text_embed = tf.keras.layers.Input(shape = (embed_dim), dtype = tf.float32, 
                                           name = 'stage1_generator_text_embed_inp')
    inp_noise = tf.keras.layers.Input(shape = (noise_dim), dtype = tf.float32, name = 'stage1_generator_noise_inp')
    
    mu = tf.keras.layers.ReLU()(Linear(neurons = ca_dim//2)(inp_text_embed))
    sigma = tf.keras.layers.ReLU()(Linear(neurons = ca_dim//2)(inp_text_embed))
    
    x = ConditioningAugmentation()([mu, sigma])
    x = tf.keras.layers.Concatenate()([x, inp_noise])
    
    x = Linear(neurons = 4 * 4 * 1024)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Reshape((4, 4, 1024))(x)
    
    x = tf.keras.layers.UpSampling2D(size = (2, 2))(x)
    x = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = tf.keras.layers.UpSampling2D(size = (2, 2))(x)
    x = Conv2D(filters = 256, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = tf.keras.layers.UpSampling2D(size = (2, 2))(x)
    x = Conv2D(filters = 128, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = tf.keras.layers.UpSampling2D(size = (2, 2))(x)
    x = Conv2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = Conv2D(filters = 3, kernel_size = (3, 3), strides = (1, 1), use_bias = True)(x)
    x = tf.keras.layers.Activation('tanh')(x)
    
    return tf.keras.models.Model([inp_text_embed, inp_noise], [mu, sigma, x], name = 'stage1_generator')

In [47]:
def stage2_generator(embed_dim = 1024, ca_dim = 256, inp_img_dim = 64):
    inp_text_embed = tf.keras.layers.Input(shape = (embed_dim), dtype = tf.float32, 
                                           name = 'stage1_generator_text_embed_inp')
    inp_stage1_img = tf.keras.layers.Input(shape = (inp_img_dim, inp_img_dim, 3), dtype = tf.float32, 
                                           name = 'stage2_generator_stage1_img_inp')
    
    # Conditioning Augmentation
    mu = tf.keras.layers.ReLU()(Linear(neurons = ca_dim//2)(inp_text_embed))
    sigma = tf.keras.layers.ReLU()(Linear(neurons = ca_dim//2)(inp_text_embed))
    
    ca = ConditioningAugmentation()([mu, sigma])
    ca = SpatialRepLication(rep_dim = 16)(ca)
    
    # Downsampling
    x = Conv2D(filters = 128, kernel_size = (3, 3), strides = (1, 1), use_bias = True)(inp_stage1_img)
    x = tf.keras.layers.ReLU()(x)
    
    x = Conv2D(filters = 256, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = Conv2D(filters = 512, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    # ResBlock
    out_r = tf.keras.layers.Concatenate(axis = -1)([x, ca])
    
    out_r = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out_r)
    out_r = tf.keras.layers.BatchNormalization()(out_r)
    out_r = tf.keras.layers.ReLU()(out_r)
    
    #1
    out_r1 = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out_r)
    out_r1 = tf.keras.layers.BatchNormalization()(out_r1)
    out_r1 = tf.keras.layers.ReLU()(out_r1)
    
    out_r1 = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out_r1)
    out_r1 = tf.keras.layers.BatchNormalization()(out_r1)
    out_r1 = tf.keras.layers.ReLU()(out_r1 + out_r)
    
    #2
    out_r2 = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out_r1)
    out_r2 = tf.keras.layers.BatchNormalization()(out_r2)
    out_r2 = tf.keras.layers.ReLU()(out_r2)
    
    out_r2 = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out_r2)
    out_r2 = tf.keras.layers.BatchNormalization()(out_r2)
    out = tf.keras.layers.ReLU()(out_r2 + out_r1)
    
    # UpSampling
    out = tf.keras.layers.UpSampling2D(size = (2, 2))(out)
    out = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out)
    out = tf.keras.layers.BatchNormalization()(out)
    out = tf.keras.layers.ReLU()(out)
    
    out = tf.keras.layers.UpSampling2D(size = (2, 2))(out)
    out = Conv2D(filters = 128, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out)
    out = tf.keras.layers.BatchNormalization()(out)
    out = tf.keras.layers.ReLU()(out)
    
    out = tf.keras.layers.UpSampling2D(size = (2, 2))(out)
    out = Conv2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out)
    out = tf.keras.layers.BatchNormalization()(out)
    out = tf.keras.layers.ReLU()(out)
    
    out = tf.keras.layers.UpSampling2D(size = (2, 2))(out)
    out = Conv2D(filters = 32, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out)
    out = tf.keras.layers.BatchNormalization()(out)
    out = tf.keras.layers.ReLU()(out)
    
    out = Conv2D(filters = 3, kernel_size = (3, 3), strides = (1, 1), use_bias = True)(out)
    out = tf.keras.layers.Activation('tanh')(out)
    
    return tf.keras.models.Model([inp_stage1_img, inp_text_embed], [mu, sigma, out], name = 'stage2_generator')

In [48]:
def stage1_discriminator(img_inp = 64, mu_dim = 128):
    inp = tf.keras.layers.Input(shape = (img_inp, img_inp, 3), dtype = tf.float32, name = 'stage1_discriminator_input')
    mu_inp = tf.keras.layers.Input(shape = (mu_dim), dtype = tf.float32, name = 'stage1_discriminator_mu_input')
    
    x = Conv2D(filters = 64, kernel_size = (4, 4), strides = (2, 2), use_bias = True)(inp)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 128, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 64, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 512, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    mu = SpatialRepLication(rep_dim = 4)(mu_inp)
    out = tf.keras.layers.Concatenate(axis = -1)([x, mu])
    
    out = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out)
    out = tf.keras.layers.BatchNormalization()(out)
    out = tf.keras.layers.LeakyReLU(alpha = 0.2)(out)
    
    out = Conv2D(filters = 1, kernel_size = (4, 4), strides = (4, 4), use_bias = True)(out)
    return tf.keras.models.Model([inp, mu_inp], out, name = 'stage1_discriminator')

In [55]:
def stage2_discriminator(img_inp = 256, mu_dim = 128):
    inp = tf.keras.layers.Input(shape = (img_inp, img_inp, 3), dtype = tf.float32, name = 'stage2_discriminator_input')
    mu_inp = tf.keras.layers.Input(shape = (mu_dim), dtype = tf.float32, name = 'stage2_discriminator_mu_input')
    
    x = Conv2D(filters = 64, kernel_size = (4, 4), strides = (2, 2), use_bias = True)(inp)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 128, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 256, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 512, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 1024, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 2048, kernel_size = (4, 4), strides = (2, 2), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 1024, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    mu = SpatialRepLication(rep_dim = 4)(mu_inp)
    out = tf.keras.layers.Concatenate(axis = -1)([x, mu])
    
    out = Conv2D(filters = 512, kernel_size = (3, 3), strides = (1, 1), use_bias = False)(out)
    out = tf.keras.layers.BatchNormalization()(out)
    out = tf.keras.layers.LeakyReLU(alpha = 0.2)(out)
    
    out = Conv2D(filters = 1, kernel_size = (4, 4), strides = (4, 4), use_bias = False)(out)
    return tf.keras.models.Model([inp, mu_inp], out, name = 'stage2_discriminator')