In [None]:
!pip install -q tensorflow==2.4.1

In [None]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model

from skimage.util import random_noise
from skimage.filters import gaussian
from skimage.transform import rescale, resize
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity

In [None]:
ffhq_path = "../input/flickrfaceshq-dataset-ffhq/*"
celeba_path = "../input/celeba-dataset/img_align_celeba/img_align_celeba/*"

#HEIGHT = 512
#WIDTH = 512
CHANNELS = 3

FILTER = 64

SCALE = 8
ksize = 3

In [None]:
inputs = tf.keras.Input(shape=(None, None, CHANNELS))

b1c1 = layers.Conv2D(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(inputs)
b1c2 = layers.Conv2D(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(b1c1)
b1c3 = layers.Conv2D(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(b1c2)
b1c4 = layers.Conv2D(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(b1c3)

mp1 = layers.MaxPool2D((2, 2))(b1c4)

b2c1 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(mp1)
b2c2 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(b2c1)
b2c3 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(b2c2)
b2c4 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(b2c3)

mp2 = layers.MaxPool2D((2, 2))(b2c4)

b3c1 = layers.Conv2D(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(mp2)
b3c2 = layers.Conv2D(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(b3c1)
b3c3 = layers.Conv2D(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(b3c2)
b3c4 = layers.Conv2D(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(b3c3)

mp3 = layers.MaxPool2D((2, 2))(b3c4)

b4c1 = layers.Conv2D(FILTER, (ksize, ksize), padding = "same", activation = "relu")(mp3)
b4c2 = layers.Conv2D(FILTER, (ksize, ksize), padding = "same", activation = "relu")(b4c1)

x = layers.Conv2DTranspose(FILTER, (ksize, ksize), padding = "same", activation = "relu")(b4c2)
b4c3 = layers.add([x, b4c2])

x = layers.Conv2DTranspose(FILTER, (ksize, ksize), padding = "same", activation = "relu")(b4c3)
b4c4 = layers.add([x, b4c1])

us1 = layers.UpSampling2D((2, 2))(b4c4)

x = layers.Conv2DTranspose(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(us1)
b5c1 = layers.add([x, b3c4])
x = layers.Conv2DTranspose(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(b5c1)
b5c2 = layers.add([x, b3c3])
x = layers.Conv2DTranspose(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(b5c2)
b5c3 = layers.add([x, b3c2])
x = layers.Conv2DTranspose(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(b5c3)
b5c4 = layers.add([x, b3c1])

us2 = layers.UpSampling2D((2, 2))(b5c4)

x = layers.Conv2DTranspose(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(us2)
b6c1 = layers.add([x, b2c4])
x = layers.Conv2DTranspose(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(b6c1)
b6c2 = layers.add([x, b2c3])
x = layers.Conv2DTranspose(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(b6c2)
b6c3 = layers.add([x, b2c2])
x = layers.Conv2DTranspose(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(b6c3)
b6c4 = layers.add([x, b2c1])

us3 = layers.UpSampling2D((2, 2))(b6c4)

x = layers.Conv2DTranspose(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(us3)
b7c1 = layers.add([x, b1c4])
x = layers.Conv2DTranspose(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(b7c1)
b7c2 = layers.add([x, b1c3])
x = layers.Conv2DTranspose(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(b7c2)
b7c3 = layers.add([x, b1c2])
x = layers.Conv2DTranspose(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(b7c3)
b7c4 = layers.add([x, b1c1])

b8c1 = layers.Conv2DTranspose(FILTER // 16, (ksize, ksize), padding = "same", activation = "relu")(b7c4)
b8c2 = layers.Conv2DTranspose(FILTER // 16, (ksize, ksize), padding = "same", activation = "relu")(b8c1)
outputs = layers.Conv2DTranspose(CHANNELS, (ksize, ksize), padding = "same", activation = "relu")(b8c2)

autoencoder = Model(inputs=inputs, outputs=outputs, name="AE")

#autoencoder.compile(optimizer=opt, loss=losses.MeanAbsoluteError())

#autoencoder.summary()

In [None]:
inputs = tf.keras.Input(shape=(None, None, CHANNELS))

c0 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(inputs)
c1 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(c0)
c2 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(c1)

a = layers.add([c0, c2])

for i in range(3):
    x = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(a)
    x = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(x)

    a = layers.add([a, x])

c7 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(a)
outputs = layers.Conv2D(CHANNELS, (ksize, ksize), padding = "same", activation = "relu")(c7)

resnet = Model(inputs=inputs, outputs=outputs, name="RN")

#resnet.compile(optimizer=opt, loss=losses.MeanAbsoluteError())
#resnet.summary()

In [None]:
inputs = tf.keras.Input(shape=(None, None, CHANNELS))

b1s1c1 = layers.Conv2D(4, (1, 1), padding = "same", activation = "relu")(inputs)

b1s2c1 = layers.Conv2D(4, (1, 1), padding = "same", activation = "relu")(inputs)
b1s2c2 = layers.Conv2D(4, (3, 3), padding = "same", activation = "relu")(b1s2c1)

b1s3c1 = layers.Conv2D(4, (1, 1), padding = "same", activation = "relu")(inputs)
b1s3c2 = layers.Conv2D(4, (5, 5), padding = "same", activation = "relu")(b1s3c1)

concat1 = layers.concatenate([b1s1c1, b1s2c2, b1s3c2])


b2s1c1 = layers.Conv2D(8, (1, 1), padding = "same", activation = "relu")(concat1)

b2s2c1 = layers.Conv2D(8, (1, 1), padding = "same", activation = "relu")(concat1)
b2s2c2 = layers.Conv2D(8, (3, 3), padding = "same", activation = "relu")(b2s2c1)

b2s3c1 = layers.Conv2D(8, (1, 1), padding = "same", activation = "relu")(concat1)
b2s3c2 = layers.Conv2D(8, (5, 5), padding = "same", activation = "relu")(b2s3c1)

concat2 = layers.concatenate([b2s1c1, b2s2c2, b2s3c2])


b3s1c1 = layers.Conv2D(8, (1, 1), padding = "same", activation = "relu")(concat2)

b3s2c1 = layers.Conv2D(8, (1, 1), padding = "same", activation = "relu")(concat2)
b3s2c2 = layers.Conv2D(8, (3, 3), padding = "same", activation = "relu")(b3s2c1)

b3s3c1 = layers.Conv2D(8, (1, 1), padding = "same", activation = "relu")(concat2)
b3s3c2 = layers.Conv2D(8, (5, 5), padding = "same", activation = "relu")(b3s3c1)

concat3 = layers.concatenate([b3s1c1, b3s2c2, b3s3c2])


b4s1c1 = layers.Conv2D(4, (1, 1), padding = "same", activation = "relu")(concat3)

b4s2c1 = layers.Conv2D(4, (1, 1), padding = "same", activation = "relu")(concat3)
b4s2c2 = layers.Conv2D(4, (3, 3), padding = "same", activation = "relu")(b4s2c1)

b4s3c1 = layers.Conv2D(4, (1, 1), padding = "same", activation = "relu")(concat3)
b4s3c2 = layers.Conv2D(4, (5, 5), padding = "same", activation = "relu")(b4s3c1)

concat4 = layers.concatenate([b4s1c1, b4s2c2, b4s3c2])


outputs = layers.Conv2D(CHANNELS, (ksize, ksize), padding = "same", activation = "relu")(concat3)

inception = Model(inputs=inputs, outputs=outputs, name="IN")
#inception.summary()

In [None]:
inputs = tf.keras.Input(shape=(None, None, CHANNELS))

db1c1 = layers.Conv2D(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(inputs)
db1c2 = layers.Conv2D(FILTER // 8, (ksize, ksize), padding = "same", activation = "relu")(db1c1)
dmp1 = layers.MaxPool2D((2, 2))(db1c2)

db2c1 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(dmp1)
db2c2 = layers.Conv2D(FILTER // 4, (ksize, ksize), padding = "same", activation = "relu")(db2c1)
dmp2 = layers.MaxPool2D((2, 2))(db2c2)

db3c1 = layers.Conv2D(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(dmp2)
db3c2 = layers.Conv2D(FILTER // 2, (ksize, ksize), padding = "same", activation = "relu")(db3c1)
dmp3 = layers.MaxPool2D((2, 2))(db3c2)

db4c1 = layers.Conv2D(FILTER, (ksize, ksize), padding = "same", activation = "relu")(dmp3)
db4c2 = layers.Conv2D(FILTER, (ksize, ksize), padding = "same", activation = "relu")(db4c1)
dmp4 = layers.MaxPool2D((2, 2))(db4c2)

outputs = layers.Dense(1, activation = "sigmoid")(dmp4)

discriminator = Model(inputs=inputs, outputs=outputs, name="DC")
opt = tf.keras.optimizers.Adam()

#discriminator.compile(optimizer=opt, loss=losses.MeanSquaredError())
#discriminator.summary()

In [None]:
#plot_model(inception, show_shapes=True, show_layer_names=False)

In [None]:
ds = tf.data.Dataset.list_files(ffhq_path, shuffle=False)

In [None]:
def process_image_downscale(file_path):
    
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img)
    img = img / 255
    
    img_d = img[::SCALE, ::SCALE]
    img_d = tf.image.resize(img_d, [512, 512])
    
    return img_d, img

In [None]:
def process_image_noisy(file_path):
    
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img)
    img = img / 255
    
    return img

In [None]:
ds2 = ds.map(process_image_downscale)
ds2 = ds2.batch(32)
ds2 = ds2.take(1000)

In [None]:
for step, (x, y) in enumerate(ds2):
    
    plt.imshow(x[0])
    plt.show()
    #x = tf.clip_by_value(x + tf.random.normal(x.shape, stddev=0.4), 0, 1)
    plt.imshow(y[0])
    plt.show()
    
    break

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanAbsoluteError()

In [None]:
epochs = 3

for epoch in range(epochs):
    print(f"Start of epoch {epoch + 1}")
    start_time = time.perf_counter()
    
    for step, (x_batch_train) in enumerate(ds2):
        
        with tf.GradientTape() as tape:
            logits = resnet(tf.clip_by_value(
                                 x_batch_train + tf.random.normal(x_batch_train.shape, stddev=0.4),
                                 0, 1), training=True)
            loss_value = loss_fn(x_batch_train, logits)
            
        grads = tape.gradient(loss_value, resnet.trainable_weights)
        optimizer.apply_gradients(zip(grads, resnet.trainable_weights))
        
        if (step + 1) % 100 == 0:
            print("Training loss at step %d: %.4f"% (step + 1, float(loss_value)))
            print(f"{int(time.perf_counter() - start_time)} s")
            start_time = time.perf_counter()

In [None]:
epochs = 3

for epoch in range(epochs):
    print(f"Start of epoch {epoch + 1}")
    start_time = time.perf_counter()
    
    for step, (x_batch_train, y_batch_train) in enumerate(ds2):
        
        with tf.GradientTape() as tape:
            logits = resnet(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
            
        grads = tape.gradient(loss_value, resnet.trainable_weights)
        optimizer.apply_gradients(zip(grads, resnet.trainable_weights))
        
        if (step + 1) % 100 == 0:
            print("Training loss at step %d: %.4f"% (step + 1, float(loss_value)))
            print(f"{int(time.perf_counter() - start_time)} s")
            start_time = time.perf_counter()

In [None]:
count = 5

plt.figure(figsize=(15, 7 * count))

start = 40000

for n in range(start, start + count):
    
    test_path = f"{ffhq_path[:-1]}{n:05d}.png"
    
    test_img = cv2.imread(test_path)[:,:,::-1]
    test_img = test_img.reshape(-1 , 512, 512, CHANNELS)
    test_img = test_img / 255.
    
    test_img_noisy = tf.clip_by_value(test_img + tf.random.normal(test_img.shape,stddev=0.4), 0, 1).numpy()
    
    predict = resnet(test_img_noisy.reshape(-1, 512, 512, CHANNELS)).numpy()
    
    psnr_down = peak_signal_noise_ratio(test_img, test_img_noisy.astype('float32'))
    mae_down = loss_fn(test_img, test_img_noisy)
    ssim_down = structural_similarity(test_img[0], test_img_noisy[0], multichannel=True)
    
    psnr = peak_signal_noise_ratio(test_img, predict)
    mae = loss_fn(test_img, predict)
    ssim = structural_similarity(test_img[0], predict[0], multichannel=True)
    
    plt.subplot(count, 3, 1 + 3 * (n - start))
    plt.title(f"Noisy\nPSNR: {psnr_down:.2f} db\nMAE: {mae_down:.5f}\nSSIM: {ssim_down:.2f}")
    plt.axis("off")
    plt.imshow(test_img_noisy[0])
    
    plt.subplot(count, 3, 2 + 3 * (n - start))
    plt.title(f"Denoised\nPSNR: {psnr:.2f} db\nMAE: {mae:.5f}\nSSIM: {ssim:.2f}")
    plt.axis("off")
    plt.imshow(predict[0])
    
    plt.subplot(count, 3, 3 + 3 * (n - start))
    plt.title("Ground Truth")
    plt.axis("off")
    plt.imshow(test_img[0])

In [None]:
count = 5

plt.figure(figsize=(15, 7 * count))

start = 40000
SCALE = 8

for n in range(start, start + count):
    
    test_path = f"{ffhq_path[:-1]}{n:05d}.png"
    
    test_img = cv2.imread(test_path)[:,:,::-1]
    test_img = test_img.reshape(-1 , 512, 512, CHANNELS)
    test_img = test_img / 255.
    
    test_img_down = resize(test_img[:, ::SCALE, ::SCALE, :],  (1, 512, 512, CHANNELS))
    
    test_img_down_pix = test_img[:, ::SCALE, ::SCALE, :]
    
    predict = resnet(test_img_down.reshape(-1, 512, 512, CHANNELS)).numpy()
    
    psnr_down = peak_signal_noise_ratio(test_img, test_img_down.astype('float32'))
    mae_down = loss_fn(test_img, test_img_down)
    ssim_down = structural_similarity(test_img[0], test_img_down[0], multichannel=True)
    
    psnr = peak_signal_noise_ratio(test_img, predict)
    mae = loss_fn(test_img, predict)
    ssim = structural_similarity(test_img[0], predict[0], multichannel=True)
    
    plt.subplot(count, 3, 1 + 3 * (n - start))
    plt.title(f"Low Res\nPSNR: {psnr_down:.2f} db\nMAE: {mae_down:.5f}\nSSIM: {ssim_down:.2f}")
    plt.axis("off")
    plt.imshow(test_img_down_pix[0])
    
    plt.subplot(count, 3, 2 + 3 * (n - start))
    plt.title(f"{SCALE}x Upscaled\nPSNR: {psnr:.2f} db\nMAE: {mae:.5f}\nSSIM: {ssim:.2f}")
    plt.axis("off")
    plt.imshow(predict[0])
    
    plt.subplot(count, 3, 3 + 3 * (n - start))
    plt.title("Ground Truth")
    plt.axis("off")
    plt.imshow(test_img[0])

In [None]:
resnet.save("Resnet Upscale.h5")