In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
import PIL
import matplotlib.pyplot as plt
import sys
import yaml

from utils.dataset import DIV2K_Dataset
from utils.model import create_model
from utils.constants import HR_IMG_SIZE, DOWNSAMPLE_MODE

In [None]:
with open("config.yaml", 'r') as stream:
    config = yaml.safe_load(stream)

model = keras.models.load_model("") # path to model

In [None]:
test_dataset = DIV2K_Dataset(
    hr_image_folder = "data/DIV2K_train_valid_HR/",
    batch_size = config["batch_size"],
    set_type = "test"
)

In [None]:
n = 5
psnrs = []

for _ in range(n):
    for batch in test_dataset:
        preds = model.predict(batch[0])
        psnr = tf.image.psnr(batch[1], preds, max_val = 1.0)
        psnr = psnr.numpy().tolist()
        psnrs.extend(psnr)
    
print("Mean PSNR: {:.3f}".format(np.mean(psnrs)))

In [None]:
batch_id = 1
batch = test_dataset.__getitem__(batch_id)
preds = model.predict(batch[0])

In [None]:
img_id = 19 # choose any image

plt.figure(figsize=[6, 6])
plt.subplot(2, 2, 1)
plt.imshow(batch[0][img_id])
plt.axis("off")
plt.title("LR Image")

plt.subplot(2, 2, 2)
plt.imshow(batch[1][img_id])
plt.axis("off")
plt.title("HR Image")

plt.subplot(2, 2, 3)
plt.imshow(preds[img_id])
plt.axis("off")
plt.title("Restored Image")

plt.subplot(2, 2, 4)
lr_image = PIL.Image.fromarray(np.array(batch[0][img_id] * 255, dtype="uint8"))
lr_image_resized = lr_image.resize(HR_IMG_SIZE, resample=DOWNSAMPLE_MODE)
plt.imshow(lr_image_resized)
plt.axis("off")
plt.title("Bilinear Upsampling")

plt.tight_layout()
plt.show()