# Progressive Gan

This is an attempt to re-implement the paper Progressive gan

Paper: https://arxiv.org/pdf/1710.10196.pdf

Other Resources: 
* https://github.com/tkarras/progressive_growing_of_gans
* https://github.com/fabulousjeong/pggan-tensorflow

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [None]:
class PixelNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        
    def call(self, inputs):
        return inputs * tf.math.rsqrt(tf.math.reduce_mean(tf.square(inputs), axis = -1, keepdims = True) + self.epsilon)

In [None]:
class MinibatchStddev(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
    
    def call(self, inputs):
        inp_shp = tf.shape(inputs)
        
        mean = tf.math.reduce_mean(inputs, axis = 0, keepdims = True)
        std = tf.math.sqrt(tf.math.reduce_mean(tf.square(inputs - mean), axis = 0, keepdims = True) + self.epsilon)
        avg_std = tf.math.reduce_mean(std, keepdims = True)
        tiled = tf.tile(avg_std, (inp_shp[0], inp_shp[1], inp_shp[2], 1))
        combined = tf.concat([inputs, tiled], axis = -1)
        return combined

In [None]:
class Dense(tf.keras.layers.Layer):
    def __init__(self, units, gain = np.sqrt(2), **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.gain = gain
        
    def build(self, input_shape):
        inp_chn = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (inp_chn, self.units), initializer = init, 
                                 trainable = True, name = 'Weight')
        self.B = self.add_weight(shape = (self.units, ), initializer = 'zeros', 
                                 trainable = True, name = 'Bias')
        
        self.w_scale = self.gain * tf.math.rsqrt(tf.cast(inp_chn, tf.float32))
        
    def call(self, inputs):
        return tf.add(tf.matmul(inputs, self.W * self.w_scale), self.B)

In [None]:
class Conv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size = (3, 3), strides = (1, 1), gain = np.sqrt(2), **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.gain = gain
        self.padding = 'SAME' if (kernel_size[0]-1)//2 else 'VALID'
        
    def build(self, input_shape):
        inp_chn = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (self.kernel_size[0], self.kernel_size[1], inp_chn, self.filters), 
                                 initializer = init, trainable = True, name = 'Weight')
        self.B = self.add_weight(shape = (self.filters, ), initializer = 'zeros', 
                                 trainable = True, name = 'Bias')
        
        fan_in = tf.cast(self.kernel_size[0] * self.kernel_size[1] * inp_chn, tf.float32)
        self.w_scale = self.gain * tf.math.rsqrt(fan_in)
        
    def call(self, inputs):
        return tf.add(tf.nn.conv2d(inputs, self.W * self.w_scale, self.strides, self.padding, data_format = 'NHWC'), self.B)

In [None]:
class WeightedSum(tf.keras.layers.Layer):
    def __init__(self, alpha = None, **kwargs):
        super().__init__(**kwargs)
        if alpha is not None: 
            self.alpha = alpha
        else:
            self.alpha = tf.Variable(0.0, name = 'ws_alpha')
            
    def call(self, inputs):
        assert len(inputs) == 2
        return ((1 - self.alpha) * inputs[0]) + (self.alpha * inputs[1])

In [None]:
FILTERS = [512, 512, 512, 512, 256, 128, 64, 32, 16]
class ProgressiveGAN(tf.keras.models.Model):
    def __init__(self, latent_dim = 512, d_steps = 1, gp_weight = 10, drift_weight = 0.001):
        super().__init__()
        self.latent_dim = latent_dim
        self.gp_weight = gp_weight
        self.d_steps = d_steps
        self.drift_weight = drift_weight
        
        self.alpha = tf.Variable(0.0, 'ws_alpha') # for temporary
        self.num_block = 2
        
        self.generator = self.__init_generator
        self.discriminator = self.__init_discriminator
        
    def call(self, inputs):
        return
    
    @property
    def __init_generator(self):
        res = 2**self.num_block
        
        inp = tf.keras.layers.Input(shape = (1, 1, self.latent_dim), dtype = tf.float32, name = 'Latent_input')
        x = PixelNormalization()(inp)
        
        x = Dense(units = res*res*self.latent_dim, gain = np.sqrt(2)/res)(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        x = PixelNormalization()(x)
        x = tf.keras.layers.Reshape((res, res, self.latent_dim))(x)
        
        x = Conv2D(filters = FILTERS[self.num_block - 2], kernel_size = (4, 4), strides = (1, 1), gain = np.sqrt(2))(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        x = PixelNormalization()(x)
        
        x = Conv2D(filters = FILTERS[self.num_block - 2], kernel_size = (3, 3), strides = (1, 1), gain = np.sqrt(2))(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        x = PixelNormalization()(x)
        
        x = Conv2D(filters = 3, kernel_size = (1, 1), strides = (1, 1), gain = 1.0, name = f'to_rgb_conv_{res}x{res}')(x)
        x = tf.keras.layers.Activation('tanh', name = f'to_rgb_act_{res}x{res}')(x)
        
        return tf.keras.models.Model(inp, x, name = f'generator_{res}x{res}')
        
    
    @property
    def __grow_generator(self):
        res = 2**self.num_block
        prev_res = 2**(self.num_block-1)
        for n, layer in enumerate(self.generator.layers):
            if layer.name == f'to_rgb_conv_{prev_res}x{prev_res}':
                to_rgb_conv = layer
                t = n - 1
            elif layer.name == f'to_rgb_act_{prev_res}x{prev_res}':
                to_rgb_act = layer
                
                
        end_block = self.generator.layers[t].output
        up_sample = tf.keras.layers.UpSampling2D()(end_block)
        
        x1 = to_rgb_conv(up_sample)
        x1 = to_rgb_act(x1)
        
        x2 = Conv2D(filters = FILTERS[self.num_block-2], kernel_size = (3, 3), strides = (1, 1), gain = np.sqrt(2))(up_sample)
        x2 = tf.keras.layers.LeakyReLU(alpha = 0.2)(x2)
        x2 = PixelNormalization()(x2)
        
        x2 = Conv2D(filters = FILTERS[self.num_block-2], kernel_size = (3, 3), strides = (1, 1), gain = np.sqrt(2))(x2)
        x2 = tf.keras.layers.LeakyReLU(alpha = 0.2)(x2)
        x2 = PixelNormalization()(x2)
        
        x2 = Conv2D(filters = 3, kernel_size = (1, 1), strides = (1, 1), gain = 1.0, name = f'to_rgb_conv_{res}x{res}')(x2)
        # x2 = tf.keras.layers.LeakyReLU(alpha = 0.2, name = f'to_rgb_act_{res}x{res}')(x2)
        x2 = tf.keras.layers.Activation('tanh', name = f'to_rgb_act_{res}x{res}')(x2)
        
        x = WeightedSum(self.alpha)([x1, x2])
        
        self.generator = tf.keras.models.Model(self.generator.inputs, x, name = f'generator_{res}x{res}')
        self.stabilized_generator = tf.keras.models.Model(self.generator.inputs, x2, name = f'stabilized_generator_{res}x{res}')
        
    
    @property
    def __init_discriminator(self):
        res = 2**self.num_block
        inp = tf.keras.layers.Input(shape = (res, res, 3), name = f'discriminator_input_{res}x{res}')
        
        x = Conv2D(filters = FILTERS[self.num_block - 2], kernel_size = (1, 1), strides = (1, 1), gain = np.sqrt(2), 
                   name = f'from_rgb_conv_{res}x{res}')(inp)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2, name = f'from_rgb_act_{res}x{res}')(x)
        x = MinibatchStddev()(x)
        
        x = Conv2D(filters = FILTERS[self.num_block - 2], kernel_size = (3, 3), strides = (1, 1), gain = np.sqrt(2))(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        x = Conv2D(filters = FILTERS[self.num_block - 2], kernel_size = (4, 4), strides = (4, 4), gain = np.sqrt(2))(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Flatten()(x)
        x = Dense(units = 1, gain = 1.0)(x)
        
        return tf.keras.models.Model(inp, x, name = f'discriminator_{res}x{res}')
    
    @property
    def __grow_discriminator(self):
        res = 2**self.num_block
        prev_res = 2**(self.num_block - 1)
        
        inp = tf.keras.layers.Input(shape = (res, res, 3), dtype = tf.float32, name = f'discriminator_input_{res}x_{res}')
        
        for n, layer in enumerate(self.discriminator.layers):
            if layer.name == f'from_rgb_conv_{prev_res}x{prev_res}':
                from_rgb_conv = layer
            elif layer.name == f'from_rgb_act_{prev_res}x{prev_res}':
                from_rgb_act = layer
                t = n + 1
                
        x1 = tf.keras.layers.AveragePooling2D()(inp)
        x1 = from_rgb_conv(x1)
        x1 = from_rgb_act(x1)
        
        x2 = Conv2D(filters = FILTERS[self.num_block - 2], kernel_size = (1, 1), strides = (1, 1), gain = np.sqrt(2), 
                    name = f'from_rgb_conv_{res}x{res}')(inp)
        x2 = tf.keras.layers.LeakyReLU(alpha = 0.2, name = f'from_rgb_act_{res}x{res}')(x2)
        
        x2 = Conv2D(filters = FILTERS[self.num_block - 2], kernel_size = (3, 3), strides = (1, 1), gain = np.sqrt(2))(x2)
        x2 = tf.keras.layers.LeakyReLU(alpha = 0.2)(x2)
        x2 = Conv2D(filters = FILTERS[self.num_block - 3], kernel_size = (3, 3), strides = (1, 1), gain = np.sqrt(2))(x2)
        x2 = tf.keras.layers.LeakyReLU(alpha = 0.2)(x2)
        
        x2 = tf.keras.layers.AveragePooling2D()(x2)
        
        x = WeightedSum(self.alpha)([x1, x2])
        
        for i in range(t, len(self.discriminator.layers)):
            x = self.discriminator.layers[i](x)
            x2 = self.discriminator.layers[i](x2)
            
        self.discriminator = tf.keras.models.Model(inp, x, name = f'discriminator_{res}x{res}')
        self.stabilized_discriminator = tf.keras.models.Model(inp, x2, name = f'stabilized_discriminator_{res}x{res}')
                
    
    def grow_model(self):
        self.num_block += 1
        self.__grow_generator
        self.__grow_discriminator
    
    def stabilize_model(self):
        self.generator = self.stabilized_generator
        self.discriminator = self.stabilized_discriminator
    
    def generator_loss(self, disc_gen_out):
        return -tf.reduce_mean(disc_gen_out)
    
    def gradient_penalty(self, real_img, gen_img):
        #print(real_img.shape, gen_img.shape)
        #assert tf.shape(real_img) == tf.shape(gen_img)
        batch_size = tf.shape(real_img)[0]
        
        epsilon = tf.random.uniform((batch_size, 1, 1, 1))
        interpolated_img = ((1 - epsilon) * real_img) + (epsilon * gen_img)
        
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated_img)
            pred = self.discriminator(interpolated_img)
        grads = gp_tape.gradient(pred, [interpolated_img])[0]
        norm = tf.math.sqrt(tf.math.reduce_mean(tf.square(grads), axis = [1, 2, 3], keepdims = True))
        gp = tf.math.reduce_mean(tf.square(norm - 1))
        return gp * self.gp_weight
    
    def discriminator_loss(self, disc_real_out, disc_gen_out):
        return tf.math.reduce_mean(disc_gen_out) - tf.math.reduce_mean(disc_real_out)
    
    def drift_loss(self, disc_real_out):
        return tf.math.reduce_mean(tf.square(disc_real_out)) * self.drift_weight
    
    def compile(self, optimizer):
        super(ProgressiveGAN, self).compile()
        self.optimizer = optimizer
    
    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        
        batch_size = tf.shape(real_images)[0]
        latent_input = tf.random.normal((batch_size, 1, 1, self.latent_dim))
        for i in range(self.d_steps):
            with tf.GradientTape() as disc_tape:
                gen_out = self.generator(latent_input, training = True)
                
                disc_real_out = self.discriminator(real_images, training = True)
                disc_gen_out = self.discriminator(gen_out, training = True)
                
                gp = self.gradient_penalty(real_images, gen_out)
                drf_loss = self.drift_loss(disc_real_out)
                disc_loss = self.discriminator_loss(disc_real_out, disc_gen_out) + gp + drf_loss
                
            disc_grads = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
            self.optimizer.apply_gradients(zip(disc_grads, self.discriminator.trainable_variables))
            
        latent_input = tf.random.normal((batch_size, 1, 1, self.latent_dim))
        with tf.GradientTape() as gen_tape:
            gen_out = self.generator(latent_input, training = True)
            disc_gen_out = self.discriminator(gen_out, training = True)
            gen_loss = self.generator_loss(disc_gen_out)
            
        gen_grads = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.optimizer.apply_gradients(zip(gen_grads, self.generator.trainable_variables))
        
        return {'disc_loss': disc_loss, 'gen_loss': gen_loss}

In [None]:
class Callback(tf.keras.callbacks.Callback):
    def __init__(self):
        self.n_epoch = 0
    
    def set_steps(self, steps_per_epoch, epochs):
        self.steps_per_epoch = steps_per_epoch
        self.epochs = epochs
        self.steps = self.steps_per_epoch * self.epochs
        
    def on_epoch_begin(self, epoch, logs = None):
        self.n_epoch = epoch
        
    def on_batch_begin(self, batch, logs = None):
        alpha = ((self.n_epoch * self.steps_per_epoch) + batch) / float(self.steps - 1)
        self.model.alpha = alpha

In [None]:
pgan = ProgressiveGAN()

data_path = r'E:\Image Datasets\Celeb A\Dataset\img_align_celeba'
train_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(preprocessing_function = lambda x: (tf.cast(x, tf.float32)/127.5) - 1)

BATCHES = [4, 4, 4, 4, 1, 1, 1, 1, 1]
EPOCHS = [1, 1, 1, 1, 1, 1, 1, 1, 1]
# steps_per_epoch = [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]
steps_per_epoch = [5, 5, 5, 5, 5, 5, 5, 5, 5]
models = {}

cbk = Callback()

optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001, beta_1 = 0.0, beta_2 = 0.99, epsilon = 1e-8)

for i in range(9):
    res = 2**(i+2)
    models[f'{res}x{res}'] = {}
    train_data = train_image_generator.flow_from_directory(batch_size = BATCHES[i], directory = data_path, shuffle = True, 
                                                           target_size = (res, res), class_mode = 'binary')
    if i != 0:
        pgan.grow_model()
    
    cbk.set_steps(steps_per_epoch[i], EPOCHS[i])
    pgan.compile(optimizer)
    pgan.fit(train_data, steps_per_epoch = steps_per_epoch[i], epochs = EPOCHS[i], callbacks = [cbk])
    
    models[f'{res}x{res}']['main'] = {}
    models[f'{res}x{res}']['main']['generator'] = pgan.generator
    models[f'{res}x{res}']['main']['discriminator'] = pgan.discriminator

    if i != 0:
        pgan.stabilize_model()
        
        pgan.compile(optimizer)
        pgan.fit(train_data, steps_per_epoch = steps_per_epoch[i], epochs = EPOCHS[i], callbacks = [cbk])
    
        models[f'{res}x{res}']['stabilized'] = {}
        models[f'{res}x{res}']['stabilized']['generator'] = pgan.generator
        models[f'{res}x{res}']['stabilized']['discriminator'] = pgan.discriminator

In [None]:
# models

In [None]:
r = 512
plt.imshow(((models[f'{r}x{r}']['main']['generator'](tf.random.normal((1, 1, 1, 512)))[0]+1)*127.5).numpy().astype('uint8'))
plt.axis('off')
plt.show()

In [None]:
# gan = ProgressiveGAN()

In [None]:
# gan.grow_model()
# gan.stabilize_model()

In [None]:
# gan.stabilize_model()

In [None]:
# tf.keras.utils.plot_model(gan.generator, show_shapes = True, dpi = 64)
# tf.keras.utils.plot_model(gan.discriminator, 'model2.png', show_shapes = True, dpi = 64)

In [None]:
# gan.discriminator.summary()