# Super-Resolution metrics calculation
---

# 1. **Install dependencies**

## 1.1. **Install key dependencies**

In [None]:
#@markdown ## Install NumPy 1.26.0
#@markdown After running this cell, your session will be automatically restarted for the changes to take effect.

import importlib.metadata

desired_version = "1.26.0"

try:
    installed_version = importlib.metadata.version("numpy")
    if installed_version == desired_version:
        print(f"NumPy {desired_version} is already installed.")
    else:
        print(f"Installing NumPy {desired_version} (current: {installed_version})...")
        !pip install numpy=={desired_version} --prefer-binary
        import os
        os._exit(00)  # Restart runtime for changes to take effect
except importlib.metadata.PackageNotFoundError:
    print(f"NumPy is not installed. Installing {desired_version}...")
    !pip install numpy=={desired_version} --prefer-binary
    import os
    os._exit(00)

In [None]:
#@markdown ## Install the rest of packages
#@markdown After running this cell, your session will be automatically restarted for the changes to take effect.

!pip uninstall -y torch torchvision torchaudio
!pip cache purge
!pip install -q torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1

!pip install -q nbformat==5.10.4
!pip install -q plotly==6.0.1
!pip install -q torchmetrics[image]==1.4.0
!pip install -q torch-fidelity==0.3.0
!pip install -q nanopyx==1.1.0
!pip install -q pandas==2.2.2
!pip install -q matplotlib==3.8.0
!pip install -q opencv-python==4.8.0.76

print("Everything correctly installed, your session will be automatically restarted to activate the changes.")

import os
os._exit(00)

## 1.2. **Load key dependencies**

In [None]:
#@markdown ##Load key dependencies

import os
#Create a variable to get and store relative base path
base_path = os.getcwd()

# Load PyTorch
import torch

# Load the torchmetrics functions
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

# Load the NanoPyx functions
from nanopyx.core.analysis.decorr import DecorrAnalysis
from nanopyx.core.transform import ErrorMap

# Load the functions to calculate IL-NIQE
from scipy.ndimage.filters import convolve
from scipy.signal import convolve2d
from scipy.special import gamma
from scipy.stats import exponweib
from scipy.optimize import fmin
import scipy.io

# Load other packages
from matplotlib import pyplot as plt
from skimage import io
import urllib.request
from tqdm import tqdm
import pandas as pd
import numpy as np
import tempfile
import cv2
import math
import os

# Load the packages for final plot
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Define the main functions
def print_img_info(img, description="", is_dir=False):
    print(f"{description}")

    if isinstance(img, list):
        print("Stack of images with different shape.")

        img_shapes, img_types, img_min, img_max, img_means, img_stds = [], [], [], [], [], []

        for i, image in enumerate(img):
            img_shapes.append(image.shape)
            img_types.append(image.dtype)
            img_min.append(image.min())
            img_max.append(image.max())
            img_means.append(image.mean())
            img_stds.append(image.std())

        print(f"\tShapes: {np.unique(img_shapes)}")
        print(f"\tTypes: {np.unique(img_types)}")
        print(f"\tMinimum values: {np.mean(img_min)} ± {np.std(img_min)}")
        print(f"\tMaximum values: {np.mean(img_max)} ± {np.std(img_max)}")
        print(f"\tMean values: {np.mean(img_means)} ± {np.std(img_means)}")
        print(f"\tStandard deviations: {np.mean(img_stds)} ± {np.std(img_stds)}")

    elif isinstance(img, np.ndarray):
        if is_dir:
            print("Stack of NumPy images with the same shape.")
        else:
            print("Single NumPy image.")

        print(f"\tShape: {img.shape}")
        print(f"\tType: {img.dtype}")
        print(f"\tRange of values:[{img.min()},{img.max()}]")
        print(f"\tMean±Std:{img.mean():.3f}±{img.std():.3f}")
    elif isinstance(img, torch.Tensor):
        if is_dir:
            print("Stack of PyTorch Tensor images with the same shape.")
        else:
            print("Single PyTorch Tensor image.")

        print(f"\tShape: {img.shape}")
        print(f"\tType: {img.dtype}")
        print(f"\tRange of values:[{img.min()},{img.max()}]")
        print(f"\tMean±Std:{img.mean():.3f}±{img.std():.3f}")
    else:
        raise ValueError("Image must be a list or a numpy array.")

    print("------")

def load_images(path, normalize="None"):

    # Load the images
    if os.path.isdir(path):
        filename_list = []
        images_list = []

        for filename in sorted(os.listdir(path)):
            filename_list.append(filename)

            if normalize == "None" or normalize == "Across all images":
                images_list.append(io.imread(os.path.join(path, filename)))
            elif normalize == "Per image":
                images_list.append(min_max_norm_numpy(io.imread(os.path.join(path, filename))))

        try:
            images_list = np.array(images_list)
        except:
            print("WARNING: The images in the folder are not of the same size.")

        if normalize == "Across all images":
            images_list = min_max_norm(images_list)

        return images_list, filename_list
    else:
        if normalize == "None":
            image = io.imread(path)
        else:
            image = min_max_norm_numpy(io.imread(path))

        return image, path

def min_max_norm(data):
    # Check if the data is a list or a numpy array
    if isinstance(data, list):
        return min_max_norm_list(data)
    elif isinstance(data, np.ndarray):
        return min_max_norm_numpy(data)
    else:
        raise ValueError("Input data must be a list or a numpy array.")

def min_max_norm_numpy(data):
    return (data - data.min())/(data.max() - data.min() + 1e-10)

def min_max_norm_list(data):
    # Get the minimum and maximum values of the list
    min_val = min([img.min() for img in data])
    max_val = max([img.max() for img in data])

    # Normalize each image in the list
    normalized_data = [(img - min_val) / (max_val - min_val + 1e-10) for img in data]

    return normalized_data

#####

def get_size_from_scale(input_size, scale_factor):
    """Get the output size given input size and scale factor.
    Args:
        input_size (tuple): The size of the input image.
        scale_factor (float): The resize factor.
    Returns:
        list[int]: The size of the output image.
    """

    output_shape = [
        int(np.ceil(scale * shape))
        for (scale, shape) in zip(scale_factor, input_size)
    ]

    return output_shape


def get_scale_from_size(input_size, output_size):
    """Get the scale factor given input size and output size.
    Args:
        input_size (tuple(int)): The size of the input image.
        output_size (tuple(int)): The size of the output image.
    Returns:
        list[float]: The scale factor of each dimension.
    """

    scale = [
        1.0 * output_shape / input_shape
        for (input_shape, output_shape) in zip(input_size, output_size)
    ]

    return scale


def _cubic(x):
    """ Cubic function.
    Args:
        x (ndarray): The distance from the center position.
    Returns:
        ndarray: The weight corresponding to a particular distance.
    """

    x = np.array(x, dtype=np.float32)
    x_abs = np.abs(x)
    x_abs_sq = x_abs**2
    x_abs_cu = x_abs_sq * x_abs

    # if |x| <= 1: y = 1.5|x|^3 - 2.5|x|^2 + 1
    # if 1 < |x| <= 2: -0.5|x|^3 + 2.5|x|^2 - 4|x| + 2
    f = (1.5 * x_abs_cu - 2.5 * x_abs_sq + 1) * (x_abs <= 1) + (
        -0.5 * x_abs_cu + 2.5 * x_abs_sq - 4 * x_abs + 2) * ((1 < x_abs) &
                                                             (x_abs <= 2))

    return f


def get_weights_indices(input_length, output_length, scale, kernel,
                        kernel_width):
    """Get weights and indices for interpolation.
    Args:
        input_length (int): Length of the input sequence.
        output_length (int): Length of the output sequence.
        scale (float): Scale factor.
        kernel (func): The kernel used for resizing.
        kernel_width (int): The width of the kernel.
    Returns:
        list[ndarray]: The weights and the indices for interpolation.
    """
    if scale < 1:  # modified kernel for antialiasing

        def h(x):
            return scale * kernel(scale * x)

        kernel_width = 1.0 * kernel_width / scale
    else:
        h = kernel
        kernel_width = kernel_width

    # coordinates of output
    x = np.arange(1, output_length + 1).astype(np.float32)

    # coordinates of input
    u = x / scale + 0.5 * (1 - 1 / scale)
    left = np.floor(u - kernel_width / 2)  # leftmost pixel
    p = int(np.ceil(kernel_width)) + 2  # maximum number of pixels

    # indices of input pixels
    ind = left[:, np.newaxis, ...] + np.arange(p)
    indices = ind.astype(np.int32)

    # weights of input pixels
    weights = h(u[:, np.newaxis, ...] - indices - 1)

    weights = weights / np.sum(weights, axis=1)[:, np.newaxis, ...]

    # remove all-zero columns
    aux = np.concatenate(
        (np.arange(input_length), np.arange(input_length - 1, -1,
                                            step=-1))).astype(np.int32)
    indices = aux[np.mod(indices, aux.size)]
    ind2store = np.nonzero(np.any(weights, axis=0))
    weights = weights[:, ind2store]
    indices = indices[:, ind2store]

    return weights, indices


def resize_along_dim(img_in, weights, indices, dim):
    """Resize along a specific dimension.
    Args:
        img_in (ndarray): The input image.
        weights (ndarray): The weights used for interpolation, computed from
            [get_weights_indices].
        indices (ndarray): The indices used for interpolation, computed from
            [get_weights_indices].
        dim (int): Which dimension to undergo interpolation.
    Returns:
        ndarray: Interpolated (along one dimension) image.
    """

    img_in = img_in.astype(np.float32)
    w_shape = weights.shape
    output_shape = list(img_in.shape)
    output_shape[dim] = w_shape[0]
    img_out = np.zeros(output_shape)

    if dim == 0:
        for i in range(w_shape[0]):
            w = weights[i, :][np.newaxis, ...]
            ind = indices[i, :]
            img_slice = img_in[ind, :]
            img_out[i] = np.sum(np.squeeze(img_slice, axis=0) * w.T, axis=0)
    elif dim == 1:
        for i in range(w_shape[0]):
            w = weights[i, :][:, :, np.newaxis]
            ind = indices[i, :]
            img_slice = img_in[:, ind]
            img_out[:, i] = np.sum(np.squeeze(img_slice, axis=1) * w.T, axis=1)

    if img_in.dtype == np.uint8:
        img_out = np.clip(img_out, 0, 255)
        return np.around(img_out).astype(np.uint8)
    else:
        return img_out


class MATLABLikeResize:
    """Resize the input image using MATLAB-like downsampling.
        Currently support bicubic interpolation only. Note that the output of
        this function is slightly different from the official MATLAB function.
        Required keys are the keys in attribute "keys". Added or modified keys
        are "scale" and "output_shape", and the keys in attribute "keys".
        Args:
            keys (list[str]): A list of keys whose values are modified.
            scale (float | None, optional): The scale factor of the resize
                operation. If None, it will be determined by output_shape.
                Default: None.
            output_shape (tuple(int) | None, optional): The size of the output
                image. If None, it will be determined by scale. Note that if
                scale is provided, output_shape will not be used.
                Default: None.
            kernel (str, optional): The kernel for the resize operation.
                Currently support 'bicubic' only. Default: 'bicubic'.
            kernel_width (float): The kernel width. Currently support 4.0 only.
                Default: 4.0.
    """

    def __init__(self,
                 keys=None,
                 scale=None,
                 output_shape=None,
                 kernel='bicubic',
                 kernel_width=4.0):

        if kernel.lower() != 'bicubic':
            raise ValueError('Currently support bicubic kernel only.')

        if float(kernel_width) != 4.0:
            raise ValueError('Current support only width=4 only.')

        if scale is None and output_shape is None:
            raise ValueError('"scale" and "output_shape" cannot be both None')

        self.kernel_func = _cubic
        self.keys = keys
        self.scale = scale
        self.output_shape = output_shape
        self.kernel = kernel
        self.kernel_width = kernel_width

    def resize_img(self, img):
        return self._resize(img)

    def _resize(self, img):
        weights = {}
        indices = {}

        # compute scale and output_size
        if self.scale is not None:
            scale = float(self.scale)
            scale = [scale, scale]
            output_size = get_size_from_scale(img.shape, scale)
        else:
            scale = get_scale_from_size(img.shape, self.output_shape)
            output_size = list(self.output_shape)

        # apply cubic interpolation along two dimensions
        order = np.argsort(np.array(scale))
        for k in range(2):
            key = (img.shape[k], output_size[k], scale[k], self.kernel_func,
                   self.kernel_width)
            weight, index = get_weights_indices(img.shape[k], output_size[k],
                                                scale[k], self.kernel_func,
                                                self.kernel_width)
            weights[key] = weight
            indices[key] = index

        output = np.copy(img)
        if output.ndim == 2:  # grayscale image
            output = output[:, :, np.newaxis]

        for k in range(2):
            dim = order[k]
            key = (img.shape[dim], output_size[dim], scale[dim],
                   self.kernel_func, self.kernel_width)
            output = resize_along_dim(output, weights[key], indices[key], dim)

        return output

    def __call__(self, results):
        for key in self.keys:
            is_single_image = False
            if isinstance(results[key], np.ndarray):
                is_single_image = True
                results[key] = [results[key]]

            results[key] = [self._resize(img) for img in results[key]]

            if is_single_image:
                results[key] = results[key][0]

        results['scale'] = self.scale
        results['output_shape'] = self.output_shape

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        repr_str += (
            f'(keys={self.keys}, scale={self.scale}, '
            f'output_shape={self.output_shape}, '
            f'kernel={self.kernel}, kernel_width={self.kernel_width})')
        return repr_str


def reorder_image(img, input_order='HWC'):
    """Reorder images to 'HWC' order.
    If the input_order is (h, w), return (h, w, 1);
    If the input_order is (c, h, w), return (h, w, c);
    If the input_order is (h, w, c), return as it is.
    Args:
        img (ndarray): Input image.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            If the input image shape is (h, w), input_order will not have
            effects. Default: 'HWC'.
    Returns:
        ndarray: reordered image.
    """

    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
    if len(img.shape) == 2:
        img = img[..., None]
    if input_order == 'CHW':
        img = img.transpose(1, 2, 0)
    return img

def fitweibull(x):
   def optfun(theta):
      return -np.sum(np.log(exponweib.pdf(x, 1, theta[0], scale = theta[1], loc = 0)))
   logx = np.log(x)
   shape = 1.2 / np.std(logx)
   scale = np.exp(np.mean(logx) + (0.572 / shape))
   return fmin(optfun, [shape, scale], xtol = 0.01, ftol = 0.01, disp = 0)

def estimate_aggd_param(block):
    """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
    Args:
        block (ndarray): 2D Image block.
    Returns:
        tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
            distribution (Estimating the parames in Equation 7 in the paper).
    """
    block = block.flatten()
    gam = np.arange(0.2, 10.001, 0.001)  # len = 9801
    gam_reciprocal = np.reciprocal(gam)
    r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))

    left_std = np.sqrt(np.mean(block[block < 0]**2))
    right_std = np.sqrt(np.mean(block[block > 0]**2))
    gammahat = left_std / right_std
    rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
    rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
    array_position = np.argmin((r_gam - rhatnorm)**2)

    alpha = gam[array_position]
    beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
    beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
    return (alpha, beta_l, beta_r)


def compute_feature(feature_list, block_posi):
    """Compute features.
    Args:
        feature_list(list): feature to be processed.
        block_posi (turple): the location of 2D Image block.
    Returns:
        list: Features with length of 234.
    """
    feat = []
    data = feature_list[0][block_posi[0]:block_posi[1], block_posi[2]:block_posi[3]]
    alpha_data, beta_l_data, beta_r_data = estimate_aggd_param(data)
    feat.extend([alpha_data, (beta_l_data + beta_r_data) / 2])
    # distortions disturb the fairly regular structure of natural images.
    # This deviation can be captured by analyzing the sample distribution of
    # the products of pairs of adjacent coefficients computed along
    # horizontal, vertical and diagonal orientations.
    shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
    for i in range(len(shifts)):
        shifted_block = np.roll(data, shifts[i], axis=(0, 1))
        alpha, beta_l, beta_r = estimate_aggd_param(data * shifted_block)
        # Eq. 8 in NIQE
        mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
        feat.extend([alpha, mean, beta_l, beta_r])

    for i in range(1,4):
        data = feature_list[i][block_posi[0]:block_posi[1], block_posi[2]:block_posi[3]]
        shape, scale = fitweibull(data.flatten('F'))
        feat.extend([scale, shape])

    for i in range(4,7):
        data = feature_list[i][block_posi[0]:block_posi[1], block_posi[2]:block_posi[3]]
        mu = np.mean(data)
        sigmaSquare = np.var(data.flatten('F'))
        feat.extend([mu, sigmaSquare])

    for i in range(7,85):
        data = feature_list[i][block_posi[0]:block_posi[1], block_posi[2]:block_posi[3]]
        alpha_data, beta_l_data, beta_r_data = estimate_aggd_param(data)
        feat.extend([alpha_data, (beta_l_data + beta_r_data) / 2])

    for i in range(85,109):
        data = feature_list[i][block_posi[0]:block_posi[1], block_posi[2]:block_posi[3]]
        shape, scale = fitweibull(data.flatten('F'))
        feat.extend([scale, shape])

    return feat

def matlab_fspecial(shape=(3,3),sigma=0.5):
    """
    2D gaussian mask - should give the same result as MATLAB's
    fspecial('gaussian',[shape],[sigma])
    """
    m,n = [(ss-1.)/2. for ss in shape]
    y,x = np.ogrid[-m:m+1,-n:n+1]
    h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
    h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    return h

def gauDerivative(sigma):
    halfLength = math.ceil(3*sigma)

    x, y = np.meshgrid(np.linspace(-halfLength, halfLength, 2*halfLength+1), np.linspace(-halfLength, halfLength, 2*halfLength+1))

    gauDerX = x*np.exp(-(x**2 + y**2)/2/sigma/sigma)
    gauDerY = y*np.exp(-(x**2 + y**2)/2/sigma/sigma)

    return gauDerX, gauDerY

def conv2(x, y, mode='same'):
    return np.rot90(convolve2d(np.rot90(x, 2), np.rot90(y, 2), mode=mode), 2)

def logGabors(rows, cols, minWaveLength, sigmaOnf, mult, dThetaOnSigma):
    nscale          = 3    # Number of wavelet scales.
    norient         = 4    # Number of filter orientations.
    thetaSigma = math.pi/norient/dThetaOnSigma  # Calculate the standard deviation of the angular Gaussian function used to construct filters in the freq. plane.
    if cols % 2 > 0:
        xrange = np.linspace(-(cols-1)/2, (cols-1)/2, cols)/(cols-1)
    else:
        xrange = np.linspace(-cols/2, cols/2-1, cols)/cols

    if rows % 2 > 0:
        yrange = np.linspace(-(rows-1)/2, (rows-1)/2, rows)/(rows-1)
    else:
        yrange = np.linspace(-rows/2, rows/2-1, rows)/rows

    x, y = np.meshgrid(xrange, yrange)
    radius = np.sqrt(x**2 + y**2)
    theta = np.arctan2(-y,x)
    radius = np.fft.ifftshift(radius)
    theta  = np.fft.ifftshift(theta)
    radius[0,0] = 1
    sintheta = np.sin(theta)
    costheta = np.cos(theta)

    logGabor = []
    for s in range(nscale):
        wavelength = minWaveLength*mult**(s)
        fo = 1.0/wavelength
        logGabor_s = np.exp((-(np.log(radius/fo))**2) / (2 * np.log(sigmaOnf)**2))
        logGabor_s[0,0] = 0
        logGabor.append(logGabor_s)

    spread = []
    for o in range(norient):
        angl = o*math.pi/norient
        ds = sintheta * np.cos(angl) - costheta * np.sin(angl)
        dc = costheta * np.cos(angl) + sintheta * np.sin(angl)
        dtheta = abs(np.arctan2(ds,dc))
        spread.append(np.exp((-dtheta**2) / (2 * thetaSigma**2)))

    filter = []
    for s in range(nscale):
        o_list=[]
        for o in range(norient):
            o_list.append(logGabor[s] * spread[o])
        filter.append(o_list)
    return filter

# @ray.remote
def ilniqe(img, mu_pris_param, cov_pris_param, gaussian_window, principleVectors, meanOfSampleData, resize=True, block_size_h=84, block_size_w=84):
    """Calculate IL-NIQE (Integrated Local Natural Image Quality Evaluator) metric.
    Ref: A Feature-Enriched Completely Blind Image Quality Evaluator.
    This implementation could produce almost the same results as the official
    MATLAB codes: https://github.com/milestonesvn/ILNIQE
    Note that we do not include block overlap height and width, since they are
    always 0 in the official implementation.
    Args:
        img (ndarray): Input image whose quality needs to be computed. The
            image must be a gray or Y (of YCbCr) image with shape (h, w).
            Range [0, 255] with float type.
        mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
            model calculated on the pristine dataset.
        cov_pris_param (ndarray): Covariance of a pre-defined multivariate
            Gaussian model calculated on the pristine dataset.
        gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
            image.
        principleVectors (ndarray): Features from official .mat file.
        meanOfSampleData (ndarray): Features from official .mat file.
        block_size_h (int): Height of the blocks in to which image is divided.
            Default: 84 (the official recommended value).
        block_size_w (int): Width of the blocks in to which image is divided.
            Default: 84 (the official recommended value).
    """
    assert img.ndim == 3, ('Input image must be a color image with shape (h, w, c).')
    # crop image
    # img = img.astype(np.float64)
    blockrowoverlap = 0
    blockcoloverlap = 0
    sigmaForGauDerivative = 1.66
    KforLog = 0.00001
    normalizedWidth = 524
    minWaveLength = 2.4
    sigmaOnf = 0.55
    mult = 1.31
    dThetaOnSigma = 1.10
    scaleFactorForLoG = 0.87
    scaleFactorForGaussianDer = 0.28
    sigmaForDownsample = 0.9

    infConst = 10000
    nanConst = 2000

    if resize:
        # img = cv2.resize(img, (normalizedWidth, normalizedWidth), interpolation=cv2.INTER_AREA)
        resize_func = MATLABLikeResize(output_shape=(normalizedWidth, normalizedWidth))
        img = resize_func.resize_img(img)
        img = np.clip(img, 0.0, 255.0)

    h, w, _ = img.shape

    num_block_h = math.floor(h / block_size_h)
    num_block_w = math.floor(w / block_size_w)
    img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]

    O1 = 0.3*img[:,:,0] + 0.04*img[:,:,1] - 0.35*img[:,:,2]
    O2 = 0.34*img[:,:,0] - 0.6*img[:,:,1] + 0.17*img[:,:,2]
    O3 = 0.06*img[:,:,0] + 0.63*img[:,:,1] + 0.27*img[:,:,2]

    RChannel = img[:,:,0]
    GChannel = img[:,:,1]
    BChannel = img[:,:,2]

    distparam = []  # dist param is actually the multiscale features
    for scale in (1, 2):  # perform on two scales (1, 2)
        mu = convolve(O3, gaussian_window, mode='nearest')
        sigma = np.sqrt(np.abs(convolve(np.square(O3), gaussian_window, mode='nearest') - np.square(mu)))
        # normalize, as in Eq. 1 in the paper
        structdis = (O3 - mu) / (sigma + 1)

        dx, dy = gauDerivative(sigmaForGauDerivative/(scale**scaleFactorForGaussianDer));
        compRes = conv2(O1, dx + 1j*dy, 'same')
        IxO1 = np.real(compRes)
        IyO1 = np.imag(compRes)
        GMO1 = np.sqrt(IxO1**2 + IyO1**2) + np.finfo(O1.dtype).eps

        compRes = conv2(O2, dx + 1j*dy, 'same')
        IxO2 = np.real(compRes)
        IyO2 = np.imag(compRes)
        GMO2 = np.sqrt(IxO2**2 + IyO2**2) + np.finfo(O2.dtype).eps

        compRes = conv2(O3, dx + 1j*dy, 'same')
        IxO3 = np.real(compRes)
        IyO3 = np.imag(compRes)
        GMO3 = np.sqrt(IxO3**2 + IyO3**2) + np.finfo(O3.dtype).eps

        logR = np.log(RChannel + KforLog)
        logG = np.log(GChannel + KforLog)
        logB = np.log(BChannel + KforLog)
        logRMS = logR - np.mean(logR)
        logGMS = logG - np.mean(logG)
        logBMS = logB - np.mean(logB)

        Intensity = (logRMS + logGMS + logBMS) / np.sqrt(3)
        BY = (logRMS + logGMS - 2 * logBMS) / np.sqrt(6)
        RG = (logRMS - logGMS) / np.sqrt(2)

        compositeMat = [structdis, GMO1, GMO2, GMO3, Intensity, BY, RG, IxO1, IyO1, IxO2, IyO2, IxO3, IyO3]

        h, w = O3.shape

        LGFilters = logGabors(h,w,minWaveLength/(scale**scaleFactorForLoG),sigmaOnf,mult,dThetaOnSigma)
        fftIm = np.fft.fft2(O3)

        logResponse = []
        partialDer = []
        GM = []
        for scaleIndex in range(3):
            for oriIndex in range(4):
                response = np.fft.ifft2(LGFilters[scaleIndex][oriIndex]*fftIm)
                realRes = np.real(response)
                imagRes = np.imag(response)

                compRes = conv2(realRes, dx + 1j*dy, 'same')
                partialXReal = np.real(compRes)
                partialYReal = np.imag(compRes)
                realGM = np.sqrt(partialXReal**2 + partialYReal**2) + np.finfo(partialXReal.dtype).eps
                compRes = conv2(imagRes, dx + 1j*dy, 'same')
                partialXImag = np.real(compRes)
                partialYImag = np.imag(compRes)
                imagGM = np.sqrt(partialXImag**2 + partialYImag**2) + np.finfo(partialXImag.dtype).eps

                logResponse.append(realRes)
                logResponse.append(imagRes)
                partialDer.append(partialXReal)
                partialDer.append(partialYReal)
                partialDer.append(partialXImag)
                partialDer.append(partialYImag)
                GM.append(realGM)
                GM.append(imagGM)

        compositeMat.extend(logResponse)
        compositeMat.extend(partialDer)
        compositeMat.extend(GM)

        feat = []
        for idx_w in range(num_block_w):
            for idx_h in range(num_block_h):
                # process each block
                block_posi = [idx_h * block_size_h // scale, (idx_h + 1) * block_size_h // scale,
                                      idx_w * block_size_w // scale, (idx_w + 1) * block_size_w // scale]
                feat.append(compute_feature(compositeMat, block_posi))

        distparam.append(np.array(feat))
        gauForDS = matlab_fspecial([math.ceil(6*sigmaForDownsample), math.ceil(6*sigmaForDownsample)], sigmaForDownsample)
        filterResult = convolve(O1, gauForDS, mode='nearest')
        O1 = filterResult[0::2,0::2]
        filterResult = convolve(O2, gauForDS, mode='nearest')
        O2 = filterResult[0::2,0::2]
        filterResult = convolve(O3, gauForDS, mode='nearest')
        O3 = filterResult[0::2,0::2]

        filterResult = convolve(RChannel, gauForDS, mode='nearest')
        RChannel = filterResult[0::2,0::2]
        filterResult = convolve(GChannel, gauForDS, mode='nearest')
        GChannel = filterResult[0::2,0::2]
        filterResult = convolve(BChannel, gauForDS, mode='nearest')
        BChannel = filterResult[0::2,0::2]

    distparam = np.concatenate(distparam, axis=1)
    distparam = np.array(distparam)

    # fit a MVG (multivariate Gaussian) model to distorted patch features
    distparam[distparam>infConst] = infConst
    meanMatrix = np.tile(meanOfSampleData,(1,distparam.shape[0]))
    coefficientsViaPCA = np.matmul(principleVectors.T, (distparam.T - meanMatrix))

    final_features = coefficientsViaPCA.T
    mu_distparam = np.nanmean(final_features, axis=0)
    mu_distparam[np.isnan(mu_distparam)] = nanConst
    # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
    distparam_no_nan = final_features[~np.isnan(final_features).any(axis=1)]
    cov_distparam = np.cov(distparam_no_nan, rowvar=False)
    # compute niqe quality, Eq. 10 in NIQE
    invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)

    dist = []
    for data_i in range(final_features.shape[0]):
        currentFea = final_features[data_i,:]
        currentFea = np.where(np.isnan(currentFea), mu_distparam, currentFea)
        currentFea = np.expand_dims(currentFea, axis=0)
        quality = np.matmul(
            np.matmul((currentFea - mu_pris_param), invcov_param), np.transpose((currentFea - mu_pris_param)))
        dist.append(np.sqrt(quality))
    score = np.mean(np.array(dist))
    return score

def calculate_ilniqe(img, crop_border, input_order='HWC', num_cpus=3, resize=True, version='python', **kwargs):
    """Calculate IL-NIQE (Integrated Local Natural Image Quality Evaluator) metric.
    Args:
        img (ndarray): Input image whose quality needs to be computed.
            The input image must be in range [0, 255] with float/int type in RGB space.
            The input_order of image can be 'HWC' or 'CHW'. (BGR order)
            If the input order is 'HWC' or 'CHW', it will be reorder to 'HWC'.
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the metric calculation.
        input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
            Default: 'HWC'.
    Returns:
        float: IL-NIQE result.
    """

    # we use the official params estimated from the pristine dataset.
    gaussian_window = matlab_fspecial((5,5),5/6)
    gaussian_window = gaussian_window/np.sum(gaussian_window)

    if version == 'python':
        # Trained using python code
        temp_dir = tempfile.TemporaryDirectory()
        temp_dir_path = temp_dir.name

        # Download the file from `url` and save it in a temporal folder:
        urllib.request.urlretrieve("https://raw.githubusercontent.com/IceClear/IL-NIQE/81de83e6faf0f47a8be04f58a0685d5f95ac0c52/python_templateModel.mat",
                                os.path.join(temp_dir_path,"python_templateModel.mat"))

        model_mat = scipy.io.loadmat(os.path.join(temp_dir_path,'python_templateModel.mat'))
    else:
        # Trained using official Matlab
        temp_dir = tempfile.TemporaryDirectory()
        temp_dir_path = temp_dir.name

        # Download the file from `url` and save it in a temporal folder:
        urllib.request.urlretrieve("https://raw.githubusercontent.com/IceClear/IL-NIQE/81de83e6faf0f47a8be04f58a0685d5f95ac0c52/templateModel.mat",
                                os.path.join(temp_dir_path,"templateModel.mat"))

        model_mat = scipy.io.loadmat(os.path.join(temp_dir_path,'templateModel.mat'))

    mu_pris_param = model_mat['templateModel'][0][0]
    cov_pris_param = model_mat['templateModel'][0][1]
    meanOfSampleData = model_mat['templateModel'][0][2]
    principleVectors = model_mat['templateModel'][0][3]

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float64)

    if input_order != 'HW':
        img = reorder_image(img, input_order=input_order)
        img = np.squeeze(img)

    assert img.shape[2] == 3 # only for RGB image

    if crop_border != 0:
        img = img[crop_border:-crop_border, crop_border:-crop_border]

    # round is necessary for being consistent with MATLAB's result
    img = img.round()

    # ray.init(num_cpus=num_cpus)
    # task_id = ilniqe.remote(img, mu_pris_param, cov_pris_param, gaussian_window, principleVectors, meanOfSampleData)
    # ilniqe_result = ray.get(task_id)

    ilniqe_result = ilniqe(img, mu_pris_param, cov_pris_param, gaussian_window, principleVectors, meanOfSampleData, resize)

    if isinstance(ilniqe_result, complex) and ilniqe_result.imag == 0:
        ilniqe_result = ilniqe_result.real

    return ilniqe_result

def process_images(data, is_dir=False, torch_output=True):
    if isinstance(data, list):
        processed_data = []
        processed_data_3C = []
        for img in data:

            aux_img = img.astype(np.float32)
            if torch_output:
                aux_img = torch.from_numpy(aux_img)


            if len(img.shape) == 3:
                # If data has 3 dimenstions, it already has a channel dimension, add a batch dimension
                aux_img = aux_img[None, :, :, :]

                # First dimension is batch size, let's see which dimension is the channel dimension
                # We assume the channel dimension is the one with the minimum size
                channel_value = min(aux_img.shape[1:])
                channel_axis = aux_img.shape.index(channel_value)

                assert channel_value == 1, f"Channel dimension {channel_value} should be 1."

                # We will move the channel axis to the second position, if it is not already there
                if torch_output:
                    torch.moveaxis(aux_img, channel_axis, 1)
                else:
                    np.moveaxis(aux_img, channel_axis, 1)

            elif len(img.shape) == 2:
                # If data has 2 dimenstions, add a channel and batch dimension
                aux_img = aux_img[None, None, :, :]
            else:
                raise ValueError("Unsupported data shape for single image.")

            processed_data.append(aux_img)

            if torch_output:
                aux_img_3C = torch.cat([aux_img, aux_img, aux_img], dim=1)
            else:
                aux_img_3C = np.concatenate([aux_img, aux_img, aux_img], axis=1)
            processed_data_3C.append(aux_img_3C)

    elif isinstance(data, np.ndarray):

        aux_img = data.astype(np.float32)
        if torch_output:
            aux_img = torch.from_numpy(aux_img)

        # Check if it is a single image or a batch of images
        if is_dir:
            if len(data.shape) == 3:
                # If data has 3 dimenstions, add a channel dimension
                aux_img = aux_img[:, None, :, :]
            elif len(data.shape) == 4:
                # If data has 4 dimensions, it is a batch of images that already has a channel dimension
                pass
            else:
                raise ValueError("Unsupported data shape for batch of images.")
        else:
            if len(data.shape) == 3:
                # If data has 3 dimenstions, it already has a channel dimension, add a batch dimension
                aux_img = aux_img[None, :, :, :]
            elif len(data.shape) == 2:
                # If data has 2 dimenstions, add a channel and batch dimension
                aux_img = aux_img[None, None, :, :]
            else:
                raise ValueError("Unsupported data shape for single image.")

        # First dimension is batch size, let's see which dimension is the channel dimension
        # We assume the channel dimension is the one with the minimum size
        channel_value = min(aux_img.shape[1:])
        channel_axis = aux_img.shape.index(channel_value)

        assert channel_value == 1, f"Channel dimension {channel_value} should be 1."

        # We will move the channel axis to the second position, if it is not already there
        if torch_output:
            processed_data = torch.moveaxis(aux_img, channel_axis, 1)
            processed_data_3C = torch.cat([processed_data,processed_data,processed_data], dim=1)
        else:
            processed_data= np.moveaxis(aux_img, channel_axis, 1)
            processed_data_3C = np.concatenate([processed_data,processed_data,processed_data], axis=1)

    else:
        raise ValueError("Unsupported data type. Expected list or np.ndarray.")

    return processed_data, processed_data_3C

def calculate_metric(metric_func, gt_image, pred_image, detach=False, are_dir=False, non_reference=False):
    if are_dir:

        metric_value_list = []
        for i in tqdm(range(len(gt_image)), desc="Images processed"):

            aux_gt_image = gt_image[i]
            aux_pred_image = pred_image[i] if not non_reference else None

            if isinstance(aux_gt_image, torch.Tensor):
                aux_gt_image = aux_gt_image[None, :, :, :]
                aux_pred_image = aux_pred_image[None, :, :, :] if not non_reference else None

            try:
                if non_reference:
                    metric_value = metric_func(aux_gt_image)
                else:
                    metric_value = metric_func(aux_gt_image, aux_pred_image)

                if detach:
                    metric_value = metric_value.detach().cpu().numpy()
            except Exception as e:
                print(f"Error processing image {i}: {e}")
                metric_value = np.nan

            metric_value_list.append(metric_value)
        metric_value = np.array(metric_value_list)
    else:
        try:
            if non_reference:
                metric_value = metric_func(gt_image)
            else:
                metric_value = metric_func(gt_image, pred_image)
        except Exception as e:
            print(f"Error processing image: {e}")
            metric_value = np.nan

        if isinstance(metric_value, torch.Tensor):
            metric_value = metric_value.detach().cpu().numpy()

    return metric_value


# 2. **Initialise the Colab session**
---

## 2.1. **Check for GPU access**

By default, the session should be using Python 3 and GPU acceleration, but it is possible to ensure that these are set properly by doing the following:

<font size = 4>Go to **Runtime -> Change the Runtime type**

<font size = 4>**Runtime type: Python 3** *(Python 3 is programming language in which this program is written)*

<font size = 4>**Accelerator: GPU** *(Graphics processing unit)*

In [None]:
#@markdown ##Play the cell to check if you have GPU access

!if type nvidia-smi >/dev/null 2>&1; then \
    echo "You have GPU access"; nvidia-smi; \
    else \
    echo -e "You do not have GPU access.\nDid you change your runtime?\nIf the runtime setting is correct then Google did not allocate a GPU for your session\nExpect slow performance. To access GPU try reconnecting later"; fi

## 2.2. **Mount your Google Drive**

<font size = 4> To use this notebook on the data present in your Google Drive, you need to mount your Google Drive to this notebook.

<font size = 4> Play the cell below to mount your Google Drive and follow the instructions.

<font size = 4> Once this is done, your data are available in the **Files** tab on the top left of notebook.

In [None]:
#@markdown ##Play the cell to connect your Google Drive to Colab
from google.colab import drive
drive.mount('/content/gdrive')

# 3. **Select your parameters and paths**
---

## 3.1. **[Optional] Download some example dataset**

In [None]:
#@markdown Choose the model´s predictions you want to download

# Download the dataset
!wget https://github.com/IvanHCenalmor/SISR_Benchmark/releases/download/0.0.1/ER_test_data.zip -O ER_test_data.zip
!mkdir ER_test_data
!unzip -o ER_test_data.zip -d ER_test_data
!rm ER_test_data.zip


model = "UNet" # @param ["UNet", "RCAN", "WDSR"]

# Download the predictions from the models
!wget https://github.com/IvanHCenalmor/SISR_Benchmark/releases/download/0.0.1/ER_{model}_Predictions.zip -O Model_Predictions.zip
!mkdir Model_Predictions
!unzip -o Model_Predictions.zip -d Model_Predictions
!rm Model_Predictions.zip

## 3.2. **Setting the main parameters**


In [None]:
#@markdown ###Path to ground truth and predicted images:
#@markdown > *NOTICE*: It can be a folder or a single image.

Ground_truth_path = "/content/ER_test_data/test_gt" #@param {type:"string"}
Prediction_path = "/content/Model_Predictions" #@param {type:"string"}

normalization = "Per image" # @param ["None", "Per image", "Across all images"]

#####

gt_is_dir = os.path.isdir(Ground_truth_path)
pred_is_dir = os.path.isdir(Prediction_path)

if gt_is_dir and pred_is_dir:
    are_dir = True
elif not gt_is_dir and not pred_is_dir:
    are_dir = False
else:
    raise ValueError("Both paths should be either directories or files.")

#####

# Load the ground truth images and the predicted images
ground_truth_image, ground_truth_filename = load_images(Ground_truth_path, normalize=normalization)
predicted_image, predicted_filename = load_images(Prediction_path, normalize=normalization)

# Print image info
print_img_info(ground_truth_image, description="[Raw] Ground truth image:", is_dir=are_dir)
print_img_info(predicted_image, description="[Raw] Predicted image:", is_dir=are_dir)
print("-" * 50 + "\n")

# # Process the images into as Numpy format
# numpy_ground_truth_image, numpy_ground_truth_image_3c = process_images(ground_truth_image, is_dir=are_dir, torch_output=False)
# numpy_predicted_image, numpy_predicted_image_3c = process_images(predicted_image, is_dir=are_dir, torch_output=False)

# print_img_info(numpy_ground_truth_image, description="[Processed - NumPy] Ground truth image:", is_dir=are_dir)
# print_img_info(numpy_ground_truth_image_3c, description="[Processed - NumPy] Ground truth image (3 channels):", is_dir=are_dir)
# print_img_info(numpy_predicted_image, description="[Processed - NumPy] Predicted image:", is_dir=are_dir)
# print_img_info(numpy_predicted_image_3c, description="[Processed - NumPy] Predicted image (3 channels):", is_dir=are_dir)
# print("-" * 50 + "\n")

# Convert into PyTorch format for some metric calculations
torch_ground_truth_image, torch_ground_truth_image_3c = process_images(ground_truth_image, is_dir=are_dir, torch_output=True)
torch_predicted_image, torch_predicted_image_3c = process_images(predicted_image, is_dir=are_dir, torch_output=True)

print_img_info(torch_ground_truth_image, description="[Processed - PyTorch] Ground truth image:", is_dir=are_dir)
print_img_info(torch_ground_truth_image_3c, description="[Processed - PyTorch] Ground truth image (3 channels):", is_dir=are_dir)
print_img_info(torch_predicted_image, description="[Processed - PyTorch] Predicted image:", is_dir=are_dir)
print_img_info(torch_predicted_image_3c, description="[Processed - PyTorch] Predicted image (3 channels):", is_dir=are_dir)
print("-" * 50 + "\n")

#####

# Plot images
if are_dir:
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(ground_truth_image[0], 'inferno')
    ax[0].set_title("Real image (first image in folder)")
    ax[0].axis('off')
    ax[1].imshow(predicted_image[0], 'inferno')
    ax[1].set_title("Predicted image (first image in folder)")
    ax[1].axis('off')
    plt.show()
else:
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(ground_truth_image, 'inferno')
    ax[0].set_title("Real image")
    ax[0].axis('off')
    ax[1].imshow(predicted_image, 'inferno')
    ax[1].set_title("Predicted image")
    ax[1].axis('off')
    plt.show()

# 4. **Calculate the metrics**
---

## 4.1. **Mean Squared Errror (MSE)**

In [None]:
#@markdown ##Calculate MSE

MSE = MeanSquaredError()

print("Calculating MSE...")
mse_value = calculate_metric(MSE, torch_ground_truth_image, torch_predicted_image, are_dir=are_dir, detach=False)
print(f"MSE: {mse_value.mean()} ± {mse_value.std()}")

## 4.2. **Mean Absolute Error (MAE)**

In [None]:
#@markdown ##Calculate MAE

MAE = MeanAbsoluteError()

print("Calculating MAE...")
mae_value = calculate_metric(MAE, torch_ground_truth_image, torch_predicted_image, are_dir=are_dir, detach=False)
print(f"MAE: {mae_value.mean()} ± {mae_value.std()}")

## 4.3. **Peak Signal to Noise Ratio (PSNR)**

In [None]:
#@markdown ##Calculate PSNR

PSNR = PeakSignalNoiseRatio()

print("Calculating PSNR...")
psnr_value = calculate_metric(PSNR, torch_ground_truth_image, torch_predicted_image, are_dir=are_dir, detach=False)
print(f"PSNR: {psnr_value.mean()} ± {psnr_value.std()}")

## 4.4. **Structural Similarity Index Measure (SSIM)**

In [None]:
#@markdown ##Calculate SSIM

SSIM = StructuralSimilarityIndexMeasure()

print("Calculating SSIM...")
ssim_value_array = calculate_metric(SSIM, torch_ground_truth_image, torch_predicted_image,  are_dir=are_dir, detach=False)
print(f"SSIM: {ssim_value_array.mean()} ± {ssim_value_array.std()}")

## 4.5. **Learned Perceptual Image Patch Similarity with AlexNet (LPIPS-Alex)**

In [None]:
#@markdown ##Calculate LPIPS-Alex

LPIPS_ALEX = LearnedPerceptualImagePatchSimilarity(net_type='alex')

print("Calculating LPIPS-Alex...")
lpips_alex_value = calculate_metric(LPIPS_ALEX, torch_ground_truth_image_3c, torch_predicted_image_3c,  are_dir=are_dir, detach=True)
print(f"LPIPS-Alex: {lpips_alex_value.mean()} ± {lpips_alex_value.std()}")

## 4.5. **Learned Perceptual Image Patch Similarity with VGG-16 (LPIPS-VGG)**

In [None]:
#@markdown ##Calculate LPIPS-VGG

LPIPS_VGG = LearnedPerceptualImagePatchSimilarity(net_type='vgg')

print("Calculating LPIPS-VGG...")
lpips_vgg_value = calculate_metric(LPIPS_VGG, torch_ground_truth_image_3c, torch_predicted_image_3c,  are_dir=are_dir, detach=True)
print(f"LPIPS-VGG: {lpips_vgg_value.mean()} ± {lpips_vgg_value.std()}")

## 4.6. **Integrated Local Natural Image Quality Evaluator (IL-NIQE)**

In [None]:
#@markdown ##Calculate IL-NIQE

from skimage.util import img_as_ubyte
def ilnique_function(crop_border=0, input_order='HW', resize=True, version='python', **kwargs):

    def aux_function (images):
        return calculate_ilniqe(images,
                                crop_border=crop_border,
                                input_order=input_order,
                                resize=resize,
                                version=version)

    return aux_function

# Check the input order of the image
if len(ground_truth_image.shape) == 2:
    input_order='HW'
elif len(ground_truth_image.shape) == 3 and are_dir:
    input_order='HW'
elif (len(ground_truth_image.shape) == 3 and not are_dir):
    # We assume the channel dimension is the one with the minimum size
    channel_value = min(ground_truth_image.shape)
    channel_axis = ground_truth_image.shape.index(channel_value)

    assert channel_value == 3, f"Channel dimension {channel_value} should be 1."

    if channel_axis == 0:
        input_order='CHW'
    elif channel_axis == 2:
        input_order='HWC'

elif len(ground_truth_image.shape) == 4:
    # We assume the channel dimension is the one with the minimum size
    channel_value = min(ground_truth_image.shape[1:])
    channel_axis = ground_truth_image.shape.index(channel_value)

    assert channel_value == 3, f"Channel dimension {channel_value} should be 1."

    if channel_axis == 1:
        input_order='CHW'
    elif channel_axis == 3:
        input_order='HWC'

print("Calculating IL-NIQE for the ground truth images...")
ilniqe_gt_value = calculate_metric(ilnique_function(input_order=input_order), img_as_ubyte(ground_truth_image), None, are_dir=are_dir, detach=False, non_reference=True)
print(f"IL-NIQE (ground truth): {ilniqe_gt_value.mean()} ± {ilniqe_gt_value.std()}")
print("-" * 50 + "\n")

print("Calculating IL-NIQE for the predicted images...")
ilniqe_pred_value = calculate_metric(ilnique_function(input_order=input_order),  img_as_ubyte(predicted_image), None, are_dir=are_dir, detach=False, non_reference=True)
print(f"IL-NIQE (prediction): {ilniqe_pred_value.mean()} ± {ilniqe_pred_value.std()}")

## 4.7. **Decorrelation analysis**

In [None]:
#@markdown ##Calculate decorrelation analysis

DECORR_ANALYSIS = DecorrAnalysis()

def decorrelation_analysis_function(DECORR_ANALYSIS, plot_analysis=False):

    def aux_function (images):
        DECORR_ANALYSIS.run_analysis(images)
        if plot_analysis:
            plt.figure(figsize=(10, 5))
            plt.title("Decorrelation Analysis")
            plt.imshow(DECORR_ANALYSIS.plot_results())
            plt.axis("off")
            plt.show();

        return DECORR_ANALYSIS.resolution

    return aux_function

plot_analysis = False #@param {type:"boolean"}

print("Calculating DECORR_ANALYSIS for the ground truth images...")
decorrelation_gt_value = calculate_metric(decorrelation_analysis_function(DECORR_ANALYSIS, plot_analysis=plot_analysis), ground_truth_image, None, are_dir=are_dir, detach=False, non_reference=True)
print(f"DECORR_ANALYSIS (ground truth): {ground_truth_image.mean()} ± {ground_truth_image.std()}")
print("-" * 50 + "\n")

print("Calculating DECORR_ANALYSIS for the predicted images...")
decorrelation_pred_value = calculate_metric(decorrelation_analysis_function(DECORR_ANALYSIS, plot_analysis=plot_analysis), predicted_image, None, are_dir=are_dir, detach=False, non_reference=True)
print(f"DECORR_ANALYSIS (prediction): {predicted_image.mean()} ± {predicted_image.std()}")


## 4.8. **SQUIRREL Error Map - RSP and RSE**

In [None]:
#@markdown ##Calculate error map

# We need the low resolution images to calculate the error map
low_res_path = "/content/ER_test_data/test_wf" # @param {type:"string"}

# Load the widefield images
low_res_images, _ = load_images(low_res_path, normalize=normalization)

print_img_info(low_res_images, description="Low resolution image:", is_dir=are_dir)

def calculate_error_map(low_res_images, predicted_image, are_dir=False, plot_analysis=False):
    ERROR_MAP = ErrorMap()

    if are_dir:
        rse_values = []
        rsp_values = []

        for low_res, pred in tqdm(zip(low_res_images, predicted_image), desc="Images processed"):
            ERROR_MAP.optimise(low_res, pred)

            rse_values.append(ERROR_MAP.getRSE())
            rsp_values.append(ERROR_MAP.getRSP())

            if plot_analysis:
                plt.figure(figsize=(10, 5))
                plt.title("Error Map")
                plt.imshow(ERROR_MAP.plot_results())
                plt.axis("off")
                plt.show()

    else:
        ERROR_MAP.optimise(low_res_images, predicted_image)
        if plot_analysis:
            plt.figure(figsize=(10, 5))
            plt.title("Error Map")
            plt.imshow(ERROR_MAP.plot_results())
            plt.axis("off")
            plt.show()


        rse_values = ERROR_MAP.getRSE()
        rsp_values = ERROR_MAP.getRSP()

    return np.array(rse_values), np.array(rsp_values)

plot_analysis = False #@param {type:"boolean"}

print("Calculating error map for the ground truth images...")
rse_gt_value, rsp_gt_value = calculate_error_map(low_res_images, ground_truth_image, are_dir=are_dir, plot_analysis=plot_analysis)
print(f"RSE (ground truth): {rse_gt_value.mean()} ± {rse_gt_value.std()}")
print(f"RSP (ground truth): {rsp_gt_value.mean()} ± {rsp_gt_value.std()}")
print("-" * 50 + "\n")

print("Calculating error map for the predicted images...")
rse_pred_value, rsp_pred_value = calculate_error_map(low_res_images, predicted_image, are_dir=are_dir, plot_analysis=plot_analysis)
print(f"RSE (prediction): {rse_pred_value.mean()} ± {rse_pred_value.std()}")
print(f"RSP (prediction): {rsp_pred_value.mean()} ± {rsp_pred_value.std()}")



# 5. **Export the results**
---

In [None]:
#@markdown ### Load the calculated metrics

# Dictionary with all the possible metrics and their values
metrics_dict = {
    'MSE': mse_value if 'mse_value' in locals() else np.nan,
    'MAE': mae_value if 'mae_value' in locals() else np.nan,
    'PSNR': psnr_value if 'psnr_value' in locals() else np.nan,
    'SSIM': ssim_value_array if 'ssim_value_array' in locals() else np.nan,
    'LPIPS_ALEX': lpips_alex_value if 'lpips_alex_value' in locals() else np.nan,
    'LPIPS_VGG': lpips_vgg_value if 'lpips_vgg_value' in locals() else np.nan,
    'IL-NIQE (GT)': ilniqe_gt_value if 'ilniqe_gt_value' in locals() else np.nan,
    'IL-NIQE (Pred)': ilniqe_pred_value if 'ilniqe_pred_value' in locals() else np.nan,
    'DECORR_ANALYSIS (GT)': decorrelation_gt_value if 'decorrelation_gt_value' in locals() else np.nan,
    'DECORR_ANALYSIS (Pred)': decorrelation_pred_value if 'decorrelation_pred_value' in locals() else np.nan,
    'RSE (GT)': rse_gt_value if 'rse_gt_value' in locals() else np.nan,
    'RSE (Pred)': rse_pred_value if 'rse_pred_value' in locals() else np.nan,
    'RSP (GT)': rsp_gt_value if 'rsp_gt_value' in locals() else np.nan,
    'RSP (Pred)': rsp_pred_value if 'rsp_pred_value' in locals() else np.nan,
}

mean_metrics = {k: np.nanmean(v) for k, v in metrics_dict.items()}
std_metrics = {k: np.nanstd(v) for k, v in metrics_dict.items()}

## 5.1. **Export the mean/std results to a CSV file**


You need to specify the path where you want to save your results CSV file.



In [None]:
# Create a DataFrame to store the mean from the metrics
metrics_df = pd.DataFrame({
    'Mean': list(mean_metrics.values()),
    'Std': list(std_metrics.values())
})
metrics_df.index = list(mean_metrics.keys())

# Specify the path to save the CSV file
csv_file_path = "/content" #@param {type:"string"}

# Save the DataFrame to a CSV file
metrics_df.to_csv(os.path.join(csv_file_path, "metrics_average.csv"), index=False)

# Display the DataFrame
metrics_df

## 5.2. **Export the results per-image to a CSV file**

You need to specify the path where you want to save your results CSV file.


In [None]:
# Create a Dataframe that stores the metrics for each image
try:
    metrics_per_image_df = pd.DataFrame(metrics_dict)
except: # In case is a single image
    metrics_per_image_df = pd.DataFrame(metrics_dict, index=[0])
metrics_per_image_df.index = ground_truth_filename if are_dir else [os.path.basename(ground_truth_filename)]

# Specify the path to save the CSV file
csv_file_path_per_image = "/content" #@param {type:"string"}

# Save the DataFrame to a CSV file
metrics_per_image_df.to_csv(os.path.join(csv_file_path_per_image, "metrics_per_image.csv"), index=False)

# Show the dataframe with the metrics for each image
metrics_per_image_df

In [None]:
#@markdown ### Visualize the results on interactive plots

def create_interactive_plots_all_metrics(metrics_df, filenames):
    """
    Create interactive scatter plots for all metrics in a 2xN grid.

    Args:
        metrics_df: DataFrame containing metrics for each image
        filenames: List of image filenames
    """
    # Calculate number of metrics and required subplot layout
    n_metrics = len(metrics_df.columns)
    n_rows = math.ceil(n_metrics / 2)

    # Create figure with subplots with adjusted spacing
    fig = make_subplots(
        rows=n_rows, cols=2,
        subplot_titles=[metric for metric in metrics_df.columns],
        horizontal_spacing=0.15,
        vertical_spacing=0.1  # Reduced from 0.2 to 0.1
    )

    # Add traces for each metric
    for idx, metric_name in enumerate(metrics_df.columns):
        row = idx // 2 + 1
        col = idx % 2 + 1  # Alternate between columns 1 and 2

        # Add scatter plot
        fig.add_trace(
            go.Scatter(
                x=list(range(len(metrics_df))),
                y=metrics_df[metric_name],
                mode='markers',
                name=metric_name,
                text=filenames,
                hovertemplate=f"{metric_name}: %{{y:.4f}}<br>Image: %{{text}}<extra></extra>"
            ),
            row=row, col=col
        )

        # Update axes labels
        fig.update_xaxes(title_text="Image Index", row=row, col=col)
        fig.update_yaxes(title_text=metric_name, row=row, col=col)

    # Update layout with adjusted height
    fig.update_layout(
        title="Metrics Overview",
        showlegend=False,
        height=300*n_rows,  # Reduced from 300 to 200 per row
        width=1500,
        template="plotly_white"
    )

    # Show plot
    fig.show()

# Create plots for all metrics
print("\nCreating plots for all metrics...")
create_interactive_plots_all_metrics(metrics_per_image_df, ground_truth_filename)