# SA-Gan (Self - Attention Gan)

This is an attempt to re-implement the paper SA-GAN

Paper: https://arxiv.org/pdf/1805.08318.pdf

Other Resources: 
* https://github.com/taki0112/Self-Attention-GAN-Tensorflow
* https://github.com/brain-research/self-attention-gan

In [None]:
from glob import glob
import tensorflow as tf
from tensorflow.python.eager import def_function

In [None]:
path = 'E://Image Datasets//Celeb A\Dataset//img_align_celeba//img_align_celeba'
num_examples = 1000
batch_size = 64
img_size = 128

In [None]:
def load_files(file):
    images = tf.io.decode_png(tf.io.read_file(file), channels = 3)
    
    images = tf.cast(images, tf.float32)
    images =  tf.image.resize(images, [img_size, img_size], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    images = (images/127.5) - 1
    
    return images

def load_data(path):
    data = glob(path + '//*.jpg')[:num_examples]
    return tf.data.Dataset.list_files(data).map(load_files).shuffle(num_examples).batch(batch_size)

dataset = load_data(path)

In [None]:
# referenced: spectral norm
# https://gist.github.com/FloydHsiu/828eea345e1ca6950e05bb42f0a75b50
# https://gist.github.com/FloydHsiu/ab33c7d98d78f9873757810e2f8db50d
class SpectralNormalization(tf.keras.layers.Wrapper):
    def __init__(self, layer, **kwargs):
        super().__init__(layer, **kwargs)
        pass
    
    def build(self, input_shape):
        
        if not self.layer.built:
            self.layer.build(input_shape)
            
        if not hasattr(self.layer, 'weight'):
            self.w = tf.convert_to_tensor(self.layer.get_weights()[0])
        else:
            # self.w should not be trainable
            self.w = tf.identity(self.layer.weight)
        
        if not hasattr(self, 'w'):
            raise ValueError()
        
        self.w_shape = self.w.shape.as_list()
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.u = self.add_weight(shape = (1, self.w_shape[-1]), initializer = init, 
                                   trainable = False, name = 'specrtal_norm_u')
        
        # super().build()
        
    @def_function.function
    def call(self, inputs, training = None):
        if training is None:
            training = tf.keras.backend.learning_phase()
            
        if training == True:
            self._compute_weights()
            
        output = self.layer(inputs)
        return output
    
    def _compute_weights(self):
        w_reshaped = tf.reshape(self.w, (-1, self.w_shape[-1]))
        eps = 1e-12
        
        _u = tf.identity(self.u)
        _v = tf.matmul(_u, tf.transpose(w_reshaped, perm = [1, 0]))
        _v /= tf.maximum(tf.math.reduce_sum(_v ** 2) ** 0.5, eps)
        _u = tf.matmul(_v, w_reshaped)
        _u /= tf.maximum(tf.math.reduce_sum(_u ** 2) ** 0.5, eps)
        
        _u = tf.stop_gradient(_u)
        _v = tf.stop_gradient(_v)
        
        self.u.assign(_u)
        sigma = tf.matmul(tf.matmul(_v, w_reshaped), tf.transpose(_u, perm = [1, 0]))
        
        self.layer.weight.assign(self.w / sigma)

In [None]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, neurons, **kwargs):
        super().__init__(**kwargs)
        self.neurons = neurons
        
    def build(self, input_shape):
        inp_neurons = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.weight = self.add_weight(shape = (inp_neurons, self.neurons), initializer = init, 
                                 trainable = True, name = 'weight')
        self.bias = self.add_weight(shape = (1, self.neurons), initializer = 'zeros', 
                                 trainable = True, name = 'bias')
        
    def call(self, inputs):
        return tf.add(tf.matmul(inputs, self.weight), self.bias)
    
class SNLinear(tf.keras.layers.Layer):
    def __init__(self, neurons, **kwargs):
        super().__init__(**kwargs)
        self.linear_sn = SpectralNormalization(Linear(neurons = neurons))
        
    def call(self, inputs):
        return self.linear_sn(inputs)

In [None]:
class Conv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = 'SAME' if (kernel_size[0] - 1)//2 else 'VALID'
        
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.weight = self.add_weight(shape = self.kernel_size + (inp_filters, self.filters), initializer = init, 
                                      trainable = True, name = 'kernel')
        self.bias = self.add_weight(shape = (1, 1, 1, self.filters), initializer = 'zeros', 
                                    trainable = True, name = 'bias')
        
    def call(self, inputs):
        return tf.add(tf.nn.conv2d(inputs, self.weight, self.strides, self.padding), self.bias)
    
    
class SNConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, **kwargs):
        super().__init__(**kwargs)
        self.conv_sn = SpectralNormalization(Conv2D(filters = filters, kernel_size = kernel_size, strides = strides))
    
    def call(self, inputs):
        return self.conv_sn(inputs)

In [None]:
class SelfAttention(tf.keras.layers.Layer):
    def __init__(self, k = 8, **kwargs):
        super().__init__(**kwargs)
        self.k = k
        
    def build(self, input_shape):
        self.inp_shp = input_shape
        
        self.gamma = self.add_weight(shape = (1, ), initializer = 'zeros', 
                                     trainable = True, name = 'gamma')
        
        chn = self.inp_shp[-1]
        self.f = SNConv2D(filters = chn//self.k, kernel_size = (1, 1), strides = (1, 1))
        self.g = SNConv2D(filters = chn//self.k, kernel_size = (1, 1), strides = (1, 1))
        self.h = SNConv2D(filters = chn, kernel_size = (1, 1), strides = (1, 1))
        self.v = SNConv2D(filters = chn, kernel_size = (1, 1), strides = (1, 1))
        
    def call(self, inputs):
        fx = self.f(inputs)
        gx = self.g(inputs)
        hx = self.h(inputs)
        
        location_num = self.inp_shp[1] * self.inp_shp[2]
        gx = tf.reshape(gx, (-1, location_num, gx.shape[-1]))
        fx = tf.reshape(fx, (-1, location_num, fx.shape[-1]))
        hx = tf.reshape(hx, (-1, location_num, hx.shape[-1]))
        
        s = tf.matmul(gx, fx, transpose_b = True)
        beta = tf.nn.softmax(s, axis = -1)
        o = tf.matmul(beta, hx)
        
        out = tf.reshape(o, (-1, self.inp_shp[1], self.inp_shp[2], self.inp_shp[3]))
        out = self.v(out)
        
        out = self.gamma * out + inputs
        return out

class NonLocalBlock(tf.keras.layers.Layer):
    def __init__(self, k = 8, **kwargs):
        super().__init__(**kwargs)
        self.k = k
        
    def build(self, input_shape):
        self.inp_shp = input_shape
        
        self.sigma = self.add_weight(shape = (1, ), initializer = 'zeros', 
                                     trainable = True, name = 'sigma')
        
        chn = self.inp_shp[-1]
        self.f = SNConv2D(filters = chn//self.k, kernel_size = (1, 1), strides = (1, 1))
        self.g = SNConv2D(filters = chn//self.k, kernel_size = (1, 1), strides = (1, 1))
        self.h = SNConv2D(filters = chn//2, kernel_size = (1, 1), strides = (1, 1))
        self.v = SNConv2D(filters = chn, kernel_size = (1, 1), strides = (1, 1))
        
        self.max_pool_1 = tf.keras.layers.MaxPool2D()
        self.max_pool_2 = tf.keras.layers.MaxPool2D()
        
    def call(self, inputs):
        location_num = self.inp_shp[1] * self.inp_shp[2]
        downsampled_num = location_num // 4
        
        fx = self.f(inputs)
        fx = tf.reshape(fx, (-1, location_num, fx.shape[-1]))
        
        gx = self.max_pool_1(self.g(inputs))
        gx = tf.reshape(gx, (-1, downsampled_num, gx.shape[-1]))
        
        attn = tf.nn.softmax(tf.matmul(fx, gx, transpose_b = True), axis = -1)
        
        hx = self.max_pool_2(self.h(inputs))
        hx = tf.reshape(hx, (-1, downsampled_num, hx.shape[-1]))
        
        out = tf.matmul(attn, hx)
        out = tf.reshape(out, (-1, self.inp_shp[1], self.inp_shp[2], self.inp_shp[-1]//2))
        out = self.v(out)
        
        return self.sigma * out + inputs

In [None]:
class SAGAN(object):
    def __init__(self, img_res = 128, latent_dim = 1042):
        self.img_res = img_res
        self.latent_dim = latent_dim
    
    def up_sample_res_block(self, inp, filters):
        x = tf.keras.layers.BatchNormalization()(inp)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'nearest')(x)
        x = SNConv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1))(x)
        
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        x = SNConv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1))(x)
        
        inp = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'nearest')(inp)
        inp = SNConv2D(filters = filters, kernel_size = (1, 1), strides = (1, 1))(inp)
        
        return inp + x
        
    def down_sample_res_block(self, inp, filters, down_sample = True):
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(inp)
        x = SNConv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1))(x)
        
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        x = SNConv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1))(x)
        
        if down_sample:
            x = tf.keras.layers.AveragePooling2D()(x)
        
        if down_sample or (inp.shape[-1] != filters):
            inp = SNConv2D(filters = filters, kernel_size = (1, 1), strides = (1, 1))(inp)
            if down_sample:
                inp = tf.keras.layers.AveragePooling2D()(inp)
        
        return x + inp
    
    def optimized_block(self, inp, filters):
        x = SNConv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1))(inp)
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        
        x = SNConv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1))(x)
        x = tf.keras.layers.AveragePooling2D()(x)
        
        inp = tf.keras.layers.AveragePooling2D()(inp)
        inp = SNConv2D(filters = filters, kernel_size = (1, 1), strides = (1, 1))(inp)
        
        return x + inp
        
        
    def generator(self):
        inp = tf.keras.layers.Input(shape = (self.latent_dim, ), dtype = tf.float32, name = 'generator_input')
        
        x = SNLinear(neurons = 4*4*1024)(inp)
        x = tf.keras.layers.Reshape((4, 4, 1024))(x)
        
        x = self.up_sample_res_block(inp = x, filters = 1024)
        x = self.up_sample_res_block(inp = x, filters = 512)
        x = self.up_sample_res_block(inp = x, filters = 256)
        
        ################################################
        # attention layer
        ## both the ways are similar
        x = NonLocalBlock()(x)
        # x = SelfAttention()(x)
        ################################################
        
        x = self.up_sample_res_block(inp = x, filters = 128)
        x = self.up_sample_res_block(inp = x, filters = 64)
        
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        x = SNConv2D(filters = 3, kernel_size = (3, 3), strides = (1, 1))(x)
        x = tf.keras.layers.Activation('tanh')(x)
        
        return tf.keras.models.Model(inp, x, name = 'Generator')
    
    def discriminator(self):
        inp = tf.keras.layers.Input(shape = (self.img_res, self.img_res, 3), dtype = tf.float32, 
                                    name = 'discriminator_input')
        
        x = self.optimized_block(inp = inp, filters = 64)
        x = self.down_sample_res_block(inp = x, filters = 128)
        
        ################################################
        # attention layer
        # both the ways are similar
        x = NonLocalBlock()(x)
        # x = SelfAttention()(x)
        ################################################
        
        x = self.down_sample_res_block(inp = x, filters = 256)
        x = self.down_sample_res_block(inp = x, filters = 512)
        x = self.down_sample_res_block(inp = x, filters = 1024)
        x = self.down_sample_res_block(inp = x, filters = 1024, down_sample = False)
        
        x = tf.keras.layers.LeakyReLU(alpha = 0.2)(x)
        x = tf.math.reduce_mean(x, axis = [1, 2])
        x = SNLinear(neurons = 1)(x)
        
        return tf.keras.models.Model(inp, x, name = 'Discriminator')

In [None]:
class Losses(object):
    def discriminator_loss(self, disc_real_out, disc_gen_out):
        real_loss = -tf.math.reduce_mean(tf.minimum(0.0, disc_real_out - 1.0))
        gen_loss = -tf.math.reduce_mean(tf.minimum(0.0, -1.0 - disc_gen_out))
        return real_loss + gen_loss
    
    def generator_loss(self, disc_gen_out):
        return -tf.math.reduce_mean(disc_gen_out)

In [None]:
class Trainer(object):
    def __init__(self, img_res = 128, latent_dim = 1024, gen_lr = 1e-4, disc_lr = 4e-4):
        self.latent_dim = latent_dim
        
        self.gen_opt = tf.keras.optimizers.Adam(learning_rate = gen_lr, beta_1 = 0.0, beta_2 = 0.9)
        self.disc_opt = tf.keras.optimizers.Adam(learning_rate = disc_lr, beta_1 = 0.0, beta_2 = 0.9)
        
        gan = SAGAN(img_res = img_res, latent_dim = latent_dim)
        self.generator = gan.generator()
        self.discriminator = gan.discriminator()
        
        self.losses = Losses()
        
    def generate_latent_noise(self, batch_size):
        return tf.random.normal((batch_size, self.latent_dim))
        
    @tf.function
    def train_step(self, real_img):
        latent_inp = self.generate_latent_noise(real_img.shape[0])
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_out = self.generator(latent_inp, training = True)
            
            disc_real_out = self.discriminator(real_img, training = True)
            disc_gen_out = self.discriminator(gen_out, training = True)
            
            gen_loss = self.losses.generator_loss(disc_gen_out)
            disc_loss = self.losses.discriminator_loss(disc_real_out, disc_gen_out)
            
        gen_grads = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.gen_opt.apply_gradients(zip(gen_grads, self.generator.trainable_variables))
        
        disc_grads = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        self.disc_opt.apply_gradients(zip(disc_grads, self.discriminator.trainable_variables))
        
        return gen_loss, disc_loss
    
    def train(self, data, epochs = 1):
        gen_losses, disc_losses = [], []
        for e in range(epochs):
            print(f'Epoch: {e} Starts')
            for img in data:
                gen_loss, disc_loss = self.train_step(img)
                print('.', end = '')
                
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)
            print(f'\nGenerator Loss: {gen_loss} \t Discriminator Loss: {disc_loss}')
            print(f'Epoch: {e} Ends\n')
            
        return {'gen_losses': gen_losses, 'disc_losses': disc_losses}

In [None]:
trainer = Trainer()

In [None]:
training_losses = trainer.train(dataset)