# Gene Gan

This is an attempt to re-implement the paper Gene-GAN

Paper: https://arxiv.org/pdf/1705.04932v1.pdf

Other Resources: 
* https://github.com/Prinsphield/GeneGAN

In [1]:
import tensorflow as tf

In [40]:
class AddNoise(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        pass
    
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        self.B = self.add_weight(shape = (1, 1, 1, inp_filters), initializer = 'zeros', 
                                 trainable = True, name = 'bias')
        
    def call(self, inputs):
        return tf.add(inputs, self.B)

class GAN(object):
    def __init__(self, img_shape):
        self.img_shape = img_shape
        self.encoded_shape = self.encoder.output_shape[0][1:]
        
    @property
    def encoder(self):
        inp = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, name = 'encoder')
        
        x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(inp)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 265, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 512, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        # out_1 contains background info & out_2 contains object info 
        out_1, out_2 = x[:, :, :, :256], x[:, :, :, 256:]
        return tf.keras.models.Model(inp, [out_1, out_2], name = 'encoder')
        
    @property
    def decoder(self):
        inp_bckgr_encoded = tf.keras.layers.Input(shape = self.encoded_shape, dtype = tf.float32, 
                                                  name = 'decoder_background_info_input')
        inp_object_encoded = tf.keras.layers.Input(shape = self.encoded_shape, dtype = tf.float32, 
                                                   name = 'decoder_object_info_input')
        
        x = tf.keras.layers.Concatenate(axis = -1)([inp_bckgr_encoded, inp_object_encoded])
        
        x = tf.keras.layers.Conv2DTranspose(filters = 512, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv2DTranspose(filters = 256, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        
        x = tf.keras.layers.Conv2DTranspose(filters = 3, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = AddNoise()(x)
        
        return tf.keras.models.Model([inp_bckgr_encoded, inp_object_encoded], x, name = 'Decoder')
    
    @property
    def discriminator(self):
        inp = tf.keras.layers.Input(shape = self.img_shape, dtype = tf.float32, name = 'discriminator_input')
        
        x = tf.keras.layers.Conv2D(filters = 128, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(inp)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 256, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 512, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Conv2D(filters = 512, kernel_size = (4, 4), strides = (2, 2), padding = 'same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(units = 1)(x)
        #x = tf.keras.layers.Activation('sigmoid')(x)
        
        return tf.keras.models.Model(inp, x, name = 'Discriminator')

In [41]:
class Losses(object):
    def __init__(self):
        self.bce = tf.keras.losses.BinaryCrossentropy(from_logits = True)
    
    def reconstruction_loss(self, real, re_gen):
        return tf.math.reduce_mean(tf.math.abs(real, re_gen))
    
    def nulling_loss(self, eps):
        return tf.math.reduce_mean(tf.abs(eps))
    
    def gen_log_loss(self, disc_out):
        #return -tf.math.reduce_mean(tf.math.log(disc_out))
        return self.bce(tf.ones_like(disc_out), disc_out)
    
    def parallelogram_loss(self, au, b0, a0, bu):
        return tf.math.reduce_mean(tf.math.abs(au + b0 - a0 - bu))
    
    def disc_log_loss(self, disc_out_1, disc_out_2):
        real_loss = self.bce(tf.ones_like(disc_out_1), disc_out_1)
        fake_loss = self.bce(tf.zeros_like(disc_out_2), disc_out_2)
        return real_loss + fake_loss

In [42]:
class Trainer(object):
    def __init__(self, img_shape, learning_rate = 5e-5, **kwargs):
        super().__init__(**kwargs)
        self.gen_opt = tf.keras.optimizers.RMSprop(learning_rate = learning_rate)
        self.disc_opt = tf.keras.optimizers.RMSprop(learning_rate = learning_rate)
        
        gan = GAN(img_shape = img_shape)
        self.encoder = gan.encoder
        self.decoder = gan.decoder
        self.discriminator = gan.discriminator
        
        self.losses = Losses()
        
    @property
    def train_step(self, img_au, img_b0):
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_a, gen_u = self.encoder(img_au, training = True)
            gen_b, gen_eps = self.encoder(img_b0, training = True)
            
            gen_a0 = self.decoder([gen_a, tf.zeros_like(gen_eps)], training = True)
            gen_bu = self.decoder([gen_b, gen_u], training = True)
            
            re_gen_au = self.decoder([gen_a, gen_u], training = True)
            re_gen_b0 = self.decoder([gen_b, tf.zeros_like(gen_eps)], training = True)
            
            disc_au = self.discriminator(img_au, training = True)
            disc_a0 = self.discriminator(gen_a0, training = True)
            disc_b0 = self.discriminator(img_b0, training = True)
            disc_bu = self.discriminator(gen_bu, training = True)
            
            gen_loss = self.losses.reconstruction_loss(img_au, re_gen_au)
            gen_loss += self.losses.reconstruction_loss(img_b0, re_gen_b0)
            gen_loss += (self.losses.gen_log_loss(disc_a0) + self.losses.gen_log_loss(disc_bu))
            gen_loss += self.losses.nulling_loss(gen_eps)            
            gen_loss += self.losses.parallelogram_loss(img_au, img_b0, gen_a0, gen_bu)
            
            disc_loss = self.losses.disc_log_loss(disc_au, disc_bu)
            disc_loss += self.losses.disc_log_loss(disc_a0, disc_b0)
            
        gen_params = self.encoder.trainable_variables + self.decoder.trainable_variables
        gen_grads = gen_tape.gradient(gen_loss, gen_params)
        self.gen_opt.apply_gradients(zip(gen_grads, gen_params))
        
        disc_grads = disc_tape.gradient(dis_loss, self.discriminator.trainable_variables)
        self.disc_opt.apply_gradients(zip(disc_grads, self.discriminator.trainable_variables))
        
        return gen_loss, disc_loss
    
    def train(self, data, epochs = 1):
        gen_losses, disc_losses = [], []
        for e in range(epochs):
            print(f'Epoch: {e} Starts')
            for img_au, img_b0 in data:
                gen_loss, disc_loss = self.train_step(img_au, img_b0)
                print('.', end = '')
                
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)
            print(f'\nGenerator Loss: {gen_loss} \t Discriminator Losss: {disc_loss}')
            print(f'Epoch: {e} Ends\n')
            
        return {'gen_losses': gen_losses, 'disc_losses': disc_losses}
    