# MSG-Gan (Multi Scale Gradient Gan)

This is an attempt to re-implement the paper MSG-GAN

Paper: https://arxiv.org/pdf/1903.06048.pdf

Other Resources: 
* https://github.com/akanimax/msg-gan-v1

In [None]:
from glob import glob
import tensorflow as tf

In [None]:
res_img = 1024
batch_size = 1

In [None]:
def load_files(file):
    images = tf.io.decode_png(tf.io.read_file(file), channels = 3)
    
    images = tf.image.resize(images, [res_img, res_img], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    images = tf.cast(images, tf.float32)
    
    images = (images/127.5)-1
    return images

In [None]:
def load_data(path, num_examples = 1000):
    data = glob(path)[:num_examples]
    return tf.data.Dataset.list_files(data).map(load_files).shuffle(num_examples).batch(batch_size)

In [None]:
dataset = load_data('E://Image Datasets//Celeb A//Dataset//img_align_celeba//img_align_celeba//*.jpg')

In [None]:
class MinibatchStddev(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        
    def call(self, inputs):
        inp_shape = tf.shape(inputs)
        
        mean = tf.math.reduce_mean(inputs, axis = 0, keepdims = True)
        std = tf.math.sqrt(tf.math.reduce_mean(tf.math.square(inputs - mean), axis = 0, keepdims = True) + self.epsilon)
        avg_std = tf.math.reduce_mean(std, keepdims = True)
        tiled = tf.tile(avg_std, (inp_shape[0], inp_shape[1], inp_shape[2], 1))
        combined = tf.concat([inputs, tiled], axis = -1)
        return combined

In [None]:
class ToRGB(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(filters = 3, kernel_size = (1, 1), strides = (1, 1), padding = 'same')
        self.act = tf.keras.layers.Activation('tanh')
    
    def call(self, inputs):
        return self.act(self.conv(inputs))
    
    
class FromRGB(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(filters = filters, kernel_size = (1, 1), strides = (1, 1), padding = 'same')
        self.act = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
    def call(self, inputs):
        return self.act(self.conv(inputs))

In [None]:
class GenBlock(tf.keras.layers.Layer):
    def __init__(self, filters, sampling = None, **kwargs):
        super().__init__(**kwargs)
        self.sampling = sampling
        if sampling is not None:
            if sampling == 'up':
                self.up_sample = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')
            elif sampling == 'down':
                self.down_sample = tf.keras.layers.AveragePooling2D()
            
        self.conv_1 = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
        self.act_1 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        self.conv_2 = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
        self.act_2 = tf.keras.layers.LeakyReLU(alpha = 0.2)
            
    def call(self, x):
        if self.sampling == 'up':
            x = self.up_sample(x)
        x = self.act_1(self.conv_1(x))
        x = self.act_2(self.conv_2(x))
        if self.sampling == 'down':
            x = self.down_sample(x)
        return x

In [None]:
def generator(latent_dim = 512, filters = None):
    inp = tf.keras.layers.Input(shape = (1, 1, latent_dim), dtype = tf.float32, name = 'generator_latent_input')
    rgb_outs = []
    filters = [512, 512, 512, 512, 256, 128, 64, 32, 16] if filters is None else filters

    g = tf.keras.layers.Conv2DTranspose(filters = filters[0], kernel_size = (4, 4))(inp)
    g = tf.keras.layers.LeakyReLU(alpha = 0.2)(g)
    g = tf.keras.layers.Conv2D(filters = latent_dim, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(g)
    g = tf.keras.layers.LeakyReLU(alpha = 0.2)(g)
    rgb_outs.append(ToRGB()(g))
    
    for f in filters[1:]:
        g = GenBlock(f, sampling = 'up')(g)
        rgb_outs.append(ToRGB()(g))

    return tf.keras.models.Model(inp,  rgb_outs)

def discriminator(n_inps = 9, combine_function = 'simple', logits = True):# np.log2(1024) - 1
    inputs = []
    for i in range(n_inps):
        res = 2**(i+2)
        inputs.append(tf.keras.layers.Input(shape = (res, res, 3), dtype = tf.float32, 
                                            name = f'discriminator_input_{res}x{res}x3'))
        
    filters = [512, 512, 512, 512, 256, 128, 64, 32, 16]

    x = inputs[len(inputs) - 1]
    for i in range(len(inputs) - 1, 0, -1):
        x = FromRGB(filters[i])(x)
        x = GenBlock(filters = filters[i], sampling = 'down')(x)
        
        if combine_function == 'simple':
            x = tf.keras.layers.Concatenate()([inputs[i - 1], x])
        elif combine_function == 'lin_cat':
            x = tf.keras.layers.Concatenate()([FromRGB(filters[i])(inputs[i - 1]), x])
        elif combine_function == 'cat_lin':
            x = FromRGB(filters[i])(tf.keras.layers.Concatenate()([inputs[i - 1], x]))
            
    x = MinibatchStddev()(x)
    x = tf.keras.layers.Conv2D(filters = filters[i], kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    x = tf.keras.layers.Conv2D(filters = filters[i], kernel_size = (4, 4), strides = (1, 1), padding = 'valid')(x)
    x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
    
    x = tf.keras.layers.Conv2D(filters = 1, kernel_size = (1, 1), strides = (1, 1), padding = 'same')(x)
    
    if logits:
        x = tf.keras.layers.Activation('sigmoid')(x)
        
    return tf.keras.models.Model(inputs, x)

In [None]:
class Losses(object):
    def __init__(self, loss_types = ['wgan_gp', 'adversarial']):
        if isinstance(loss_types, list) | isinstance(loss_types, tuple):
            self.loss_types = loss_types
        elif isinstance(loss_types, str):
            self.loss_types = [loss_types]
        else:
            raise Exception('Invalid Losses.')
        
    def gradient_penalty(self, disc, real, gen, lambda_ = 10.0):
        
        interpolated = []
        for r, g in zip(real, gen):
            epsilon = tf.random.uniform((r.shape[0], 1, 1, 1), 0.0, 1.0)
            interpolated.append(((1.0 - epsilon) * r) + (epsilon * g))
        
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            out = disc(interpolated)
        grads = gp_tape.gradient(out, [interpolated])[0]
        
        gp = 0
        for g in grads:
            norm = tf.math.sqrt(tf.math.reduce_mean(tf.math.square(g), axis = [1, 2, 3], keepdims = True) + 1e-8)
            gp += tf.math.reduce_mean(tf.square(norm - 1.0)) * lambda_
            
        return gp
        
    def discriminator_loss(self, disc_real_out, disc_gen_out):
        loss = 0
        for loss_type in self.loss_types:
            if loss_type.__contains__('wgan'):
                loss += tf.math.reduce_mean(disc_gen_out) - tf.math.reduce_mean(disc_real_out)

            elif loss_type == 'lsgan':
                loss += tf.math.reduce_mean(tf.math.square(tf.ones_like(disc_real_out) - disc_real_out)) 
                loss += tf.math.reduce_mean(tf.math.square(tf.zeros_like(disc_gen_out) - disc_gen_out))

            elif loss_type == 'adversarial':
                loss -= tf.math.reduce_mean(tf.math.log(disc_real_out)) + tf.math.reduce_mean(tf.math.log(1.0 - disc_gen_out))
        
        return loss
    
    def generator_loss(self, disc_gen_out):
        loss = 0
        for loss_type in self.loss_types:
            if loss_type.__contains__('wgan'):
                loss -= tf.math.reduce_mean(disc_gen_out)

            elif loss_type == 'lsgan':
                loss += tf.math.reduce_mean(tf.math.square(tf.ones_like(disc_gen_out) - disc_gen_out))

            elif loss_type == 'adversarial':
                loss -= tf.math.reduce_mean(tf.math.log(disc_gen_out))
                
        return loss

In [None]:
class Trainer(object):
    def __init__(self, latent_dim = 512, learning_rate = 0.003, loss_types = ['wgan_gp', 'adversarial']):
        self.latent_dim = latent_dim
        self.loss_types = loss_types
        self.gen_opt = tf.keras.optimizers.RMSprop(learning_rate = learning_rate)
        self.disc_opt = tf.keras.optimizers.RMSprop(learning_rate = learning_rate)
        
        self.generator = generator(latent_dim = latent_dim)
        self.discriminator = discriminator()
        
        self.losses = Losses(loss_types = loss_types)
        
    def generate_diff_dim_imgs(self, img, till = 4):
        assert img.shape[1] == img.shape[2]
        outs = [img]
        while outs[-1].shape[1] != 4:
            outs.append(tf.keras.layers.AveragePooling2D()(outs[-1]))
        return outs
    
    def generate_latent_noise(self, batch_size):
        return tf.random.normal((batch_size, 1, 1, self.latent_dim))
        
    @tf.function
    def train_step(self, img):
        
        real_imgs = self.generate_diff_dim_imgs(img)[::-1]
        latent_inp = self.generate_latent_noise(img.shape[0])
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_outs = self.generator(latent_inp, training = True)
            
            assert len(real_imgs) == len(gen_outs), f'{len(real_imgs)}, {len(gen_outs)}'
            disc_real_outs = self.discriminator(real_imgs, training = True)
            disc_gen_outs = self.discriminator(gen_outs, training = True)
            
            gen_loss = self.losses.generator_loss(disc_gen_outs)
            disc_loss = self.losses.discriminator_loss(disc_real_outs, disc_gen_outs)
            
            for loss_type in self.loss_types:
                if loss_type.__contains__('gp'): 
                    disc_loss += self.losses.gradient_penalty(self.discriminator, real_imgs, gen_outs)
            
        gen_grads = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.gen_opt.apply_gradients(zip(gen_grads, self.generator.trainable_variables))
        
        disc_grads = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        self.disc_opt.apply_gradients(zip(disc_grads, self.discriminator.trainable_variables))
        
        return gen_loss, disc_loss
    
    def train(self, data, epochs = 1):
        gen_losses, disc_losses = [], []
        for e in range(epochs):
            print(f'Epoch: {e} Starts')
            for img in data:
                gen_loss, disc_loss = self.train_step(img)
                print('.', end = '')
                
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)
            print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
            print(f'Epoch: {e} Ends\n')
            
        return {'gen_losses': gen_losses, 'disc_losses': disc_losses}

In [None]:
trainer = Trainer()

In [None]:
training_losses = trainer.train(dataset, 1)