In [None]:
# Importing modules
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import os
from keras.optimizers import Adam
from keras.layers import Input, Dense, Flatten, Reshape, LeakyReLU, UpSampling2D, AveragePooling2D, Layer, Add 
from keras.initializers import RandomNormal
from keras.models import Model


In [None]:
Batch_size=16
EPSILON=1e-8
LR = 1e-3


In [None]:
# Classes use later
class PixelNorm(Layer):
    def __init__(self, **kwargs):
        super(PixelNorm, self).__init__(**kwargs)
    
    def call(self, inputs):
        values = inputs**2.0
        mean_value = keras.backend.mean(values, axis=1, keepdims=True)
        mean_value += EPSILON
        l2 = keras.backend.sqrt(mean_value)
        normalized = inputs/l2
        return normalized
    
    def compute_output_shape(self, input_shape):
        return input_shape



class Minibatchstdev(Layer):
    def __init__(self, **kwargs):
        super(Minibatchstdev, self).__init__(**kwargs)

    def call(self, inputs):
        mean=keras.backend.mean(inputs, axis=0, keepdims=True)
        sq_diff = keras.backend.square(inputs-mean)
        mean_sq_diff = keras.backend.mean(sq_diff, axis=0, keepdims=True)
        mean_sq_diff += EPSILON
        stdev = keras.backend.sqrt(mean_sq_diff)

        mean_pix = keras.backend.mean(stdev, keepdims=True)
        shape = keras.backend.shape(inputs)
        output= keras.backend.tile(mean_pix, (shape[0], shape[1], shape[2], 1))

        combined = keras.backend.concatenate([inputs, output], axis=-1)
        return combined
    
    def compute_output_shape(self, input_shape):
        input_shape = list(input_shape)
        input_shape[-1] += 1
        return tuple(input_shape)


class WeighedSum(Add):
    def __init__(self, alpha=0.0, **kwargs):
        super(WeighedSum, self).__init__(**kwargs)
        self.alpha = keras.backend.variable(alpha, name="ws_Alpha")
    
    def _merge_function(self, inputs):
        assert (len(inputs)==2)
        # ((1-a)*input1) + (a*input2)
        output = ((1.0- self.alpha)*inputs[0]) + (self.alpha * inputs[1])
        return output
    


def update_fadein(models, step, n_steps):
    alpha= min(step/float(n_steps-1), 1)
    for model in models:
        for layer in model.layers:
            if isinstance(layer, WeighedSum):
                keras.backend.set_value(layer.alpha, alpha)
    return alpha



In [None]:
# EQUALIZED LEARNING RATE
class EqualizedConv(Layer):
    def __init__(self, out_channels, kernal, gain=2, padding="valid", **kwargs):
        super().__init__(**kwargs)
        self.kernal = kernal
        self.out_channels= out_channels
        self.gain = gain
        self.pad = kernal != 1
        self.padding = padding.upper()

    def build(self, input_shape):
        self.in_channels = input_shape[-1]
        initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
        self.w = self.add_weight(
            shape=[self. kernal,self.kernal, self.in_channels, self.out_channels],
            initializer=initializer,
            trainable=True,
            name="kernal",
        )
        self.b = self.add_weight(
            shape=(self.out_channels,),
            initializer="zeros",
            trainable=True,
            name="bias"
        )

        fan_in = self.kernal*self.kernal*self.in_channels
        self.scale = tf.sqrt(self.gain/fan_in)

    
    def call(self, inputs):
        x= inputs
        output=(tf.nn.conv2d(x, self.scale*self.w, strides=1, padding=self.padding)+self.b)
        return output


class EqualizedConvT(Layer):
    def __init__(self, out_channels, kernal, gain=2, padding="valid", **kwargs):
        super().__init__(**kwargs)
        self.kernal = kernal
        self.out_channels= out_channels
        self.gain = gain
        self.pad = kernal != 1
        self.padding = padding.upper()

    def build(self, input_shape):
        self.in_channels = input_shape[-1]
        initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
        self.w = self.add_weight(
            shape=[self. kernal,self.kernal, self.in_channels, self.out_channels],
            initializer=initializer,
            trainable=True,
            name="kernal",
        )
        self.b = self.add_weight(
            shape=(self.out_channels,),
            initializer="zeros",
            trainable=True,
            name="bias"
        )

        fan_in = self.kernal*self.kernal*self.in_channels
        self.scale = tf.sqrt(self.gain/fan_in)

    
    def call(self, inputs):
        x= inputs
        os=[tf.shape(x)[0], 4,4,self.out_channels]
        output=(
            tf.nn.conv2d_transpose(x, self.scale*self.w, strides=1,output_shape=os, padding=self.padding)+self.b)
        return output



class EqualizedDense(Layer):
    def __init__(self, units ,gain=2, **kwargs):
        super().__init__(**kwargs)
        self.units= units
        self.gain = gain

    def build(self, input_shape):
        self.in_channels = input_shape[-1]
        initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
        self.w = self.add_weight(
            shape=[self.in_channels, self.units],
            initializer=initializer,
            trainable=True,
            name="kernal",
        )
        self.b = self.add_weight(
            shape=(self.units,),
            initializer="zeros",
            trainable=True,
            name="bias"
        )

        fan_in = self.in_channels
        self.scale = tf.sqrt(self.gain/fan_in)

    
    def call(self, inputs):
        output=tf.add(tf.matmul(inputs, self.scale*self.w),self.b)
        return output



In [None]:
# define generator model
def define_generator(latent_dim):

    in_latent = Input(shape=(latent_dim,))
    g = Reshape((1,1,512))(in_latent)
    g = PixelNorm()(g)

    # con 4x4 block
    g = EqualizedConvT(512, kernal=4, padding="valid")(g)
    g = LeakyReLU(alpha=0.2)(g)
    g = PixelNorm()(g)

    # con 3x3 block
    g = EqualizedConv(512, kernal=3, padding="same")(g)
    g = LeakyReLU(alpha=0.2)(g)
    g = PixelNorm()(g)

    # con 1x1 output block
    out_img = EqualizedConv(3, kernal=1, padding="same")(g)
    model = Model(inputs=[in_latent], outputs=out_img)
    return model



In [None]:
# define discriminator model
def define_discriminator(input_shape=(4,4,3)):
    in_img = Input(shape=input_shape)

    d = EqualizedConv(512, kernal=1, padding="same")(in_img)
    d = LeakyReLU(alpha=0.2)(d)

    d = Minibatchstdev()(d)

    d = EqualizedConv(512, kernal=3, padding="same")(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = EqualizedConv(512, kernal=4, padding="valid")(d)
    d = LeakyReLU(alpha=0.2)(d)

    d = Flatten()(d)
    out_class = EqualizedDense(1)(d)

    model = Model(inputs=[in_img], outputs=out_class)
    return model

In [None]:
# Addin generator block
def add_generator_block(old_model, filters):
    block_end = old_model.layers[-2].output
    
    upsampling = UpSampling2D()(block_end)
    g = EqualizedConv(filters, kernal=3, padding="same")(upsampling)
    g = LeakyReLU(alpha=0.2)(g)
    g = PixelNorm()(g)
    g = EqualizedConv(filters, kernal=3, padding="same")(upsampling)
    g = LeakyReLU(alpha=0.2)(g)
    g = PixelNorm()(g)

    out_img = EqualizedConv(3, kernal=1, padding="same")(g)
    model1 = Model(inputs=old_model.input, outputs=out_img)
    out_old = old_model.layers[-1]
    out_img2 = out_old(upsampling)

    merged = WeighedSum()([out_img2, out_img])
    model2 = Model(old_model.input, merged)

    return [model1, model2]


In [None]:
# add discriminator block
def add_discrominator_block(old_model, filter1, filter2):
    in_shape = list(old_model.input_shape)

    input_shape = (in_shape[-2]*2, in_shape[-2]*2, in_shape[-1])
    in_img = Input(shape=input_shape)

    d = EqualizedConv(filter1, kernal=3, padding="same")(in_img)
    d = LeakyReLU(alpha=0.2)(d)

    d = EqualizedConv(filter1, kernal=3, padding="same")(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = EqualizedConv(filter2, kernal=3, padding="same")(d)
    d = LeakyReLU(alpha=0.2)(d)
    d = AveragePooling2D((2,2))(d)
    block_new = d

    # skipp input 1x1 and activation of old model
    for i in range(3, len(old_model.layers)):
        d = old_model.layers[i](d)
    
    model1 = Model(in_img, d)

    downsample = AveragePooling2D((2,2))(in_img)
    block_old = old_model.layers[1](downsample)
    block_old = old_model.layers[2](block_old)

    d = WeighedSum()([block_old, block_new])

    for i in range(3, len(old_model.layers)):
        d = old_model.layers[i](d)
    
    model2 = Model(in_img, d)
    return [model1, model2]




In [None]:
def add_blocks(gmodel, dmodel, filterg, filterd):
    gmodels = add_generator_block(gmodel, filterg)
    dmodels = add_discrominator_block(dmodel, filterd[0], filterd[1])

    return gmodels, dmodels

In [None]:
class GAN(keras.Model):
    def __init__(self, disc_fade, disc_direct, gen_fade, gen_direct, latent_dim):
        super().__init__()
        self.discrminator = disc_fade
        self.generator = gen_fade
        self.gen_direct = gen_direct
        self.disc_direct = disc_direct
        self.latent_dim = latent_dim
        self.seed = tf.random.normal([9, latent_dim])
        self.d_loss_tracker = keras.metrics.Mean(name="D_loss")
        self.g_loss_tracker = keras.metrics.Mean(name="G_loss")
        self.alpha_tracker = keras.metrics.Mean(name="Alpha")
        self.alpha=0.0

    def compile(self,g_opt, d_opt):
        super().compile()
        self.d_optimizer = d_opt
        self.g_optimizer = g_opt

    
    def get_models(self):
        return [self.gen_direct, self.disc_direct]
    
    def update_alpha(self, new_alp):
        self.alpha_tracker.reset_state()
        self.alpha_tracker.update_state(new_alp)
    
    def generator_loss(self, fake_img):
        return -tf.reduce_mean(fake_img)
    
    def discriminator_loss(self, real_img, fake_img, epsilion_drift=0.001):
        real_loss=tf.reduce_mean(real_img)
        fake_loss=tf.reduce_mean(fake_img)
        drift_loss = epsilion_drift*tf.reduce_mean(tf.square(real_img))
        return fake_loss-real_loss + drift_loss
    
    def gradient_penalty(self, real, fake, disc):
        alpha=tf.random.uniform([self.Batch_size, 1,1,1], 0.0, 1.0)
        diff = fake-real
        interpolated= real + alpha*diff
        
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            pred = disc(interpolated, training=True)

        grads = gp_tape.gradient(pred, [interpolated])[0]
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1,2,3]))
        gp = tf.reduce_mean((norm-1.0)**2)
        return gp
    

    def show_samples(self, epoch, res, s=True):
        predictions = self.gen_direct(self.seed, training=False)
        predictions = (predictions+1)/2

        fig = plt.figure(figsize=(5,5))

        for i in range(9):
            plt.subplot(3,3,i+1)
            plt.imshow(predictions[i])
            plt.axis("off")

        if s:
            plt.savefig(f"Images/image_at_{res}_at_epoch_{epoch}.png")
        plt.show()

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images= real_images[0]
        
        self.Batch_size=tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal([self.Batch_size, self.latent_dim])

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_img = self.generator(random_latent_vectors, training=True)

            real_out=self.discrminator(real_images, training=True)
            fake_out=self.discrminator(gen_img, training=True)

            gen_loss=self.generator_loss(fake_out)
            d_cost = self.discriminator_loss(real_out, fake_out)
            gp = self.gradient_penalty(real_images, gen_img, self.discrminator)
            disc_loss = d_cost + 10*gp

            gradients_of_genrator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
            self.g_optimizer.apply_gradients(zip(gradients_of_genrator, self.generator.trainable_variables))

            gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discrminator.trainable_variables)
            self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discrminator.trainable_variables))

        # update metrics
        self.d_loss_tracker.reset_state()
        self.g_loss_tracker.reset_state()
        self.d_loss_tracker.update_state(disc_loss)
        self.g_loss_tracker.update_state(gen_loss)

        return {
            "d_loss":self.d_loss_tracker.result(),
            "g_loss":self.g_loss_tracker.result(),
            "alpha":self.alpha_tracker.result(),
        }




In [None]:
from IPython import display
class CustomCallback(keras.callbacks.Callback):
    def __init__(self, step, n_step, cr, ckmanager):
        super().__init__()
        self.n_step=n_step
        self.step = step
        self.ckpt_manager = ckmanager
        self.res=cr
    
    def on_epoch_end(self, epoch, logs=None):
        display.clear_output(wait=True)
        self.model.show_samples(epoch, self.res)
        self.ckpt_manager.save()
        print("Model Saved")
    
    def on_train_batch_begin(self, batch, logs=None):
        self.step=self.step+1
        self.alp=update_fadein([self.model.generator, self.model.discrminator],self.step, self.n_step)
        self.model.update_alpha(self.alp)



In [None]:
path="/kaggle/input/celebahq-resized-256x256/celeba_hq_256/*"
res=16

def load(image_path):
    img=tf.io.read_file(image_path)
    img = tf.io.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, [res,res])
    img = tf.cast(img, tf.float32)
    img= (img-127.5)/127.5
    return img

dataset=tf.data.Dataset.list_files(path).map(load).batch(Batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

In [None]:
epochs=80
n_step=(epochs/2)*len(dataset)
print(len(dataset))
print(n_step)


In [None]:
with tf.device("/device:GPU:0"):
    g_model=define_generator(512)
    d_model = define_discriminator()
    gan=GAN(disc_fade=d_model, disc_direct=d_model, gen_fade=g_model, gen_direct=g_model, latent_dim=512)

D_optimizer = Adam(learning_rate=LR, beta_1=0.0, beta_2=0.99, epsilon=EPSILON)
G_optimizer = Adam(learning_rate=LR, beta_1=0.0, beta_2=0.99, epsilon=EPSILON)
gan.compile(G_optimizer, D_optimizer)

In [None]:
# Upsampler
with tf.device("/device:GPU:0"):
    g_model, d_model= gan.get_models()
    gmodels, dmodels = add_blocks(g_model, d_model, 512, [512,512])
    gan=GAN(disc_fade=dmodels[1], disc_direct=dmodels[0], gen_fade=gmodels[1], gen_direct=gmodels[0], latent_dim=512)

D_optimizer = Adam(learning_rate=LR, beta_1=0.0, beta_2=0.99, epsilon=EPSILON)
G_optimizer = Adam(learning_rate=LR, beta_1=0.0, beta_2=0.99, epsilon=EPSILON)
gan.compile(G_optimizer, D_optimizer)

In [None]:
# Check Points
checkpoint_path=f"./checkpoints/Models/RES_{res}"
ckpt=tf.train.Checkpoint(generator=gmodels[1],
                         discriminator=dmodels[1],
                         generator_direct=gmodels[0],
                         discriminator_direct=dmodels[0], )

ckpt_manager=tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)


In [None]:
# restore checkpoint
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print("checkpoint restored !")


In [None]:
os.mkdir("Images")

In [None]:
gan.fit(dataset, epochs=epochs, batch_size=Batch_size, callbacks=[CustomCallback(0, n_step,res, ckpt_manager)])

In [None]:
gan.show_samples(5,5,s=False)

In [None]:
# show sample from dataset
iteer=next(dataset.take(1).as_numpy_iterator())
plt.imshow(iteer[0])
plt.show()

In [None]:
# sample genrated by model
noise=tf.random.normal([1, 512])
img=g_model(noise, training=False)
img=(img[0]+1)/2
plt.imshow(img)
plt.axis("off")
plt.show()