# 1) Preparations

Before start, make sure that you choose
* Runtime Type = Python 3
* Hardware Accelerator = GPU
* Broswer != Firefox (cannot upload images in step 2)

in the **Runtime** menu -> **Change runtime type**

Then, we clone the repository, set up the envrironment, and download the pre-trained model.

Cloning required Repos and pretarined models into paths

In [None]:
# Clone realESRGAN
!git clone https://github.com/xinntao/Real-ESRGAN.git
%cd Real-ESRGAN
# Set up the environment
!pip install basicsr
!pip install facexlib
!pip install gfpgan
!pip install -r requirements.txt
!python setup.py develop

# Clone BSRGAN
!git clone https://github.com/cszn/BSRGAN.git

!rm -r SwinIR
# Clone SwinIR
!git clone https://github.com/JingyunLiang/SwinIR.git
!pip install timm

# Download the pre-trained models
!wget https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth -P BSRGAN/model_zoo
!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
#!wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth -P experiments/pretrained_models
!wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth -P experiments/pretrained_models

**Upload your images**

In [None]:
import os
import glob
from google.colab import files
import shutil
print(' Note1: You can find an image on the web or download images from the RealSRSet (proposed in BSRGAN, ICCV2021) at https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/RealSRSet+5images.zip.\n Note2: You may need Chrome to enable file uploading!\n Note3: If out-of-memory, set test_patch_wise = True.\n')

# test SwinIR by partioning the image into patches
test_patch_wise = False

# to be compatible with BSRGAN
!rm -r BSRGAN/testsets/RealSRSet
upload_folder = 'BSRGAN/testsets/RealSRSet'
result_folder = 'results'

if os.path.isdir(upload_folder):
    shutil.rmtree(upload_folder)
if os.path.isdir(result_folder):
    shutil.rmtree(result_folder)
os.mkdir(upload_folder)
os.mkdir(result_folder)

# upload images
uploaded = files.upload()
for filename in uploaded.keys():
  dst_path = os.path.join(upload_folder, filename)
  print(f'move {filename} to {dst_path}')
  shutil.move(filename, dst_path)

In [None]:
# empty cache with torch
import torch
torch.cuda.empty_cache()

Small patchwork for errors

In [None]:
import os

# Use IPython magic command to get the package location
file_paths = !pip show basicsr | grep "Location"
if file_paths:
    file_path = os.path.join(file_paths[0].split(": ")[1], "basicsr/data/degradations.py")

    # Check if the file exists
    if os.path.exists(file_path):
        # Open the file for reading
        with open(file_path, "r") as file:
            file_content = file.read()

        # Replace the problematic import statement
        new_content = file_content.replace(
            "from torchvision.transforms.functional_tensor import rgb_to_grayscale",
            "from torchvision.transforms._functional_tensor import rgb_to_grayscale"
        )

        # Open the file for writing and overwrite its content with the modified content
        with open(file_path, "w") as file:
            file.write(new_content)

        print("The file has been updated successfully.")
    else:
        print("The specified file does not exist:", file_path)
else:
    print("Failed to find the installation location for 'basicsr'. Please check the package installation.")


Inference models

In [None]:
# BSRGAN
!rm -r results
if not test_patch_wise:
  %cd BSRGAN
  !python main_test_bsrgan.py
  %cd ..
  shutil.move('BSRGAN/testsets/RealSRSet_results_x4', 'results/BSRGAN')

# realESRGAN
if test_patch_wise:
  !python inference_realesrgan.py -n RealESRGAN_x4plus --input BSRGAN/testsets/RealSRSet -s 4 --output results/realESRGAN --tile 800 --face_enhance
else:
  !python inference_realesrgan.py -n RealESRGAN_x4plus --input BSRGAN/testsets/RealSRSet -s 4 --output results/realESRGAN --face_enhance

# SwinIR-Large
if test_patch_wise:
  !python SwinIR/main_test_swinir.py --task real_sr --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth --folder_lq BSRGAN/testsets/RealSRSet --scale 4 --large_model --tile 640
else:
  !python SwinIR/main_test_swinir.py --task real_sr --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth --folder_lq BSRGAN/testsets/RealSRSet --scale 4 --large_model
shutil.move('results/swinir_real_sr_x4_large', 'results/SwinIR_large')
for path in sorted(glob.glob(os.path.join('results/SwinIR_large', '*.png'))):
  os.rename(path, path.replace('SwinIR.png', 'SwinIR_large.png')) # here is a bug in Colab file downloading: no same-name files





**Enhance Images**

In [None]:
# utils for visualization
import cv2
import matplotlib.pyplot as plt
def display(img1, img2):
  total_figs = 4  # Adjusted number of subplots as SwinIR has been removed
  fig = plt.figure(figsize=(total_figs*12, 14))
  ax1 = fig.add_subplot(1, total_figs, 1)
  plt.title('Input image', fontsize=30)
  ax1.axis('off')
  ax2 = fig.add_subplot(1, total_figs, 2)
  plt.title('BSRGAN (ICCV2021) output', fontsize=30)
  ax2.axis('off')
  ax3 = fig.add_subplot(1, total_figs, 3)
  plt.title('Real-ESRGAN output', fontsize=30)
  ax3.axis('off')
  ax4 = fig.add_subplot(1, total_figs, 4)
  plt.title('SwinIR-Large output', fontsize=30)
  ax4.axis('off')
  ax1.imshow(img1)
  ax2.imshow(img2['BSRGAN'])
  ax3.imshow(img2['realESRGAN'])
  ax4.imshow(img2['SwinIR-L'])

def imread(img_path):
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  return img

# display each image in the upload folder
print('Note: BSRGAN may be better at face restoration, but worse at building restoration because it uses different datasets in training.')
if test_patch_wise:
  print('BSRGAN does not support "test_patch_wise" mode for now. Set test_patch_wise = False to see its results.\n')
else:
  print('\n')

# display each image in the upload folder
print('Note: BSRGAN may be better at face restoration, but worse at building restoration because it uses different datasets in training.')
if test_patch_wise:
  print('BSRGAN does not support "test_patch_wise" mode for now. Set test_patch_wise = False to see its results.\n')
else:
  print('\n')

input_folder = upload_folder
result_folder = 'results/SwinIR_large'  # This is the correct variable name
input_list = sorted(glob.glob(os.path.join(input_folder, '*')))
output_list = sorted(glob.glob(os.path.join(result_folder, '*')))  # Correct usage of variable name
for input_path, output_path in zip(input_list, output_list):
  img_input = imread(input_path)
  img_output = {}
  img_output['SwinIR-L'] = imread(output_path)
  if test_patch_wise:
    img_output['BSRGAN'] = img_output['SwinIR-L'] * 0 + 255  # Display a white image if patch-wise test isn't supported
  else:
    img_output['BSRGAN'] = imread(output_path.replace('SwinIR_large', 'BSRGAN'))
  path = output_path.replace('/SwinIR_large/', '/realESRGAN/').replace('_SwinIR_large.png', '_out{}'.format(os.path.splitext(input_path)[1]))
  if os.path.exists(path):
    shutil.move(path, path.replace('_out.', '_realESRGAN.'))
  img_output['realESRGAN'] = imread(path.replace('_out.', '_realESRGAN.'))

  display(img_input, img_output)


Download Your new enhanced images

In [None]:
# Download the results
zip_filename = 'Real-ESRGAN_result.zip'
if os.path.exists(zip_filename):
  os.remove(zip_filename)
os.system(f"zip -r -j {zip_filename} results/*")
files.download(zip_filename)

**Additional Features:**

This is Zooming and looking into the detailing clearly

In [None]:
# Add this code after your existing display function
import numpy as np
# Add this improved code after your existing display function

def display_with_zoom(img1, img2, zoom_region=None):
    """
    Display input and enhanced images with an option to zoom into a specific region.

    Parameters:
    img1: Input image
    img2: Dictionary of enhanced images from different models
    zoom_region: Tuple (x, y, width, height) defining region to zoom into
    """
    total_figs = 4
    fig = plt.figure(figsize=(total_figs*12, 20))

    # Original images in first row
    ax1 = fig.add_subplot(2, total_figs, 1)
    plt.title('Input image', fontsize=30)
    ax1.axis('off')
    ax1.imshow(img1)

    ax2 = fig.add_subplot(2, total_figs, 2)
    plt.title('BSRGAN output', fontsize=30)
    ax2.axis('off')
    ax2.imshow(img2['BSRGAN'])

    ax3 = fig.add_subplot(2, total_figs, 3)
    plt.title('Real-ESRGAN output', fontsize=30)
    ax3.axis('off')
    ax3.imshow(img2['realESRGAN'])

    ax4 = fig.add_subplot(2, total_figs, 4)
    plt.title('SwinIR-Large output', fontsize=30)
    ax4.axis('off')
    ax4.imshow(img2['SwinIR-L'])

    # If zoom region is specified, add zoomed versions in second row
    if zoom_region is not None:
        x, y, w, h = zoom_region

        # Ensure zoom region is within image bounds
        h_input, w_input = img1.shape[:2]
        x = max(0, min(x, w_input - 1))
        y = max(0, min(y, h_input - 1))
        w = max(1, min(w, w_input - x))
        h = max(1, min(h, h_input - y))

        # Calculate zoom coordinates for enhanced images (accounting for 4x scale)
        x_hr, y_hr = x*4, y*4
        w_hr, h_hr = w*4, h*4

        # Zoomed input image
        ax5 = fig.add_subplot(2, total_figs, 5)
        plt.title('Input image (zoomed)', fontsize=30)
        ax5.axis('off')
        ax5.imshow(img1[y:y+h, x:x+w])

        # For each model output, check dimensions before zooming
        for idx, (model_name, ax_idx) in enumerate([('BSRGAN', 6), ('realESRGAN', 7), ('SwinIR-L', 8)]):
            ax = fig.add_subplot(2, total_figs, ax_idx)
            plt.title(f'{model_name} output (zoomed)', fontsize=30)
            ax.axis('off')

            if model_name in img2 and img2[model_name] is not None and img2[model_name].size > 0:
                h_output, w_output = img2[model_name].shape[:2]
                # Check if zoomed region is within bounds
                if (x_hr < w_output and y_hr < h_output and
                    x_hr + w_hr > 0 and y_hr + h_hr > 0):

                    # Adjust coordinates to be within bounds
                    x_hr_adj = max(0, min(x_hr, w_output - 1))
                    y_hr_adj = max(0, min(y_hr, h_output - 1))
                    w_hr_adj = max(1, min(w_hr, w_output - x_hr_adj))
                    h_hr_adj = max(1, min(h_hr, h_output - y_hr_adj))

                    ax.imshow(img2[model_name][y_hr_adj:y_hr_adj+h_hr_adj, x_hr_adj:x_hr_adj+w_hr_adj])
                else:
                    # Display placeholder if region is out of bounds
                    ax.text(0.5, 0.5, "Region out of bounds",
                            horizontalalignment='center',
                            verticalalignment='center',
                            transform=ax.transAxes,
                            fontsize=20)
            else:
                # Display placeholder if model output is missing
                ax.text(0.5, 0.5, f"No {model_name} output available",
                        horizontalalignment='center',
                        verticalalignment='center',
                        transform=ax.transAxes,
                        fontsize=20)

    plt.tight_layout()
    plt.show()

# Interactive region selection
from IPython.display import display as ipydisplay
import ipywidgets as widgets
import numpy as np

def process_and_display_images(input_folder, result_folders, zoom_enabled=True):
    """
    Process and display images with interactive zoom controls.

    Parameters:
    input_folder: Path to folder containing original images
    result_folders: Dictionary with paths to model output folders
    zoom_enabled: Whether to enable interactive zooming
    """
    # Get list of input files
    input_list = sorted(glob.glob(os.path.join(input_folder, '*')))

    # Create dictionaries to store images
    all_inputs = {}
    all_outputs = {}

    # Load all images
    for input_path in input_list:
        # Get base filename without path and extension
        img_name = os.path.basename(input_path)
        name_without_ext = os.path.splitext(img_name)[0]

        # Load input image
        try:
            img_input = imread(input_path)
            all_inputs[img_name] = img_input
            all_outputs[img_name] = {}
        except Exception as e:
            print(f"Error loading input image {img_name}: {e}")
            continue

        # Load BSRGAN output
        bsrgan_path = os.path.join(result_folders['BSRGAN'], f"{name_without_ext}_BSRGAN.png")
        try:
            if os.path.exists(bsrgan_path):
                all_outputs[img_name]['BSRGAN'] = imread(bsrgan_path)
            else:
                print(f"BSRGAN output not found for {img_name}")
                all_outputs[img_name]['BSRGAN'] = np.ones((img_input.shape[0]*4, img_input.shape[1]*4, 3), dtype=np.uint8) * 255
        except Exception as e:
            print(f"Error loading BSRGAN output for {img_name}: {e}")
            all_outputs[img_name]['BSRGAN'] = np.ones((img_input.shape[0]*4, img_input.shape[1]*4, 3), dtype=np.uint8) * 255

        # Load Real-ESRGAN output
        # Try different possible file extensions (.jpg, .png)
        realesrgan_path_jpg = os.path.join(result_folders['realESRGAN'], f"{name_without_ext}_realESRGAN.jpg")
        realesrgan_path_png = os.path.join(result_folders['realESRGAN'], f"{name_without_ext}_realESRGAN.png")

        try:
            if os.path.exists(realesrgan_path_jpg):
                all_outputs[img_name]['realESRGAN'] = imread(realesrgan_path_jpg)
            elif os.path.exists(realesrgan_path_png):
                all_outputs[img_name]['realESRGAN'] = imread(realesrgan_path_png)
            else:
                print(f"Real-ESRGAN output not found for {img_name}")
                all_outputs[img_name]['realESRGAN'] = np.ones((img_input.shape[0]*4, img_input.shape[1]*4, 3), dtype=np.uint8) * 255
        except Exception as e:
            print(f"Error loading Real-ESRGAN output for {img_name}: {e}")
            all_outputs[img_name]['realESRGAN'] = np.ones((img_input.shape[0]*4, img_input.shape[1]*4, 3), dtype=np.uint8) * 255

        # Load SwinIR output
        swinir_path = os.path.join(result_folders['SwinIR-L'], f"{name_without_ext}_SwinIR_large.png")
        try:
            if os.path.exists(swinir_path):
                all_outputs[img_name]['SwinIR-L'] = imread(swinir_path)
            else:
                print(f"SwinIR-L output not found for {img_name}")
                all_outputs[img_name]['SwinIR-L'] = np.ones((img_input.shape[0]*4, img_input.shape[1]*4, 3), dtype=np.uint8) * 255
        except Exception as e:
            print(f"Error loading SwinIR-L output for {img_name}: {e}")
            all_outputs[img_name]['SwinIR-L'] = np.ones((img_input.shape[0]*4, img_input.shape[1]*4, 3), dtype=np.uint8) * 255

    if not all_inputs:
        print("No input images were loaded successfully. Check your input folder path.")
        return

    if not zoom_enabled:
        # Just display all images without zoom
        for img_name in all_inputs:
            print(f"Displaying results for {img_name}")
            display(all_inputs[img_name], all_outputs[img_name])
        return

    # Create widgets for interactive display
    image_dropdown = widgets.Dropdown(
        options=list(all_inputs.keys()),
        description='Image:',
        style={'description_width': 'initial'}
    )

    # Get initial image dimensions for sliders
    first_img = all_inputs[list(all_inputs.keys())[0]]
    init_height, init_width = first_img.shape[:2]

    x_slider = widgets.IntSlider(
        value=init_width//4, min=0, max=init_width-1,
        description='X position:',
        style={'description_width': 'initial'}
    )

    y_slider = widgets.IntSlider(
        value=init_height//4, min=0, max=init_height-1,
        description='Y position:',
        style={'description_width': 'initial'}
    )

    width_slider = widgets.IntSlider(
        value=min(100, init_width//2), min=1, max=init_width,
        description='Width:',
        style={'description_width': 'initial'}
    )

    height_slider = widgets.IntSlider(
        value=min(100, init_height//2), min=1, max=init_height,
        description='Height:',
        style={'description_width': 'initial'}
    )

    output = widgets.Output()

    def update_image_dimensions(*args):
        """Update slider ranges based on selected image"""
        img_name = image_dropdown.value
        img = all_inputs[img_name]
        height, width = img.shape[:2]

        # Update slider ranges
        x_slider.max = width - 1
        y_slider.max = height - 1
        width_slider.max = width
        height_slider.max = height

        # Adjust current values if needed
        if x_slider.value >= width:
            x_slider.value = width // 4
        if y_slider.value >= height:
            y_slider.value = height // 4
        if width_slider.value > width:
            width_slider.value = min(100, width // 2)
        if height_slider.value > height:
            height_slider.value = min(100, height // 2)

    def update_display(*args):
        img_name = image_dropdown.value
        zoom_region = (x_slider.value, y_slider.value, width_slider.value, height_slider.value)

        with output:
            output.clear_output(wait=True)
            print(f"Displaying results for {img_name} with zoom at region {zoom_region}")
            try:
                display_with_zoom(all_inputs[img_name], all_outputs[img_name], zoom_region)
            except Exception as e:
                print(f"Error displaying images: {e}")
                import traceback
                traceback.print_exc()

    # Link image dropdown to update dimensions
    image_dropdown.observe(update_image_dimensions, names='value')

    # Link widgets to update display
    image_dropdown.observe(update_display, names='value')
    x_slider.observe(update_display, names='value')
    y_slider.observe(update_display, names='value')
    width_slider.observe(update_display, names='value')
    height_slider.observe(update_display, names='value')

    # Create UI layout
    controls = widgets.VBox([image_dropdown,
                            widgets.HBox([x_slider, y_slider]),
                            widgets.HBox([width_slider, height_slider])])

    ui = widgets.VBox([controls, output])
    ipydisplay(ui)

    # Initialize display
    update_image_dimensions()
    update_display()

# Use the function with your specific folder paths
result_folders = {
    'BSRGAN': 'results/BSRGAN',
    'realESRGAN': 'results/realESRGAN',
    'SwinIR-L': 'results/SwinIR_large'
}

# Add this at the end of your script to use the interactive display
process_and_display_images('BSRGAN/testsets/RealSRSet', result_folders, zoom_enabled=True)

In [None]:
import os
print(os.path.exists('/content/Real-ESRGAN/results/BSRGAN'))
print(os.listdir('/content/Real-ESRGAN/results/BSRGAN'))

In [None]:
print(os.path.exists('/content/Real-ESRGAN/results/realESRGAN'))
print(os.listdir('/content/Real-ESRGAN/results/realESRGAN'))

In [None]:
print(os.path.exists('/content/Real-ESRGAN/results/SwinIR_large'))
print(os.listdir('/content/Real-ESRGAN/results/SwinIR_large'))

**Evaluation Metrics**:

PSNR,SSIM

In [None]:
# --- Step 1: Import libraries
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import cv2
import pandas as pd
import glob
import os
import matplotlib.pyplot as plt
from IPython.display import display

# --- Step 2: Setup base directory
base_folder = os.getcwd()  # Current working directory

# --- Step 3: Helper Functions
def imread(img_path):
    img = cv2.imread(img_path)
    if img is None:
        raise ValueError(f"Image not found or corrupted: {img_path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def resize_to_input(img_restored, img_input):
    return cv2.resize(img_restored, (img_input.shape[1], img_input.shape[0]), interpolation=cv2.INTER_CUBIC)

def evaluate_metrics(img_input, img_restored):
    psnr_score = psnr(img_input, img_restored, data_range=255)
    min_shape = min(img_input.shape[0], img_input.shape[1])
    win_size = 7 if min_shape >= 7 else (min_shape if min_shape % 2 == 1 else min_shape - 1)
    ssim_score = ssim(img_input, img_restored, data_range=255, channel_axis=-1, win_size=win_size)
    return psnr_score, ssim_score

def display_image_with_metrics(img_input, img_restored, model_name, psnr_score, ssim_score):
    plt.figure(figsize=(16, 8))

    plt.subplot(1, 2, 1)
    plt.imshow(img_input)
    plt.title("Input Image")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(img_restored)
    plt.title(f"{model_name} Output\nPSNR: {psnr_score:.2f} dB, SSIM: {ssim_score:.4f}")
    plt.axis('off')

    plt.tight_layout()
    plt.show()

# --- Step 4: Find output files directly
results_folder = os.path.join(base_folder, 'results')

# Find all available output images
bsrgan_outputs = glob.glob(os.path.join(results_folder, 'BSRGAN', '*_BSRGAN.png'))
realesrgan_outputs = glob.glob(os.path.join(results_folder, 'realESRGAN', '*_realESRGAN.*'))
swinir_outputs = glob.glob(os.path.join(results_folder, 'SwinIR_large', '*_SwinIR_large.png'))

print(f"Found {len(bsrgan_outputs)} BSRGAN outputs")
print(f"Found {len(realesrgan_outputs)} Real-ESRGAN outputs")
print(f"Found {len(swinir_outputs)} SwinIR outputs")

# Group outputs by base image name (without model suffix)
grouped_outputs = {}

# Process BSRGAN outputs
for output_path in bsrgan_outputs:
    filename = os.path.basename(output_path)
    base_name = filename.replace('_BSRGAN.png', '')
    if base_name not in grouped_outputs:
        grouped_outputs[base_name] = {'bsrgan': None, 'realesrgan': None, 'swinir': None}
    grouped_outputs[base_name]['bsrgan'] = output_path

# Process Real-ESRGAN outputs
for output_path in realesrgan_outputs:
    filename = os.path.basename(output_path)
    if filename.endswith('.jpg'):
        base_name = filename.replace('_realESRGAN.jpg', '')
    else:  # .png
        base_name = filename.replace('_realESRGAN.png', '')
    if base_name not in grouped_outputs:
        grouped_outputs[base_name] = {'bsrgan': None, 'realesrgan': None, 'swinir': None}
    grouped_outputs[base_name]['realesrgan'] = output_path

# Process SwinIR outputs
for output_path in swinir_outputs:
    filename = os.path.basename(output_path)
    base_name = filename.replace('_SwinIR_large.png', '')
    if base_name not in grouped_outputs:
        grouped_outputs[base_name] = {'bsrgan': None, 'realesrgan': None, 'swinir': None}
    grouped_outputs[base_name]['swinir'] = output_path

# --- Step 5: Find input files that correspond to our outputs
input_folder = '/content/inputs'  # Adjust this to your actual input folder

# Try different possible input file patterns
input_patterns = [
    os.path.join(input_folder, '*.jpg'),
    os.path.join(input_folder, '*.png'),
    os.path.join(input_folder, '*.jpeg'),
    os.path.join(base_folder, 'inputs', '*.jpg'),
    os.path.join(base_folder, 'inputs', '*.png'),
    os.path.join(base_folder, 'inputs', '*.jpeg')
]

input_files = []
for pattern in input_patterns:
    input_files.extend(glob.glob(pattern))

print(f"Found {len(input_files)} input files")

# Map input files to base names
input_map = {}
for input_path in input_files:
    filename = os.path.basename(input_path)
    base_name = os.path.splitext(filename)[0]
    input_map[base_name] = input_path

# --- Step 6: Process each set of outputs
results_list = []
successful_evaluations = 0

print("\n--- Processing each image ---")

for base_name, outputs in grouped_outputs.items():
    print(f"\nProcessing outputs for base name: {base_name}")

    # Try to find matching input file
    input_path = None

    # Exact match
    if base_name in input_map:
        input_path = input_map[base_name]
    else:
        # Try partial matches
        for input_base in input_map:
            if base_name in input_base or input_base in base_name:
                input_path = input_map[input_base]
                print(f"Found partial match: {input_base} for {base_name}")
                break

    if input_path is None:
        print(f" No matching input file found for {base_name}. Creating a reference from BSRGAN output...")
        # If no input is found, use one of the outputs as reference (not ideal but allows comparison)
        if outputs['bsrgan']:
            reference_img = imread(outputs['bsrgan'])
            # Create a synthetic "input" by downscaling
            input_img = cv2.resize(reference_img,
                                 (reference_img.shape[1]//4, reference_img.shape[0]//4),
                                 interpolation=cv2.INTER_AREA)
            input_img = cv2.resize(input_img,
                                 (reference_img.shape[1], reference_img.shape[0]),
                                 interpolation=cv2.INTER_CUBIC)
        else:
            print(" Cannot create reference. Skipping...")
            continue
    else:
        print(f"Using input file: {input_path}")
        try:
            input_img = imread(input_path)
        except Exception as e:
            print(f" Error reading input file: {e}")
            continue

    # Check if we have all three outputs
    if not all(outputs.values()):
        missing = [k for k, v in outputs.items() if v is None]
        print(f" Missing outputs: {', '.join(missing)}. Skipping complete evaluation...")

    # Process each model's output if available
    for model, output_path in outputs.items():
        if output_path is None:
            continue

        try:
            output_img = imread(output_path)
            resized_output = resize_to_input(output_img, input_img)

            psnr_score, ssim_score = evaluate_metrics(input_img, resized_output)

            model_name = {
                'bsrgan': 'BSRGAN',
                'realesrgan': 'Real-ESRGAN',
                'swinir': 'SwinIR-Large'
            }[model]

            # Store results
            results_list.append({
                "Image": base_name,
                "Model": model_name,
                "PSNR (dB)": psnr_score,
                "SSIM": ssim_score
            })

            # Display individual result
            print(f"{model_name} - PSNR: {psnr_score:.2f} dB, SSIM: {ssim_score:.4f}")
            display_image_with_metrics(input_img, resized_output, model_name, psnr_score, ssim_score)

            successful_evaluations += 1

        except Exception as e:
            print(f" Error processing {model} output for {base_name}: {e}")

# --- Step 7: Create and display results table
df_results = pd.DataFrame(results_list)

if df_results.empty:
    print("\n No evaluation results found! Please check your output files.")
else:
    print(f"\n Successfully evaluated {successful_evaluations} images.")
    print("\n Quantitative Evaluation Metrics (per image and per model):\n")
    display(df_results.style.format({"PSNR (dB)": "{:.2f}", "SSIM": "{:.4f}"}).set_caption("Per-Image Metrics"))

    # Calculate and display average metrics per model
    avg_results = df_results.groupby('Model')[['PSNR (dB)', 'SSIM']].mean().reset_index()
    print("\n Average Quantitative Metrics (per model):\n")
    display(avg_results.style.format({"PSNR (dB)": "{:.2f}", "SSIM": "{:.4f}"}).set_caption("Average Metrics"))

    # Calculate and display metrics per image (all models)
    img_results = df_results.groupby('Image')[['PSNR (dB)', 'SSIM']].mean().reset_index()
    print("\n Average Metrics (per image across all models):\n")
    display(img_results.style.format({"PSNR (dB)": "{:.2f}", "SSIM": "{:.4f}"}).set_caption("Per-Image Average Metrics"))

    # Save results to CSV
    df_results.to_csv('full_image_restoration_metrics.csv', index=False)
    avg_results.to_csv('average_metrics_per_model.csv', index=False)
    img_results.to_csv('average_metrics_per_image.csv', index=False)

    print("\n Results saved to CSV files.")

In [None]:
import matplotlib.pyplot as plt

# --- Bar Chart for PSNR
plt.figure(figsize=(8,6))
plt.bar(avg_results['Model'], avg_results['PSNR (dB)'])
plt.title('Average PSNR Comparison', fontsize=16)
plt.ylabel('PSNR (dB)', fontsize=14)
plt.xlabel('Model', fontsize=14)
plt.xticks(rotation=20)
plt.grid(axis='y')
plt.show()

# --- Bar Chart for SSIM
plt.figure(figsize=(8,6))
plt.bar(avg_results['Model'], avg_results['SSIM'])
plt.title('Average SSIM Comparison', fontsize=16)
plt.ylabel('SSIM', fontsize=14)
plt.xlabel('Model', fontsize=14)
plt.xticks(rotation=20)
plt.grid(axis='y')
plt.show()


**Histograms**

In [None]:
def plot_histograms(img1, img2):
    """
    Plot RGB histograms for all images to compare color distributions

    Parameters:
    - img1: Input image
    - img2: Dictionary containing output images from each model
    """
    fig = plt.figure(figsize=(20, 15))

    models = ['Input', 'BSRGAN', 'realESRGAN', 'SwinIR-L']
    images = [img1, img2['BSRGAN'], img2['realESRGAN'], img2['SwinIR-L']]

    # Plot histograms for each image
    for i, (model, img) in enumerate(zip(models, images)):
        ax = fig.add_subplot(2, 2, i+1)
        plt.title(f'{model} RGB Histogram', fontsize=20)

        colors = ('r', 'g', 'b')
        channel_names = ('Red', 'Green', 'Blue')

        # Plot histogram for each color channel
        for j, (color, name) in enumerate(zip(colors, channel_names)):
            histogram = cv2.calcHist([img], [j], None, [256], [0, 256])
            plt.plot(histogram, color=color, label=name)

        plt.xlim([0, 256])
        plt.xlabel('Pixel Intensity', fontsize=14)
        plt.ylabel('Number of Pixels', fontsize=14)
        plt.legend(fontsize=12)
        plt.grid(alpha=0.3)

    plt.tight_layout()
    plt.show()

def plot_cumulative_histograms(img1, img2):
    """
    Plot cumulative RGB histograms to analyze dynamic range

    Parameters:
    - img1: Input image
    - img2: Dictionary containing output images from each model
    """
    fig = plt.figure(figsize=(20, 15))

    models = ['Input', 'BSRGAN', 'realESRGAN', 'SwinIR-L']
    images = [img1, img2['BSRGAN'], img2['realESRGAN'], img2['SwinIR-L']]

    # Plot cumulative histograms for each image
    for i, (model, img) in enumerate(zip(models, images)):
        ax = fig.add_subplot(2, 2, i+1)
        plt.title(f'{model} Cumulative Histogram', fontsize=20)

        colors = ('r', 'g', 'b')
        channel_names = ('Red', 'Green', 'Blue')

        # Calculate and plot cumulative histogram for each channel
        for j, (color, name) in enumerate(zip(colors, channel_names)):
            histogram = cv2.calcHist([img], [j], None, [256], [0, 256])
            cumulative = histogram.cumsum()
            # Normalize to 0-100% for easier comparison
            cumulative_normalized = cumulative * 100 / cumulative.max()
            plt.plot(cumulative_normalized, color=color, label=name)

        plt.xlim([0, 256])
        plt.ylim([0, 100])
        plt.xlabel('Pixel Intensity', fontsize=14)
        plt.ylabel('Cumulative %', fontsize=14)
        plt.legend(fontsize=12)
        plt.grid(alpha=0.3)

    plt.tight_layout()
    plt.show()

def plot_brightness_distribution(img1, img2):
    """
    Plot brightness distribution as a combined histogram

    Parameters:
    - img1: Input image
    - img2: Dictionary containing output images from each model
    """
    fig = plt.figure(figsize=(12, 8))
    plt.title('Brightness Distribution Comparison', fontsize=20)

    models = ['Input', 'BSRGAN', 'realESRGAN', 'SwinIR-L']
    images = [img1, img2['BSRGAN'], img2['realESRGAN'], img2['SwinIR-L']]
    colors = ['black', 'blue', 'green', 'red']

    # Calculate and plot brightness histogram for each image
    for img, model, color in zip(images, models, colors):
        # Convert to grayscale to get brightness
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        histogram = cv2.calcHist([gray], [0], None, [256], [0, 256])
        # Normalize for comparison between images of different sizes
        histogram = histogram / histogram.sum()
        plt.plot(histogram, color=color, label=model, alpha=0.7, linewidth=2)

    plt.xlim([0, 256])
    plt.xlabel('Brightness Level', fontsize=14)
    plt.ylabel('Normalized Frequency', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

def process_and_analyze_histograms():
    """Process images and analyze their histograms"""
    input_folder = upload_folder
    result_folder = 'results/SwinIR_large'
    input_list = sorted(glob.glob(os.path.join(input_folder, '*')))
    output_list = sorted(glob.glob(os.path.join(result_folder, '*')))

    for input_path, output_path in zip(input_list, output_list):
        img_input = imread(input_path)
        img_output = {}
        img_output['SwinIR-L'] = imread(output_path)

        if test_patch_wise:
            img_output['BSRGAN'] = img_output['SwinIR-L'] * 0 + 255
        else:
            img_output['BSRGAN'] = imread(output_path.replace('SwinIR_large', 'BSRGAN'))

        path = output_path.replace('/SwinIR_large/', '/realESRGAN/').replace('_SwinIR_large.png', '_out{}'.format(os.path.splitext(input_path)[1]))
        if os.path.exists(path):
            shutil.move(path, path.replace('_out.', '_realESRGAN.'))
        img_output['realESRGAN'] = imread(path.replace('_out.', '_realESRGAN.'))

        # Plot different histogram visualizations
        print(f"Analyzing histograms for {os.path.basename(input_path)}")
        plot_histograms(img_input, img_output)
        plot_cumulative_histograms(img_input, img_output)
        plot_brightness_distribution(img_input, img_output)

In [None]:
# Analyze histograms
process_and_analyze_histograms()

In [None]:
import matplotlib.pyplot as plt
import cv2

def combined_histogram_plot(img_input, img_output):
    """
    Create a single figure combining RGB histograms, cumulative histograms, and brightness histograms
    for Input image and three restored outputs.
    """
    models = ['Input', 'BSRGAN', 'realESRGAN', 'SwinIR-L']
    images = [img_input, img_output['BSRGAN'], img_output['realESRGAN'], img_output['SwinIR-L']]
    colors = ('r', 'g', 'b')

    fig, axs = plt.subplots(3, 4, figsize=(24, 18))
    fig.suptitle('Histogram Analysis of Input vs Restored Images', fontsize=24)

    for col, (model, img) in enumerate(zip(models, images)):

        # --- Row 1: RGB Histograms
        for i, color in enumerate(colors):
            hist = cv2.calcHist([img], [i], None, [256], [0, 256])
            axs[0, col].plot(hist, color=color)
        axs[0, col].set_title(f'{model} RGB Histogram', fontsize=16)
        axs[0, col].set_xlim([0, 256])
        axs[0, col].set_ylim(bottom=0)
        axs[0, col].label_outer()

        # --- Row 2: Cumulative Histograms
        for i, color in enumerate(colors):
            hist = cv2.calcHist([img], [i], None, [256], [0, 256])
            cumulative = hist.cumsum()
            cumulative_normalized = cumulative * 100 / cumulative.max()
            axs[1, col].plot(cumulative_normalized, color=color)
        axs[1, col].set_title(f'{model} Cumulative Histogram', fontsize=16)
        axs[1, col].set_xlim([0, 256])
        axs[1, col].set_ylim([0, 100])
        axs[1, col].label_outer()

        # --- Row 3: Brightness (Grayscale) Histograms
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        brightness_hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
        brightness_hist = brightness_hist / brightness_hist.sum()
        axs[2, col].plot(brightness_hist, color='black')
        axs[2, col].set_title(f'{model} Brightness Histogram', fontsize=16)
        axs[2, col].set_xlim([0, 256])
        axs[2, col].set_ylim(bottom=0)
        axs[2, col].label_outer()

    # --- Layout
    for ax in axs.flat:
        ax.grid(alpha=0.3)

    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Leave space for the main title
    plt.show()


In [None]:
combined_histogram_plot(img_input, img_output)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig('combined_histogram_analysis.png', bbox_inches='tight', dpi=300)
plt.show()


In [None]:
%matplotlib inline

**Edge Detection and Preservation.**

In [None]:
# --- Import necessary libraries
import cv2
import matplotlib.pyplot as plt
import glob
import os
import numpy as np

# --- Setup base directory
base_folder = os.getcwd()  # Current working directory
results_folder = os.path.join(base_folder, 'results')

def imread(img_path):
    """Read an image with proper error handling and convert to RGB"""
    img = cv2.imread(img_path)
    if img is None:
        raise ValueError(f"Image not found or corrupted: {img_path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def find_input_file(base_name):
    """Try to find an input file that matches the base name"""
    input_patterns = [
        os.path.join('/content/inputs', f"{base_name}.*"),
        os.path.join(base_folder, 'inputs', f"{base_name}.*"),
        # Add more potential input locations as needed
    ]

    for pattern in input_patterns:
        matches = glob.glob(pattern)
        if matches:
            return matches[0]
    return None

def plot_edge_detection_multiple():
    """
    Process all images in the results folder and visualize edge detection comparison
    """
    # Find all base names from results
    bsrgan_outputs = glob.glob(os.path.join(results_folder, 'BSRGAN', '*_BSRGAN.png'))
    base_names = [os.path.basename(f).replace('_BSRGAN.png', '') for f in bsrgan_outputs]

    print(f"Found {len(base_names)} base images to process")

    for base_name in base_names:
        print(f"\nProcessing edge detection for: {base_name}")

        # Find paths for all versions of this image
        bsrgan_path = os.path.join(results_folder, 'BSRGAN', f"{base_name}_BSRGAN.png")

        # Try both .jpg and .png for realESRGAN
        realesrgan_path_jpg = os.path.join(results_folder, 'realESRGAN', f"{base_name}_realESRGAN.jpg")
        realesrgan_path_png = os.path.join(results_folder, 'realESRGAN', f"{base_name}_realESRGAN.png")
        realesrgan_path = realesrgan_path_jpg if os.path.exists(realesrgan_path_jpg) else realesrgan_path_png

        swinir_path = os.path.join(results_folder, 'SwinIR_large', f"{base_name}_SwinIR_large.png")

        # Find input file
        input_path = find_input_file(base_name)

        # Check if all files exist
        if not (os.path.exists(bsrgan_path) and os.path.exists(realesrgan_path) and
                os.path.exists(swinir_path)):
            print(f" Missing output files for {base_name}. Skipping...")
            continue

        if not input_path:
            print(f" Input file not found for {base_name}. Using BSRGAN output to create reference...")
            # Create a synthetic input by downscaling BSRGAN output
            img_bsrgan = imread(bsrgan_path)
            img_input = cv2.resize(img_bsrgan,
                                 (img_bsrgan.shape[1]//4, img_bsrgan.shape[0]//4),
                                 interpolation=cv2.INTER_AREA)
            img_input = cv2.resize(img_input,
                                 (img_bsrgan.shape[1], img_bsrgan.shape[0]),
                                 interpolation=cv2.INTER_CUBIC)
        else:
            print(f"Using input file: {input_path}")
            img_input = imread(input_path)

        # Load all images
        try:
            img_bsrgan = imread(bsrgan_path)
            img_realesrgan = imread(realesrgan_path)
            img_swinir = imread(swinir_path)

            # Resize all images to match input size
            img_bsrgan = cv2.resize(img_bsrgan, (img_input.shape[1], img_input.shape[0]),
                                  interpolation=cv2.INTER_CUBIC)
            img_realesrgan = cv2.resize(img_realesrgan, (img_input.shape[1], img_input.shape[0]),
                                     interpolation=cv2.INTER_CUBIC)
            img_swinir = cv2.resize(img_swinir, (img_input.shape[1], img_input.shape[0]),
                                  interpolation=cv2.INTER_CUBIC)

            # Apply edge detection to all images
            models = ['Input', 'BSRGAN', 'realESRGAN', 'SwinIR-L']
            images = [img_input, img_bsrgan, img_realesrgan, img_swinir]
            edge_images = []

            for img in images:
                gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
                edges = cv2.Canny(gray, threshold1=100, threshold2=200)
                edge_images.append(edges)

            # Plot edge detection comparison
            fig, axs = plt.subplots(1, 4, figsize=(24, 6))
            fig.suptitle(f'Edge Detection Comparison for {base_name} (Canny)', fontsize=24)

            for idx, (model, edge_img) in enumerate(zip(models, edge_images)):
                axs[idx].imshow(edge_img, cmap='gray')
                axs[idx].set_title(model, fontsize=18)
                axs[idx].axis('off')

            plt.tight_layout()
            plt.show()

            # You could also save the figure if needed
            # plt.savefig(f'edge_detection_{base_name}.png', dpi=300, bbox_inches='tight')

        except Exception as e:
            print(f" Error processing {base_name}: {e}")

# Run for all available images
plot_edge_detection_multiple()



def analyze_edge_preservation(base_name=None):
    """
    Quantitatively analyze edge preservation by measuring edge density and similarity
    """
    # If no base_name provided, find all images
    if base_name is None:
        bsrgan_outputs = glob.glob(os.path.join(results_folder, 'BSRGAN', '*_BSRGAN.png'))
        base_names = [os.path.basename(f).replace('_BSRGAN.png', '') for f in bsrgan_outputs]
    else:
        base_names = [base_name]

    results = []

    for base_name in base_names:
        print(f"\nAnalyzing edge preservation for: {base_name}")

        # Find paths for all versions of this image
        bsrgan_path = os.path.join(results_folder, 'BSRGAN', f"{base_name}_BSRGAN.png")

        # Try both .jpg and .png for realESRGAN
        realesrgan_path_jpg = os.path.join(results_folder, 'realESRGAN', f"{base_name}_realESRGAN.jpg")
        realesrgan_path_png = os.path.join(results_folder, 'realESRGAN', f"{base_name}_realESRGAN.png")
        realesrgan_path = realesrgan_path_jpg if os.path.exists(realesrgan_path_jpg) else realesrgan_path_png

        swinir_path = os.path.join(results_folder, 'SwinIR_large', f"{base_name}_SwinIR_large.png")

        # Find input file
        input_path = find_input_file(base_name)

        # Check if all files exist
        if not (os.path.exists(bsrgan_path) and os.path.exists(realesrgan_path) and
                os.path.exists(swinir_path)):
            print(f" Missing output files for {base_name}. Skipping...")
            continue

        if not input_path:
            print(f" Input file not found for {base_name}. Using BSRGAN output to create reference...")
            img_bsrgan = imread(bsrgan_path)
            img_input = cv2.resize(img_bsrgan,
                                 (img_bsrgan.shape[1]//4, img_bsrgan.shape[0]//4),
                                 interpolation=cv2.INTER_AREA)
            img_input = cv2.resize(img_input,
                                 (img_bsrgan.shape[1], img_bsrgan.shape[0]),
                                 interpolation=cv2.INTER_CUBIC)
        else:
            img_input = imread(input_path)

        try:
            # Load and resize all images
            img_bsrgan = imread(bsrgan_path)
            img_realesrgan = imread(realesrgan_path)
            img_swinir = imread(swinir_path)

            img_bsrgan = cv2.resize(img_bsrgan, (img_input.shape[1], img_input.shape[0]))
            img_realesrgan = cv2.resize(img_realesrgan, (img_input.shape[1], img_input.shape[0]))
            img_swinir = cv2.resize(img_swinir, (img_input.shape[1], img_input.shape[0]))

            # Convert to grayscale
            gray_input = cv2.cvtColor(img_input, cv2.COLOR_RGB2GRAY)
            gray_bsrgan = cv2.cvtColor(img_bsrgan, cv2.COLOR_RGB2GRAY)
            gray_realesrgan = cv2.cvtColor(img_realesrgan, cv2.COLOR_RGB2GRAY)
            gray_swinir = cv2.cvtColor(img_swinir, cv2.COLOR_RGB2GRAY)

            # Apply Canny edge detection
            edges_input = cv2.Canny(gray_input, 100, 200)
            edges_bsrgan = cv2.Canny(gray_bsrgan, 100, 200)
            edges_realesrgan = cv2.Canny(gray_realesrgan, 100, 200)
            edges_swinir = cv2.Canny(gray_swinir, 100, 200)

            # Calculate edge density (percentage of edge pixels)
            density_input = np.sum(edges_input > 0) / (edges_input.shape[0] * edges_input.shape[1])
            density_bsrgan = np.sum(edges_bsrgan > 0) / (edges_bsrgan.shape[0] * edges_bsrgan.shape[1])
            density_realesrgan = np.sum(edges_realesrgan > 0) / (edges_realesrgan.shape[0] * edges_realesrgan.shape[1])
            density_swinir = np.sum(edges_swinir > 0) / (edges_swinir.shape[0] * edges_swinir.shape[1])

            # Calculate edge similarity (intersection over union of edge pixels)
            similarity_bsrgan = np.sum(np.logical_and(edges_input > 0, edges_bsrgan > 0)) / \
                              np.sum(np.logical_or(edges_input > 0, edges_bsrgan > 0))
            similarity_realesrgan = np.sum(np.logical_and(edges_input > 0, edges_realesrgan > 0)) / \
                                 np.sum(np.logical_or(edges_input > 0, edges_realesrgan > 0))
            similarity_swinir = np.sum(np.logical_and(edges_input > 0, edges_swinir > 0)) / \
                              np.sum(np.logical_or(edges_input > 0, edges_swinir > 0))

            # Store results
            results.append({
                'Image': base_name,
                'Model': 'Input',
                'Edge Density': density_input,
                'Edge Similarity': 1.0  # Self-similarity is 1.0
            })
            results.append({
                'Image': base_name,
                'Model': 'BSRGAN',
                'Edge Density': density_bsrgan,
                'Edge Similarity': similarity_bsrgan
            })
            results.append({
                'Image': base_name,
                'Model': 'Real-ESRGAN',
                'Edge Density': density_realesrgan,
                'Edge Similarity': similarity_realesrgan
            })
            results.append({
                'Image': base_name,
                'Model': 'SwinIR-Large',
                'Edge Density': density_swinir,
                'Edge Similarity': similarity_swinir
            })

        except Exception as e:
            print(f" Error analyzing {base_name}: {e}")

    # Create DataFrame and display results
    import pandas as pd
    df_edges = pd.DataFrame(results)

    if not df_edges.empty:
        # Display per-image results
        print("\nEdge Preservation Analysis (per image and per model):")
        edge_table = df_edges.pivot_table(
            index='Image', columns='Model',
            values=['Edge Density', 'Edge Similarity']
        )
        display(edge_table.style.format("{:.4f}").set_caption("Edge Preservation Metrics"))

        # Display average results per model
        print("\nAverage Edge Preservation (per model):")
        avg_edges = df_edges.groupby('Model')[['Edge Density', 'Edge Similarity']].mean()
        display(avg_edges.style.format("{:.4f}").set_caption("Average Edge Preservation"))

        # Save results
        df_edges.to_csv('edge_preservation_metrics.csv', index=False)
        print("\nResults saved to 'edge_preservation_metrics.csv'")

        return df_edges
    else:
        print("No edge analysis results found!")
        return None

# Run edge preservation analysis
edge_results = analyze_edge_preservation()