# Inpainting using GAN-s

In [1]:
import tensorflow as tf
import keras 
from keras import layers
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
import os
from tqdm import tqdm
import re
from keras.preprocessing.image import img_to_array
import time




In [2]:
# Učitavanje svih slika, resize-ovanje na veličinu 255x255 i konvertovanje vrednosti u opseg [0, 1]
IMAGE_SIZE = 256

rgb_path = "dataset/color/"
grayscale_path = "dataset/gray/"
rgb_images = []
grayscale_images = []

image_names = os.listdir(rgb_path)
image_names = sorted(image_names, key=lambda x: int(x.split('.')[0]))

for image_name in tqdm(image_names):
    rgb_image = cv.imread(rgb_path + image_name)
    rgb_image = cv.cvtColor(rgb_image, cv.COLOR_BGR2RGB)
    rgb_image = cv.resize(rgb_image, (IMAGE_SIZE, IMAGE_SIZE))
    rgb_image = rgb_image.astype("float32") / 255.0
    rgb_images.append(img_to_array(rgb_image))

    grayscale_image = cv.imread(grayscale_path + image_name)
    grayscale_image = cv.cvtColor(grayscale_image, cv.COLOR_BGR2RGB)
    grayscale_image = cv.resize(grayscale_image, (IMAGE_SIZE, IMAGE_SIZE))
    grayscale_image = grayscale_image.astype("float32") / 255.0
    grayscale_images.append(img_to_array(grayscale_image))


 56%|██████████████▍           | 3969/7129 [00:57<00:45, 68.89it/s]


KeyboardInterrupt: 

In [None]:
# Podela podataka na trening i test setove. 
TRAINING_SIZE = int(len(image_names) * 0.8)

In [None]:
rgb_train_set = tf.data.Dataset.from_tensor_slices(rgb_images[:TRAINING_SIZE]).batch(64)

In [None]:
rgb_test_set = tf.data.Dataset.from_tensor_slices(rgb_images[TRAINING_SIZE:]).batch(8)

In [None]:
grayscale_train_set = tf.data.Dataset.from_tensor_slices(grayscale_images[:TRAINING_SIZE]).batch(64)

In [None]:
grayscale_test_set = tf.data.Dataset.from_tensor_slices(grayscale_images[TRAINING_SIZE:]).batch(8)

In [None]:
# Definisanje funkcija za dodavanje slojeva mreži.
def add_downsampling_layer(filters, size, normalize=True):
    layers = tf.keras.Sequential()
    layers.add(tf.keras.layers.Conv2D(filters, 
                                      size, 
                                      strides=2, 
                                      padding="same",
                                      kernel_initializer="he_normal",
                                      use_bias=False))
    if normalize:
        layers.add(tf.keras.layers.BatchNormalization())

    layers.add(tf.keras.layers.LeakyReLU())

    return layers

In [None]:
def add_upsampling_layer(filters, size, dropout=False):
    layers = tf.keras.Sequential()
    layers.add(tf.keras.layers.Conv2DTranspose(filters, 
                                      size, 
                                      strides=2, 
                                      padding="same",
                                      kernel_initializer="he_normal",
                                      use_bias=False))
    
    layers.add(tf.keras.layers.BatchNormalization())
    
    if dropout:
        layers.add(tf.keras.layers.Dropout(0.5))

    layers.add(tf.keras.layers.LeakyReLU())

    return layers

In [None]:
# Definisanje Generator mreže.
def make_generator():
    inputs = tf.keras.layers.Input(shape=[256, 256, 3])

    downsampling_layers = [
        add_downsampling_layer(64, 4, normalize=False),
        add_downsampling_layer(128, 4),
        add_downsampling_layer(256, 4),
        add_downsampling_layer(512, 4),
        add_downsampling_layer(512, 4),
        add_downsampling_layer(512, 4),
        add_downsampling_layer(512, 4),
        add_downsampling_layer(512, 4),
    ]

    upsampling_layers = [
        add_upsampling_layer(512, 4, dropout=True),
        add_upsampling_layer(512, 4, dropout=True),
        add_upsampling_layer(512, 4, dropout=True),
        add_upsampling_layer(512, 4),
        add_upsampling_layer(256, 4),
        add_upsampling_layer(128, 4),
        add_upsampling_layer(64, 4),
    ]

    last_layer = tf.keras.layers.Conv2DTranspose(3, 
                                                 4,
                                                 strides=2,
                                                 padding="same",
                                                 kernel_initializer=tf.random_normal_initializer(0., 0.02),
                                                 activation="tanh")
    x = inputs
    downsampled = []

    for layer in downsampling_layers:
        x = layer(x)
        downsampled.append(x)

    downsampled = downsampled[:-1][::-1]

    for i in range(len(upsampling_layers)):
        x = upsampling_layers[i](x)
        x = tf.keras.layers.Concatenate()([x, downsampled[i]])

    x = last_layer(x)

    return tf.keras.Model(inputs=inputs, outputs=x)                             

In [None]:
# Definisanje Diskriminator mreže.
def make_discriminator():
    image = tf.keras.layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 3], name="image")
    target = tf.keras.layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 3], name="target")

    inputs = tf.keras.layers.concatenate([image, target])

    downsampled1 = add_downsampling_layer(64, 4, False)(inputs)
    downsampled2 = add_downsampling_layer(128, 4)(downsampled1)
    downsampled3 = add_downsampling_layer(256, 4)(downsampled2)

    padding = tf.keras.layers.ZeroPadding2D()(downsampled3)    

    convolution = tf.keras.layers.Conv2D(512,
                                         4,
                                         strides=1,
                                         kernel_initializer=tf.random_normal_initializer(0., 0.02),
                                         use_bias=False)(padding)

    normalization = tf.keras.layers.BatchNormalization()(convolution)
    activation = tf.keras.layers.LeakyReLU()(normalization)
    padding2 = tf.keras.layers.ZeroPadding2D()(activation)
    final = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=tf.random_normal_initializer(0., 0.02))(padding2)

    return tf.keras.Model(inputs=[image, target], outputs=final)

In [None]:
# Definisanje loss funkcija.
losses = tf.keras.losses.BinaryCrossentropy(from_logits=True)
LAMBDA = 100

def generator_loss(disc_output, gen_output, target):
    gan_loss = losses(tf.ones_like(disc_output), disc_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    gen_loss = gan_loss + LAMBDA * l1_loss

    return gen_loss, gan_loss, l1_loss


def discriminator_loss(real, generated):
    real_loss = losses(tf.ones_like(real), real)
    generated_loss = losses(tf.zeros_like(generated), generated)
    disc_loss = real_loss + generated_loss

    return disc_loss

In [None]:
# Definisanje treninga.
generator = make_generator()
discriminator = make_discriminator()

gen_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

def training_step(image, target, epoch): #epoha se ne koristi
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(image, training=True)
        disc_real_output = discriminator([image, target], training=True)
        disc_generated_output = discriminator([image, gen_output], training=True)

        gen_loss, gan_loss, l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

        gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
        disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

        gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
        disc_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

In [None]:
# Definisanje fit funkcije.
def fit(train_dataset, num_epochs):
    for epoch in range(num_epochs):
        start_time = time.time()
        print("Epoch: ", epoch + 1)

        for n, (input_image, target) in train_dataset.enumerate():
            training_step(input_image, target, epoch)
            
        print (f"Time taken for epoch {epoch + 1} is {time.time() - start_time} sec\n")            

In [None]:
fit(tf.data.Dataset.zip((grayscale_train_set, rgb_train_set)), num_epochs=10)

In [None]:
generator.compile(loss="MSE", optimizer="adam", metrics=["accuracy"])

# Definisanje testiranja.
def test_and_evaluate(grayscale_test_dataset, rgb_test_dataset):
    print(grayscale_test_dataset.cardinality().numpy())
    average_loss = 0
    average_accuracy = 0
    num_data = grayscale_test_dataset.cardinality().numpy()
    
    for input, target in tqdm(tf.data.Dataset.zip((grayscale_test_dataset, rgb_test_dataset))):
        loss, accuracy = generator.evaluate(input, target)
        average_loss += loss
        average_accuracy += accuracy

    average_loss /= num_data
    average_accuracy /= num_data

    return average_loss, average_accuracy

In [None]:
average_loss, average_accuracy = test_and_evaluate(grayscale_test_set, rgb_test_set)
print(f"Average loss: {average_loss}")
print(f"Average accuracy: {average_accuracy}")

In [None]:
# Generisanje par primera radi prikaza.
def generate_examples(model, input, target):
    prediction = model(input)
    plt.figure(figsize=(15, 15))

    display = [input[0], target[0], prediction[0]]
    title = ["Grayscale", "Color", "Predicted"]

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
            
        plt.imshow(display_list[i])
        plt.axis("off")
            
    plt.show()

In [None]:
for input, target in tf.data.Dataset.zip((grayscale_test_set, rgb_test_set)).take(3):
    generate_examples(generator, input, target)