In [None]:
# uncomment and run the lines below if running in google colab
# !pip install tensorflow==2.4.3
# !git clone https://github.com/jlaihong/image-super-resolution.git
# !mv image-super-resolution/* ./

In [None]:
import os
import glob
import numpy as np
from PIL import Image, ImageOps
import tensorflow as tf

from datasets.div2k.parameters import Div2kParameters
from models.srresnet import build_srresnet
from models.pretrained import pretrained_models
from utils.prediction import get_sr_image
from utils.config import config

In [None]:
dataset_key = "bicubic_x4"

data_path = config.get("data_path", "")

div2k_folder = os.path.abspath(os.path.join(data_path, "div2k"))

dataset_parameters = Div2kParameters(dataset_key, save_data_directory=div2k_folder)

In [None]:
def load_image(path):
    img = Image.open(path)

    was_grayscale = len(img.getbands()) == 1

    if was_grayscale or len(img.getbands()) == 4:
        img = img.convert('RGB')

    return was_grayscale, np.array(img)


In [None]:
model_name = "srresnet"
# model_name = "srgan"

In [None]:
model_key = f"{model_name}_{dataset_key}"

In [None]:
weights_directory = os.path.abspath(f"weights/{model_key}")

file_path = os.path.join(weights_directory, "generator.h5")

if not os.path.exists(file_path):
    os.makedirs(weights_directory, exist_ok=True)

    print("Couldn't find file: ", file_path, ", attempting to download a pretrained model")

    if model_key not in pretrained_models:
        print(f"Couldn't find pretrained model with key: {model_key}, available pretrained models: {pretrained_models.key()}")
    else:
        download_url = pretrained_models[model_key]
        file = file_path.split("/")[-1]
        tf.keras.utils.get_file(file, download_url, cache_subdir=weights_directory)

In [None]:
model = build_srresnet(scale=dataset_parameters.scale)

os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.h5'

model.load_weights(weights_file)

In [None]:
results_path = f"output/{model_key}/"

In [None]:
os.makedirs(results_path, exist_ok=True)

In [None]:
image_paths = glob.glob("input/*")

for image_path in image_paths:
    print(image_path)
    was_grayscale, lr = load_image(image_path)

    sr = get_sr_image(model, lr)

    if was_grayscale:
        sr = ImageOps.grayscale(sr)

    image_name = image_path.split("/")[-1]
    sr.save(f"{results_path}{image_name}" )

In [None]:
# zip files for download from colab

!zip -r images.zip output