# Star Gan

This is an attempt to re-implement the paper Star GAN

Paper: https://arxiv.org/pdf/1711.09020.pdf

Other Resources: 
* https://github.com/yunjey/stargan

In [1]:
import tensorflow as tf

In [2]:
config = {
    'img_shape': (128, 128, 3),
    'num_classes': 3
}

In [3]:
class InstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-6, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
    
    def build(self, input_shape):
        inp_chn = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.gamma = self.add_weight(shape = (1, 1, 1, inp_chn), initializer = init, 
                                     trainable = True, name = 'gamma')
        self.beta = self.add_weight(shape = (1, 1, 1, inp_chn), initializer = 'zeros', 
                                    trainable = True, name = 'beta')
        
    def call(self, inputs):
        mean = tf.math.reduce_mean(inputs, axis = [1, 2], keepdims = True)
        rstd = tf.math.rsqrt(tf.math.reduce_variance(inputs) + self.epsilon)
        
        out = self.gamma * ((inputs - mean) * rstd) + self.beta
        return out

In [4]:
class GAN(object):
    def __init__(self, config):
        self.img_shape = config['img_shape']
        self.num_classes = config['num_classes']
        
    def __reshape_lbl(self, x):
        x = tf.repeat(x, self.img_shape[0]*self.img_shape[1])
        x = tf.reshape(x, (-1, self.img_shape[0], self.img_shape[1], self.num_classes))
        return x
    
    def __residual_block(self, inp, filters = 256, kernel_size = (3, 3), strides = (1, 1), padding = 'same'):

        x = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)(inp)
        x = InstanceNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)(x)
        x = InstanceNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Add()([x, inp])
        return x
      
    @property
    def generator(self):
        inp_img = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, 
                                        name = f'generator_img_input')
        inp_lbl = tf.keras.layers.Input(shape = self.num_classes, dtype = tf.float32, 
                                        name = f'generator_label_input')
        
        lbl = tf.keras.layers.Lambda(self.__reshape_lbl)(inp_lbl)
        x = tf.keras.layers.Concatenate(axis = -1)([inp_img, lbl])
        
        # Downsampling
        x = tf.keras.layers.Conv2D(filters = 64, kernel_size = (7, 7), strides = (1, 1), padding = 'same')(x)
        x = InstanceNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = InstanceNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv2D(filters = 256, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = InstanceNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        # Bottleneck
        for _ in range(6):
            x = self.__residual_block(x, filters = 256, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
            
        
        # UpSampling
        x = tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = InstanceNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv2DTranspose(filters = 64, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = InstanceNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        
        x = tf.keras.layers.Conv2D(filters = 3, kernel_size = (7, 7), strides = (1, 1), padding = 'same')(x)
        x = tf.keras.layers.Activation('tanh')(x)
        
        return tf.keras.models.Model(inputs = [inp_img, inp_lbl], outputs = x, name = f'Generator')
    
    
    @property
    def discriminator(self):
        inp = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, name = f'discriminator')
        
        x = tf.keras.layers.Conv2D(filters = 64, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(inp)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 256, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 512, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 1024, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 2048, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        D_src = tf.keras.layers.Conv2D(filters = 1, kernel_size = (3, 3), strides = (1, 1), padding = 'same')(x)
        
        D_cls = tf.keras.layers.Conv2D(filters = self.num_classes, 
                                       kernel_size = (self.img_shape[0]//64, self.img_shape[1]//64), 
                                       strides = (1, 1), padding = 'valid')(x)
        D_cls = tf.keras.layers.Reshape((self.num_classes, ))(D_cls)
        
        return tf.keras.models.Model(inp, [D_src, D_cls], name = 'Discriminator')

In [5]:
gan = GAN(config)
discriminator = gan.discriminator
generator = gan.generator

In [6]:
class Losses(object):
    def __init__(self, loss_type):
        self.bce = tf.keras.losses.BinaryCrossentropy(from_logits = True)
        if loss_type == 'sparse':
            self.cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True)
        if loss_type == 'categorical':
            self.cce = tf.keras.losses.CategoricalCrossentropy(from_logits = True)
    
    def gradeint_penalty(self, discriminator, real, gen, gp_weight = 10.0):
        
        epsilon = tf.random.uniform((real.shape[0], 1, 1, 1))
        interpolated = ((1 - epsilon) * real) + (epsilon * gen)
        
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            out = discriminator(interpolated)
        grads = gp_tape.gradient(out, [interpolated])[0]
        norm = tf.math.sqrt(tf.math.reduce_mean(tf.math.square(grads), axis = [1, 2, 3], keepdims = True))
        gp = tf.math.reduce_mean(tf.square(norm - 1.0)) * gp_weight
        return gp
    
    def disc_wgan_loss(self, disc_real_out_src, disc_gen_out_src):
        return tf.math.reduce_mean(disc_gen_out_src) - tf.math.reduce_mean(disc_real_out_src)
    
    def gen_wgan_loss(self, disc_gen_out_src):
        return -tf.math.reduce_mean(disc_gen_out_src)
    
    def adversarial_loss_disc(self, disc_real_out_src, disc_gen_out_src):
        real_loss = self.bce(tf.ones_like(disc_real_out_src), disc_real_out_src)
        gen_loss = self.bce(tf.zeros_like(disc_gen_out_src), disc_gen_out_src)
        return real_loss + gen_loss
    
    def adversarial_loss_gen(self, disc_gen_out_src):
        return self.bce(tf.ones_like(disc_gen_out_src), disc_gen_out_src)
        
    def reconstruction_loss(self, real, re_gen):
        return tf.math.reduce_mean(tf.math.abs(real - re_gen))
    
    def cls_loss(self, tar_cls, disc_out_cls):
        return self.cce(tar_cls, disc_out_cls)  

In [7]:
class Trainer(object):
    def __init__(self, gan, config = None, n_disc = 5, learning_rate = 1e-4, 
                 lambda_cls = 1.0, lambda_rec = 10.0, loss_type = 'wgan', cat_loss_type = 'categorical'):
        self.n_disc = n_disc
        self.lambda_cls = lambda_cls
        self.lambda_rec = lambda_rec
        self.loss_type = loss_type
        
        self.gen_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        self.disc_optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = 0.5, beta_2 = 0.999)
        
        if callable(gan):
            if config is not None:
                self.gan = gan(config)
            else:
                raise Exception('missing argument `config`')
        else:
            self.gan = gan
            
        self.generator = gan.generator
        self.discriminator = gan.discriminator
        
        self.losses = Losses(cat_loss_type)
        
    def train(self, data, epochs = 1):
        gen_losses, disc_losses = [], []
        for e in range(epochs):
            print(f'Epoch: {e} Starts')
            for real_img, real_cls, tar_img, tar_cls in data:
                gen_loss, disc_loss = self.train_step(real_img, real_cls, tar_img, tar_cls)
                print('.', end = '')
                
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)
            print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
            print(f'Epoch: {e} Ends\n')
            
        return {'gen_losses': gen_losses, 'disc_losses': disc_losses}
        
    def generator_loss(self, disc_gen_out_src, real_img, re_gen_img, disc_gen_out_cls, tar_cls, loss_type = 'wgan'):
        if self.loss_type == 'wgan':
            gen_loss = self.losses.gen_wgan_loss(disc_gen_out_src)
        elif self.loss_type == 'adverserial':
            gen_loss = self.losses.adversarial_loss_gen(disc_gen_out_src)
        
        recon_loss = self.losses.reconstruction_loss(real_img, re_gen_img)
        gen_cls_loss = self.losses.cls_loss(tar_cls, disc_gen_out_cls)
        
        loss = gen_loss + self.lambda_cls * gen_cls_loss + self.lambda_rec * recon_loss
        return loss
    
    def discriminator_loss(self, tar_img, real_cls, gen_out, disc_real_out_src, disc_gen_out_src, disc_real_out_cls, 
                           loss_type = 'wgan'):
        if self.loss_type == 'wgan':
            disc_loss = self.losses.disc_wgan_loss(disc_real_out_src, disc_gen_out_src)
            disc_loss += self.losses.gradeint_penalty(self.discriminator, tar_img, gen_out)
        elif self.loss_type == 'adverserial':
            disc_loss = self.losses.adversarial_loss_gen(disc_real_out_src, disc_gen_out_src)
        
        disc_cls_loss = self.losses.cls_loss(real_cls, disc_real_out_cls)
        
        loss = disc_loss + self.lambda_cls * disc_cls_loss
        return loss
        
    @tf.function
    def train_step(self, real_img, real_cls, tar_img, tar_cls):
        
        for _ in range(self.n_disc):
            with tf.GradientTape() as disc_tape:
                gen_out = self.generator([real_img, tar_cls], training = True)
                
                disc_real_out_src, disc_real_out_cls = self.discriminator(real_img, training = True)
                disc_gen_out_src, disc_gen_out_cls = self.discriminator(gen_out, training = True)
                
                disc_loss = self.discriminator_loss(tar_img, real_cls, gen_out, disc_real_out_src, 
                                                    disc_gen_out_src, disc_real_out_cls)
                
            disc_grads = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
            self.disc_optimizer.apply_gradients(zip(disc_grads, self.discriminator.trainable_variables))
            
        with tf.GradientTape() as gen_tape:
            gen_out = self.generator([real_img, tar_cls], training = True)
            re_gen_out = self.generator([gen_out, real_cls], training = True)

            disc_gen_out_src, disc_gen_out_cls = self.discriminator(gen_out, training = True)
                
            gen_loss = self.generator_loss(disc_gen_out_src, real_img, re_gen_out, disc_gen_out_cls, tar_cls)  
        
        gen_grads = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.gen_optimizer.apply_gradients(zip(gen_grads, self.generator.trainable_variables))
        
        return disc_loss, gen_loss

In [8]:
ri, rl = tf.random.uniform((120, 128, 128, 3)), tf.random.uniform((120, 3))
ti, tl = tf.random.uniform((120, 128, 128, 3)), tf.random.uniform((120, 3))

In [9]:
dataset = tf.data.Dataset.from_tensor_slices((ri, rl, ti, tl)).shuffle(1200).batch(2)

In [10]:
t = Trainer(GAN(config))

In [11]:
t.train(dataset, 1)

Epoch: 0 Starts
............................................................
Generator Loss: -2172282.25 	 Discriminator Loss: 27368048.0
Epoch: 0 Ends



{'gen_losses': [<tf.Tensor: shape=(), dtype=float32, numpy=-2172282.2>],
 'disc_losses': [<tf.Tensor: shape=(), dtype=float32, numpy=27368048.0>]}