In [None]:
import tensorflow as tf
import numpy as np


def round(x):
    return "%.2f" % np.round(x.numpy(), 2)


def SSIM(a, b):
    return round(tf.image.ssim(a, b, 255))


def PSNR(a, b):
    return round(tf.image.psnr(a, b, 255))

In [None]:
# Get validation images

from dlsr import *
from dlsr.data import DIV2K

# from PIL import Image

image_size = 400
loader = DIV2K(type="valid")
ds = loader.dataset(
    batch_size=1, random_transform=False, crop_images=True, image_size=image_size
)

hr = ds.map(lambda a, b: b).take(7)
lr = ds.map(lambda a, b: a).take(7)

upscaled = {}

In [None]:
# # Upscale with my SRGAN

discriminator = tf.keras.models.load_model(
    "./saved-models/srgan/discriminator.h5",
    custom_objects=losses.get_custom_objects(),
)

model = tf.keras.models.load_model(
    "./saved-models/srgan/generator.h5",
    custom_objects=losses.get_custom_objects(discriminator),
)

upscaled["srgan"] = upscale(model, lr)

In [None]:
# Upscale with Interploation-Based methods

lr_transformed = [np.uint8(x)[0] for x in lr]

new_size = [image_size, image_size]

algs = [
    tf.image.ResizeMethod.NEAREST_NEIGHBOR,
    tf.image.ResizeMethod.BILINEAR,
    tf.image.ResizeMethod.BICUBIC,
    tf.image.ResizeMethod.LANCZOS3,
]

for alg in algs:
    upscaled[alg] = tf.image.resize(
        lr_transformed,
        new_size,
        method=alg,
    )

In [None]:
# save LR images
for i, x in enumerate(lr):
    tf.keras.utils.save_img(path=f"output/lr/image{i}.png", x=x[0], file_format="png")

# save HR images
for i, x in enumerate(hr):
    tf.keras.utils.save_img(path=f"output/hr/image{i}.png", x=x[0], file_format="png")

# save upscaled images
for key in upscaled.keys():
    for i, x in enumerate(upscaled[key]):
        tf.keras.utils.save_img(
            path=f"output/{key}/image{i}.png", x=x, file_format="png"
        )

# calculate SSIM/PSNR
for i, x in enumerate(hr):
    real = np.uint8(x)[0]
    print("-------------------------------------")
    for j, key in enumerate(upscaled.keys()):
        fake = np.uint8(upscaled[key][i])
        print(
            "| {:<4} | {:<10} | {:<4} = {:<6} |".format(
                f"img{i}" if j == 0 else "", f"{key}", "SSIM", SSIM(real, fake)
            )
        )
        print("| {:<4} | {:<10} | {:<4} = {:<6} |".format("", "", "PSNR", PSNR(real, fake)))