In [None]:
from tensorflow.keras import mixed_precision

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)


In [None]:
import numpy as np
import cv2
import tensorflow as tf
import keras
from keras.utils import img_to_array, Sequence
import os
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from math import floor
import random
from datetime import datetime
from skimage.metrics import peak_signal_noise_ratio, structural_similarity


In [None]:
print(tf.__version__)
print(tf.config.list_physical_devices())


In [None]:
# 4 pixel padding is added, this make the inputs 256x256 and 128x128 for a nice 2^n size

HIGH_RES_IMAGE_SIZE = 2016
LOW_RES_IMAGE_SIZE = 992

HIGH_RES_CHUNK_SIZE = 252
LOW_RES_CHUNK_SIZE = 126

high_res_path = "../input/high_res"
low_res_path = "../input/low_res"

ADDED_PADDING = 2
LOW_RES_PADDING = 1
NETWORK_IMAGE_SIZE = HIGH_RES_CHUNK_SIZE + ADDED_PADDING * 2


def prepare_images(high_res_filenames, low_res_paths):
    high_res_tiles = []
    for high_res_filename in high_res_filenames:
        image = cv2.imread(f"{high_res_path}/{high_res_filename}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype("float32") / 255.0
        height, width, _ = image.shape
        assert height == width
        assert height == HIGH_RES_IMAGE_SIZE
        assert width == HIGH_RES_IMAGE_SIZE
        np_image = img_to_array(image)
        for y in range(0, height, HIGH_RES_CHUNK_SIZE):
            for x in range(0, width, HIGH_RES_CHUNK_SIZE):
                tile = np_image[
                    y: y + HIGH_RES_CHUNK_SIZE, x: x + HIGH_RES_CHUNK_SIZE
                ]
                tile = cv2.copyMakeBorder(
                    tile, ADDED_PADDING, ADDED_PADDING, ADDED_PADDING, ADDED_PADDING, cv2.BORDER_REFLECT)
                high_res_tiles.append(tile)
    low_res_tiles = []
    for low_res_filename in low_res_paths:
        image = cv2.imread(f"{low_res_path}/{low_res_filename}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype("float32") / 255.0
        height, width, _ = image.shape
        assert height == width
        assert height == LOW_RES_IMAGE_SIZE
        assert width == LOW_RES_IMAGE_SIZE
        np_image = img_to_array(image)
        for y in range(0, height, LOW_RES_CHUNK_SIZE):
            for x in range(0, width, LOW_RES_CHUNK_SIZE):
                tile = np_image[y: y + LOW_RES_CHUNK_SIZE,
                                x: x + LOW_RES_CHUNK_SIZE]
                tile = cv2.copyMakeBorder(
                    tile, LOW_RES_PADDING, LOW_RES_PADDING, LOW_RES_PADDING, LOW_RES_PADDING, cv2.BORDER_REFLECT)
                tile = cv2.resize(
                    tile, (NETWORK_IMAGE_SIZE, NETWORK_IMAGE_SIZE))
                low_res_tiles.append(img_to_array(tile))

    return high_res_tiles, low_res_tiles


In [None]:
VALIDATION_IMAGES = 0.2
TEST_IMAGES = 0.1

high_res_paths = os.listdir(high_res_path)
low_res_paths = os.listdir(low_res_path)

assert len(high_res_paths) == len(high_res_paths)

data_size = len(high_res_paths)

high_res_training = high_res_paths[
    : -(floor(VALIDATION_IMAGES * data_size) + floor(TEST_IMAGES * data_size))
]
low_res_training = low_res_paths[
    : -(floor(VALIDATION_IMAGES * data_size) + floor(TEST_IMAGES * data_size))
]
high_res_validation = high_res_paths[
    -(floor(VALIDATION_IMAGES * data_size) + floor(TEST_IMAGES * data_size)): -floor(
        TEST_IMAGES * data_size
    )
]
low_res_validation = low_res_paths[
    -(floor(VALIDATION_IMAGES * data_size) + floor(TEST_IMAGES * data_size)): -floor(
        TEST_IMAGES * data_size
    )
]
high_res_test = high_res_paths[-floor(TEST_IMAGES * data_size):]
low_res_test = low_res_paths[-floor(TEST_IMAGES * data_size):]

len(high_res_training), len(high_res_validation), len(high_res_test)


In [None]:
class DataGenerator(Sequence):
    def __init__(self, x_paths, y_paths, items_per_image, pre_shape_size, batch_size):
        self.x_paths, self.y_paths = x_paths, y_paths
        self.pre_shape_size, self.batch_size = pre_shape_size, batch_size
        self.items_per_image = items_per_image
        self._reshape(0)

    def __len__(self):
        return int(
            np.ceil(len(self.x_paths) * self.items_per_image /
                    float(self.batch_size))
        )

    def shuffle(self):
        seed = random.randint(0, 1024)
        random.seed(seed)
        random.shuffle(self.y_paths)
        random.seed(seed)
        random.shuffle(self.x_paths)

    def _paths_index(self, index):
        return floor(int(index * self.batch_size) / self.items_per_image)

    def _reshape(self, new_index):
        self.shape_index = new_index
        self.paths_index = self._paths_index(new_index)

        self.x_pre_shaped = None
        self.y_pre_shaped = None

        y_shuffled = self.y_paths[
            self.paths_index: self.paths_index + self.pre_shape_size
        ]
        x_shuffled = self.x_paths[
            self.paths_index: self.paths_index + self.pre_shape_size
        ]

        high_res_tiles, low_res_tiles = prepare_images(
            y_shuffled,
            x_shuffled,
        )

        new_len = len(high_res_tiles)

        self.x_pre_shaped = np.reshape(
            low_res_tiles,
            (
                new_len,
                NETWORK_IMAGE_SIZE,
                NETWORK_IMAGE_SIZE,
                3,
            ),
        )
        self.y_pre_shaped = np.reshape(
            high_res_tiles,
            (
                new_len,
                NETWORK_IMAGE_SIZE,
                NETWORK_IMAGE_SIZE,
                3,
            ),
        )

    def __getitem__(self, index):
        if isinstance(index, slice):
            results_x, results_y = [], []
            for image in range(index.start, index.stop):
                x, y = self.__getitem__(image)
                results_x.extend(x)
                results_y.extend(y)
            return results_x, results_y
        else:
            paths_index = self._paths_index(index)

            if (
                index < self.shape_index
                or paths_index > self.paths_index + self.pre_shape_size - 1
            ):
                self._reshape(index)

            index = index - self.shape_index

            batch_x = self.x_pre_shaped[
                index * self.batch_size: (index + 1) * self.batch_size
            ]
            batch_y = self.y_pre_shaped[
                index * self.batch_size: (index + 1) * self.batch_size
            ]
            return batch_x, batch_y


In [None]:
BATCH_SIZE = 4
PRE_SHAPE_SIZE = 16
TILES_PER_IMAGE = (HIGH_RES_IMAGE_SIZE / HIGH_RES_CHUNK_SIZE) ** 2

training_generator = DataGenerator(
    low_res_training, high_res_training, TILES_PER_IMAGE, PRE_SHAPE_SIZE, BATCH_SIZE
)
validation_generator = DataGenerator(
    low_res_validation,
    high_res_validation,
    TILES_PER_IMAGE,
    PRE_SHAPE_SIZE,
    BATCH_SIZE,
)
test_generator = DataGenerator(
    low_res_test, high_res_test, TILES_PER_IMAGE, PRE_SHAPE_SIZE, 1
)


def regen_data():
    training_generator.shuffle()
    validation_generator.shuffle()
    test_generator.shuffle()


regen_data()


In [None]:
def create_model():
    input_img = tf.keras.layers.Input(
        shape=(NETWORK_IMAGE_SIZE, NETWORK_IMAGE_SIZE, 3))

    l1 = tf.keras.layers.Conv2D(
        64,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(input_img)
    l2 = tf.keras.layers.Conv2D(
        64,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l1)
    l3 = tf.keras.layers.MaxPool2D(padding="same")(l2)

    l4 = tf.keras.layers.Conv2D(
        128,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l3)
    l5 = tf.keras.layers.Conv2D(
        128,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l4)
    l6 = tf.keras.layers.MaxPool2D(padding="same")(l5)

    l7 = tf.keras.layers.Conv2D(
        256,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l6)

    l8 = tf.keras.layers.Conv2DTranspose(
        256,
        (3, 3),
        strides=(2, 2),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l7)
    l9 = tf.keras.layers.Conv2D(
        128,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l8)
    l10 = tf.keras.layers.Conv2D(
        128,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l9)

    l11 = tf.keras.layers.add([l10, l5])

    l12 = tf.keras.layers.Conv2DTranspose(
        128,
        (3, 3),
        strides=(2, 2),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l11)
    l13 = tf.keras.layers.Conv2D(
        64,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l12)
    l14 = tf.keras.layers.Conv2D(
        64,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l13)

    l15 = tf.keras.layers.add([l14, l2])

    decoded_image = tf.keras.layers.Conv2D(
        3,
        (3, 3),
        padding="same",
        kernel_initializer="he_uniform",
        activation="relu",
        dtype='float32',
        activity_regularizer=tf.keras.regularizers.l1(10e-10),
    )(l15)

    return tf.keras.models.Model(inputs=(input_img), outputs=decoded_image)


tf.keras.utils.plot_model(
    create_model(), show_shapes=True, show_layer_names=False)


In [None]:
class Callbacks(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        regen_data()


## MSE loss model


In [None]:
mse_model = create_model()

mse_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
    loss="mean_squared_error",
    metrics=["acc"],
)

for _ in range(3):
    mse_model.fit(
        training_generator,
        shuffle=False,
        epochs=2,
        validation_data=validation_generator,
        callbacks=[Callbacks()]
    )

    timestamp = int(datetime.timestamp(datetime.now()) * 1000)
    mse_model.save(f'mse-checkpoint-{timestamp}.h5')


In [None]:
for _ in range(3):
    mse_model.fit(
        training_generator,
        shuffle=False,
        epochs=2,
        validation_data=validation_generator,
        callbacks=[Callbacks()]
    )

    timestamp = int(datetime.timestamp(datetime.now()) * 1000)
    mse_model.save(f'mse-checkpoint-{timestamp}.h5')


In [None]:
# mse_best = tf.keras.models.load_model(
#     "mse-checkpoint-1655412386260.h5"
# )

mse_best = mse_model


## SSIM loss


In [None]:
def SSIMLoss(y_true, y_pred):
    return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))


In [None]:
ssim_model = create_model()

ssim_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=SSIMLoss,
    metrics=["acc"],
)


In [None]:
for _ in range(3):
    ssim_history = ssim_model.fit(
        training_generator,
        epochs=2,
        validation_data=validation_generator,
        callbacks=[Callbacks()])

    timestamp = int(datetime.timestamp(datetime.now()) * 1000)
    ssim_model.save(f'ssim-checkpoint-{timestamp}.h5')


In [None]:
# ssim_best = tf.keras.models.load_model(
#     "ssim-checkpoint-1654698942576.h5", custom_objects={"SSIMLoss": SSIMLoss, "ssim_loss": SSIMLoss}
# )

ssim_best = ssim_model


## Perceptual loss


In [None]:
from lpipstf.lpips_tf import get_lpips, load_lpips

graph_file = load_lpips()


In [None]:
pl_model = create_model()

pl_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
    loss=get_lpips(graph_file),
    metrics=["acc"],
)

pl_history = pl_model.fit(
    training_generator,
    epochs=5,
    validation_data=validation_generator,
)


In [None]:
timestamp = int(datetime.timestamp(datetime.now()) * 1000)
pl_model.save(f'pl-checkpoint-{timestamp}.h5')


In [None]:
pl_history_2 = pl_model.fit(
    training_generator,
    epochs=2,
    validation_data=validation_generator,
    callbacks=[Callbacks()],
)


In [None]:
def plot_images(high, low, predicted):
    plt.figure(figsize=(50, 50))
    plt.subplot(1, 3, 1)
    plt.title("High Image", color="green", fontsize=20)
    plt.imshow(high)
    plt.subplot(1, 3, 2)
    plt.title("Low Image ", color="black", fontsize=20)
    plt.imshow(low)
    plt.subplot(1, 3, 3)
    plt.title("Predicted Image ", color="Red", fontsize=20)
    plt.imshow(predicted)
    plt.show()


In [None]:
def reconstruct_my_image(index, pictures, tilings, model):
    (vertical_segments, horizontal_segments), res = tilings[index]
    start_index = sum(v * h for (v, h), _ in tilings[:index])
    end_index = start_index + vertical_segments * horizontal_segments

    def join_images(tiles, local_tiling):
        result = None
        (h, v), resolutions = local_tiling

        for h_i in range(h):
            current_row = None

            for v_i in range(v):
                tile = tiles[v_i + h_i * v]
                local_height, local_width = resolutions[v_i + h_i * v]
                tile = tile[ADDED_PADDING:-ADDED_PADDING,
                            ADDED_PADDING:-ADDED_PADDING]

                tile = cv2.resize(tile, (local_width, local_height))
                tile = cv2.copyMakeBorder(tile, 0, HIGH_RES_CHUNK_SIZE - local_height,
                                          0, HIGH_RES_CHUNK_SIZE - local_width, cv2.BORDER_CONSTANT)
                if current_row is None:
                    current_row = np.array(tile)
                else:
                    current_row = np.concatenate((current_row, tile), axis=1)

            if result is None:
                result = current_row
            else:
                result = np.concatenate((result, current_row), axis=0)

        total_height = sum([h for i, (h, _) in enumerate(
            resolutions) if i % horizontal_segments == 0])
        total_width = sum([v for i, (_, v) in enumerate(
            resolutions) if i < vertical_segments])

        current_height, current_width, _ = result.shape

        if current_height != total_height:
            result = result[:-(current_height - total_height), :]
        if current_width != total_width:
            result = result[:, :-(current_width - total_width)]

        return result

    highres, lowres = pictures[start_index:end_index]

    reconstructed_high_res = join_images(
        highres, tilings[index]
    )
    reconstructed_low_res = join_images(
        lowres, tilings[index]
    )
    predicted_images = (
        np.clip(
            model.predict(
                np.reshape(
                    lowres,
                    (
                        end_index - start_index,
                        NETWORK_IMAGE_SIZE,
                        NETWORK_IMAGE_SIZE,
                        3,
                    ),
                )
            ),
            0.0,
            1.0,
        )
        .reshape(end_index - start_index, NETWORK_IMAGE_SIZE, NETWORK_IMAGE_SIZE, 3)
    )
    reconstructed_predicted = join_images(
        predicted_images, tilings[index]
    )

    print(reconstructed_high_res.shape, reconstructed_low_res.shape,
          reconstructed_predicted.shape)

    return reconstructed_high_res, reconstructed_low_res, reconstructed_predicted


In [None]:
def save_images(hr, lr, r, prefix=""):
    if hr is not None:
        cv2.imwrite(
            f"../final-out/{prefix}-hr.png",
            cv2.cvtColor(
                cv2.normalize(hr, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U),
                cv2.COLOR_RGB2BGR,
            ),
        )
    if lr is not None:
        cv2.imwrite(
            f"../final-out/{prefix}-lr.png",
            cv2.cvtColor(
                cv2.normalize(lr, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U),
                cv2.COLOR_RGB2BGR,
            ),
        )
    if r is not None:
        cv2.imwrite(
            f"../final-out/{prefix}-r.png",
            cv2.cvtColor(
                cv2.normalize(r, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U),
                cv2.COLOR_RGB2BGR,
            ),
        )


In [None]:
models = [mse_best, ssim_best]


test_results = []
for model in models:
    psnr = []
    ssim = []
    for i in range(100):
        lr, hr, r = reconstruct_my_image(i, test_generator,
                                         [((8, 8), [(252, 252) for _ in range(64)]) for _ in range(len(high_res_paths))], model)
        psnr.append(peak_signal_noise_ratio(
            hr, r))
        ssim.append(structural_similarity(
            hr, r, multichannel=True))

    test_results.append((sum(psnr)/len(psnr), sum(ssim)/len(ssim)))


In [None]:
from math import ceil


def prepare_test_images(hr_files, lr_files, path):
    tilings = []
    high_res_tiles = []

    print('Preparing high res...')
    for high_res_filename in hr_files:
        image = cv2.imread(f"{path}/{high_res_filename}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype("float32") / 255.0
        height, width, _ = image.shape
        np_image = img_to_array(image)
        local_size = (ceil(height / HIGH_RES_CHUNK_SIZE),
                      ceil(width / HIGH_RES_CHUNK_SIZE))
        local_tilings = []
        for y in range(0, height, HIGH_RES_CHUNK_SIZE):
            for x in range(0, width, HIGH_RES_CHUNK_SIZE):
                tile = np_image[
                    y: y + HIGH_RES_CHUNK_SIZE, x: x + HIGH_RES_CHUNK_SIZE
                ]
                local_height, local_width, _ = tile.shape
                local_tilings.append((local_height, local_width))
                tile = cv2.resize(
                    tile, (HIGH_RES_CHUNK_SIZE, HIGH_RES_CHUNK_SIZE))
                tile = cv2.copyMakeBorder(
                    tile, ADDED_PADDING, ADDED_PADDING, ADDED_PADDING, ADDED_PADDING, cv2.BORDER_REFLECT)
                high_res_tiles.append(tile)
        tilings.append((local_size, local_tilings))

    print('Preparing low res...')
    low_res_tiles = []
    for low_res_filename in lr_files:
        image = cv2.imread(f"{path}/{low_res_filename}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.astype("float32") / 255.0
        height, width, _ = image.shape
        np_image = img_to_array(image)
        for y in range(0, height, LOW_RES_CHUNK_SIZE):
            for x in range(0, width, LOW_RES_CHUNK_SIZE):
                tile = np_image[y: y + LOW_RES_CHUNK_SIZE,
                                x: x + LOW_RES_CHUNK_SIZE]
                tile = cv2.resize(
                    tile, (HIGH_RES_CHUNK_SIZE, HIGH_RES_CHUNK_SIZE))
                tile = cv2.copyMakeBorder(
                    tile, ADDED_PADDING, ADDED_PADDING, ADDED_PADDING, ADDED_PADDING, cv2.BORDER_REFLECT)
                low_res_tiles.append(img_to_array(tile))

    return high_res_tiles, low_res_tiles, tilings


In [None]:
def reconstruct_test_image(index, high_resolution_pictures, low_resolution_pictures, tilings, model):
    (vertical_segments, horizontal_segments), res = tilings[index]
    start_index = sum(v * h for (v, h), _ in tilings[:index])
    end_index = start_index + vertical_segments * horizontal_segments

    def join_images(tiles, local_tiling):
        result = None
        (h, v), resolutions = local_tiling

        for h_i in range(h):
            current_row = None

            for v_i in range(v):
                tile = tiles[v_i + h_i * v]
                local_height, local_width = resolutions[v_i + h_i * v]
                tile = tile[ADDED_PADDING:-ADDED_PADDING,
                            ADDED_PADDING:-ADDED_PADDING]
                tile = cv2.resize(tile, (local_width, local_height))
                tile = cv2.copyMakeBorder(tile, 0, HIGH_RES_CHUNK_SIZE - local_height,
                                          0, HIGH_RES_CHUNK_SIZE - local_width, cv2.BORDER_CONSTANT)
                if current_row is None:
                    current_row = np.array(tile)
                else:
                    current_row = np.concatenate((current_row, tile), axis=1)

            if result is None:
                result = current_row
            else:
                result = np.concatenate((result, current_row), axis=0)

        total_height = sum([h for i, (h, _) in enumerate(
            resolutions) if i % horizontal_segments == 0])
        total_width = sum([v for i, (_, v) in enumerate(
            resolutions) if i < vertical_segments])

        current_height, current_width, _ = result.shape

        if current_height != total_height:
            result = result[:-(current_height - total_height), :]
        if current_width != total_width:
            result = result[:, :-(current_width - total_width)]

        return result

    reconstructed_high_res = join_images(
        high_resolution_pictures[start_index:end_index], tilings[index]
    )
    reconstructed_low_res = join_images(
        low_resolution_pictures[start_index:end_index], tilings[index]
    )
    predicted_images = (
        np.clip(
            model.predict(
                np.reshape(
                    low_resolution_pictures[start_index:end_index],
                    (
                        end_index - start_index,
                        NETWORK_IMAGE_SIZE,
                        NETWORK_IMAGE_SIZE,
                        3,
                    ),
                )
            ),
            0.0,
            1.0,
        )
        .reshape(end_index - start_index, NETWORK_IMAGE_SIZE, NETWORK_IMAGE_SIZE, 3)
    )
    reconstructed_predicted = join_images(
        predicted_images, tilings[index]
    )

    return reconstructed_high_res, reconstructed_low_res, reconstructed_predicted


In [None]:
import re

# This expects files to have *_HR or *_LR endings to find the correct versions


def create_image_pairs(image_paths):
    matched_paths = [path for path in image_paths if re.match(
        r'.*(HR|LR)\.png', path)]
    sorted_paths = sorted(matched_paths)
    high_res = [path for index, path in enumerate(
        sorted_paths) if index % 2 == 0]
    low_res = [path for index, path in enumerate(
        sorted_paths) if index % 2 == 1]

    return high_res, low_res


In [None]:

bsd100_path = '../dataset-tests/BSD100'
set14_path = '../dataset-tests/Set14'
set5_path = '../dataset-tests/Set5'

bsd100_high, bsd100_low = create_image_pairs(os.listdir(bsd100_path))
set14_high, set14_low = create_image_pairs(os.listdir(set14_path))
set5_high, set5_low = create_image_pairs(os.listdir(set5_path))

sets = [(len(bsd100_high), prepare_test_images(bsd100_high, bsd100_low, bsd100_path))
        (len(set14_high), prepare_test_images(set14_high, set14_low, set14_path)),
        (len(set5_high), prepare_test_images(set5_high, set5_low, set5_path))]


In [None]:
image_num = 2
reconstructed_high_res, reconstructed_low_res, reconstructed_predicted = reconstruct_test_image(
    image_num, sets[2][1][0], sets[2][1][1], sets[2][1][2], mse_best)

plot_images(reconstructed_high_res, reconstructed_low_res,
            reconstructed_predicted)
save_images(None, None, reconstructed_predicted, f'set5-{image_num}-mse')


In [None]:
models = [mse_best, ssim_best]


results = []
for model in models:
    model_results = []
    for dataset in sets:
        set_size, (set_high, set_low, set_tilings) = dataset
        psnr = []
        ssim = []
        for i in range(set_size):
            reconstructed_high_res, reconstructed_low_res, reconstructed_predicted = reconstruct_test_image(
                i, set_high, set_low, set_tilings, model)
            plot_images(reconstructed_high_res,
                        reconstructed_low_res, reconstructed_predicted)
            print(peak_signal_noise_ratio(
                reconstructed_high_res, reconstructed_predicted), structural_similarity(
                reconstructed_high_res, reconstructed_predicted, multichannel=True))
            psnr.append(peak_signal_noise_ratio(
                reconstructed_high_res, reconstructed_predicted))
            ssim.append(structural_similarity(
                reconstructed_high_res, reconstructed_predicted, multichannel=True))

        model_results.append((sum(psnr)/len(psnr), sum(ssim)/len(ssim)))
    results.append(model_results)


In [None]:
results
