### The objective is to test some images and see the metric results

## Imports

In [None]:
## Imports

from src.dataset.PatchImageTool import PatchImageTool
from src.utils.PytorchUtil import PytorchUtil as torchUtil
from src.utils.PlotUtils import PlotUtils

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import skimage.metrics as metrics
import platform  # Import the platform module to detect the OS
import cv2
import math

import platform
import os

device = None
if platform.system() == 'Windows':  # Check if the OS is Windows
    import torch_directml  # Import torch_directml only on Windows
    device = torch_directml.device()

force_cpu = True

if not device:
    if torch.cuda.is_available() and not force_cpu:
        device = torch.device('cuda')
        torch.cuda.empty_cache()
    else:
        device = torch.device('cpu')

print(device)

## Define the model

In [None]:
from src.models.InitModel import InitModel

UPSCALE_FACTOR = 2

PATCH_SIZE = 256
PATCH_RESIZE_SIZE = PATCH_SIZE // UPSCALE_FACTOR

LEARNING_RATE = 1e-4
BATCH_SIZE = 256
USE_PREDICTION_BY_PATCH = False 

IMAGE_DIM = "1920x1080"

CHANNELS = ["b", "g", "r"]
CHANNELS_POSITION = {"b": 0, "g": 1, "r": 2, "d": 3, "s": 4}

SEED = None

IMAGE = "example.png"

PATH = "results/weights-upscale-residual-lpips-v.2"
NAME = PATH


In [None]:
exp = InitModel.create_model_static(NAME, PATH, {"learningRate": LEARNING_RATE, "channels" : CHANNELS}, UPSCALE_FACTOR, device)

if SEED:
    torch.manual_seed(SEED)
    np.random.seed(SEED)

## Load image and test

In [None]:
import torchvision.transforms as transforms

image_transform = transforms.Compose([
    #transforms.RandomHorizontalFlip(),
    #transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
])

In [None]:
resources_path = "resources"

hr_data_path = os.path.join(resources_path, IMAGE)

hr_data_np = torchUtil.open_data(hr_data_path)
#hr_data_np = torchUtil.norm_numpy_image(hr_data_np)

hr_img_size = (hr_data_np.shape[1], hr_data_np.shape[0])

# apply transform to image
hr_data_tensor = image_transform(hr_data_np).to(device)

hr_img_tensor = torchUtil.filter_data(hr_data_tensor, {"b", "g", "r"}, CHANNELS_POSITION)
hr_img_np = torchUtil.tensor_to_numpy(hr_img_tensor)

# divide image image by x
resized_data_np = torchUtil.resize_data(hr_data_np, (hr_img_size[0] // UPSCALE_FACTOR, hr_img_size[1] // UPSCALE_FACTOR), CHANNELS, CHANNELS_POSITION)
resized_data_tensor = image_transform(resized_data_np).to(device)
resized_img_tensor = torchUtil.filter_data(resized_data_tensor, {"b", "g", "r"}, CHANNELS_POSITION)

resized_img_np = torchUtil.tensor_to_numpy(resized_img_tensor)
resized_img_size = (resized_img_np.shape[0], resized_img_np.shape[1])

print(resized_img_tensor.shape, hr_img_tensor.shape)

PlotUtils.show_high_low_res_images([resized_img_tensor], [hr_img_tensor], upscale_factor_list=[UPSCALE_FACTOR], plot_title="Example")

In [None]:
num_patch_width = math.ceil(resized_img_size[0] / PATCH_RESIZE_SIZE)
num_patch_height = math.ceil(resized_img_size[1] / PATCH_RESIZE_SIZE)

num_patch_total = num_patch_width * num_patch_height

resized_data_patch_tensors = PatchImageTool.get_patchs_from_image(resized_data_tensor, patch_size=PATCH_RESIZE_SIZE)

resized_img_patch_tensors = []
for patch in resized_data_patch_tensors:
    resized_img_patch_tensors.append(torchUtil.filter_data(patch, {"b", "g", "r"}, CHANNELS_POSITION))

In [None]:
num_cols = min(5, num_patch_width)
num_rows = min(5, num_patch_height)

# Show some image patches
fig, ax = plt.subplots(num_rows, num_cols, figsize=(3 * num_cols, 3 * num_rows))
for i in range(num_rows):
    for j in range(num_cols):
        # ndarray
        
        ax[i, j].imshow(torchUtil.tensor_to_image(resized_img_patch_tensors[i * num_patch_width + j]))
        ax[i, j].set_title(f"Patch {i * num_patch_width + j}")
        ax[i, j].axis('off')

In [None]:
# Make the model predict one patch
patch_index = 0

resized_data_patch_tensor = resized_data_patch_tensors[patch_index].to(device)
resized_img_patch_tensor = resized_img_patch_tensors[patch_index]
print("Patch tensor size", resized_data_patch_tensor.shape, "type", resized_data_patch_tensor.dtype)

with torch.no_grad():
    pred_img_tensors = exp.net(resized_data_patch_tensor)
    pred_img_tensor = pred_img_tensors.squeeze(0)
    """prediction = exp.net(prediction)
    prediction = prediction.squeeze(0)"""

    pred_img_np = torchUtil.tensor_to_numpy(pred_img_tensor)

    print("Prediction tensor size", pred_img_np.shape, "type", pred_img_np.dtype)

    bicubic_img_np = torchUtil.resize_tensor_to_numpy(resized_img_patch_tensor, (PATCH_SIZE, PATCH_SIZE))

    subtraction_img_np = torchUtil.norm_numpy_image(pred_img_np - bicubic_img_np)

fig, ax = plt.subplots(1, 3, figsize=(10, 10))
ax[0].imshow(torchUtil.tensor_to_image(resized_img_patch_tensor))
ax[0].set_title(f"Low res patch")

ax[1].imshow(torchUtil.numpy_to_image(pred_img_np))
ax[1].set_title(f"Prediction")

ax[2].imshow(torchUtil.numpy_to_image(subtraction_img_np))
ax[2].set_title(f"Subtraction")

plt.show()

In [None]:
# Use image
#image_upscaled_torch = prediction
# Use patches reconstructed image
pred_img_torch = None

if USE_PREDICTION_BY_PATCH:
    pred_img_torch = PatchImageTool.predict_image_from_image_patches(
                        exp, hr_img_size, resized_data_patch_tensors, 
                        device, 
                        patch_size=PATCH_RESIZE_SIZE, upscale_factor=UPSCALE_FACTOR)
    print("Used prediction by patch")
else :
    pred_img_torch = exp.net(resized_data_tensor.unsqueeze(0)).squeeze(0)    
    print("Used prediction by image")

image_to_show = torchUtil.numpy_to_image(hr_img_np)

pred_img_np = torchUtil.tensor_to_numpy(pred_img_torch)
pred_img = torchUtil.numpy_to_image(pred_img_np)

bicubic_img_np = torchUtil.resize_tensor_to_numpy(resized_img_tensor, (hr_img_size[1], hr_img_size[0]))
bicubic_img = torchUtil.numpy_to_image(bicubic_img_np)

substract_img_np = torchUtil.norm_numpy_image(bicubic_img_np - pred_img_np)
substract_img = torchUtil.numpy_to_image(substract_img_np)

In [None]:
# Show the original image, the predicted and the substracted image

fig, ax = plt.subplots(1, 3, figsize=(15, 15))
ax[0].imshow(bicubic_img)
ax[0].set_title(f"Bicubic image {bicubic_img.shape}")

ax[1].imshow(pred_img)
ax[1].set_title(f"Upscaled image {pred_img.shape}")

ax[2].imshow(substract_img, vmin=substract_img_np.min(), vmax=substract_img_np.max())
ax[2].set_title(f"Substracted image {substract_img.shape}")

plt.show()

In [None]:
# Compute PSNR and SSIM
psnr = metrics.peak_signal_noise_ratio(hr_img_np, pred_img_np, data_range=1)
ssim = metrics.structural_similarity(hr_img_np, pred_img_np, \
                                     win_size=7, data_range=1, multichannel=True, channel_axis=2)

print(f"Model PSNR: {psnr} SSIM: {ssim}")

bicubic_psnr = metrics.peak_signal_noise_ratio(hr_img_np, bicubic_img_np, data_range=1)
bicubic_ssim = metrics.structural_similarity(hr_img_np, bicubic_img_np, \
                                     win_size=7, data_range=1, multichannel=True, channel_axis=2)

print(f"Bicubic PSNR: {bicubic_psnr} SSIM: {bicubic_ssim}")

In [None]:
# Saving the image results

output_path = "results/examples/"

if not os.path.exists(output_path):
    os.makedirs(output_path)

# empty the output folder

for filename in os.listdir(output_path):
    file_path = os.path.join(output_path, filename)
    try:
        if os.path.isfile(file_path) or os.path.islink(file_path):
            os.unlink(file_path)
    except Exception as e:
        print('Failed to delete %s. Reason: %s' % (file_path, e))

if not os.path.exists(output_path):
    os.makedirs(output_path)

# Save the original image
cv2.imwrite(os.path.join(output_path, "original.png"), (hr_img_np * 255.0).astype(np.uint8))
# Save the upscaled image
cv2.imwrite(os.path.join(output_path, "upscaled.png"), (pred_img_np * 255.0).astype(np.uint8))
# Save the bilinear image
cv2.imwrite(os.path.join(output_path, "bicubic.png"), (bicubic_img_np * 255.0).astype(np.uint8))
# Save the substracted image
cv2.imwrite(os.path.join(output_path, "substracted.png"), (torchUtil.numpy_to_image(substract_img_np) * 255.0).astype(np.uint8))
