# Initialization

In [15]:
!rm -rf Deblur

In [16]:
!git clone https://github.com/JPBrito0528/Deblur

Cloning into 'Deblur'...
remote: Enumerating objects: 289, done.[K
remote: Counting objects: 100% (22/22), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 289 (delta 5), reused 19 (delta 2), pack-reused 267 (from 1)[K
Receiving objects: 100% (289/289), 397.39 MiB | 22.93 MiB/s, done.
Resolving deltas: 100% (8/8), done.


In [17]:
!pip install pytorch-msssim



In [18]:
!pip install numpy==1.24.3
!pip install imgaug



In [19]:
pip install torchviz



In [20]:
import sys
sys.path.append('/content/Deblur')

from pytorch_msssim import ssim
from google.colab.patches import cv2_imshow
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
from torchviz import make_dot
import model
from model import common
import cv2 as cv
import numpy as np
import os
import torch.nn as nn
from tkinter import filedialog
import random
from math import ceil, floor
import pandas as pd
from imgaug import augmenters as iaa
import matplotlib.pyplot as plt

# Main

In [24]:

# --------------------------
# Helper function to display images in Colab
def imshow_cv2(img, title="Image", resize_factor=0.5):
    """Display an image using matplotlib instead of cv.imshow."""
    # Convert BGR (OpenCV format) to RGB for display with matplotlib
    if len(img.shape) == 3 and img.shape[2] == 3:
        img_disp = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    else:
        img_disp = img
    # Resize image for display if desired
    if resize_factor != 1:
        img_disp = cv.resize(img_disp, (0, 0), fx=resize_factor, fy=resize_factor)

    plt.figure(figsize=(10, 6))
    plt.imshow(img_disp, cmap='gray' if len(img_disp.shape) == 2 else None)
    plt.title(title)
    plt.axis('off')
    plt.show()


def build_square_mask(corners_list, H, W, device):
    """
    corners_list: [(TL_x,TL_y,TR_x,TR_y,BL_x,BL_y,BR_x,BR_y), …] length B
    returns: mask [B,1,H,W], 1 inside the square, 0 outside
    """
    B = len(corners_list)
    masks = torch.zeros(B, 1, H, W, device=device)
    for i, c in enumerate(corners_list):
        xs = sorted(c[0::2]); ys = sorted(c[1::2])
        # drop the extreme corners → inner square
        x0 = max(int(xs[1]) + 1, 0)
        x1 = min(int(xs[-2]) - 1, W)
        y0 = max(int(ys[1]) + 1, 0)
        y1 = min(int(ys[-2]) - 1, H)
        if x1 > x0 and y1 > y0:
            masks[i, 0, y0:y1, x0:x1] = 1.0
    return masks
# ----------------------------

DEBUG=True
TEST=False
TRAIN=False
TRAIN_TYPE = 4 # 0 = weighted ROI ; 1 = black inverted ROI ; 2 = traditional MSE ; 3 - binary output; 4 - SSIM
DNN_TYPE = 0 # 0 = EDSR ; 1 = UNET

IMG_WIDTH = int(1920/2)
IMG_HEIGHT = int(1080/2)
IMG_CHANNELS = 3
input_shape = (IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)

SOURCE_FOLDER = '/content/Deblur'
#SOURCE_FOLDER = 'C:/Users/guilherme.franco/Documents/GAF/Estagio_Rui_Brito'
training_history_filename = SOURCE_FOLDER+'/Results/log.csv'

seed=420
SPLIT = 0.895
EPOCHS = 500
BATCH_SIZE=4
LEARNING_RATE = 10e-5

### Setup Dataset ###############################################################################################

raw_df = pd.read_csv(SOURCE_FOLDER+'/Lens_Dataset/data.csv').sample(frac=1)  #(read and shuffle)

raw_size = int(len(raw_df))

train_df = raw_df.head(floor(int(raw_size*SPLIT)))
train_size = int(len(train_df))

val_df = raw_df.tail(ceil(int(raw_size*(1-SPLIT))))
val_size = int(len(val_df))

train_df = train_df.sample(frac=1)
val_df = val_df.sample(frac=1)

# ── build lists of IDs + corner‐tuples for training & val ──
all_ids         = raw_df['ID'].astype(str).tolist()
all_corners     = raw_df[[
    'cornerTL_x','cornerTL_y',
    'cornerTR_x','cornerTR_y',
    'cornerBL_x','cornerBL_y',
    'cornerBR_x','cornerBR_y'
]].values.tolist()

train_img_list     = all_ids[:train_size]
train_corners_list = all_corners[:train_size]
val_img_list       = all_ids[train_size:]
val_corners_list   = all_corners[train_size:]
# ───────────────────────────────────────────────────────

#################################################################################################################

def yield_training_batch(img_file_list, corners_list):

    batch_size = len(img_file_list)

    input_images = []
    output_images = []
    batch_corners = []

    for f in range(0,batch_size):

        # Read image from list and convert to array
        input_image_path = SOURCE_FOLDER+'/Lens_Dataset/INPUTS/'+str(img_file_list[f])+'.png'
        output_image_path = SOURCE_FOLDER+'/Lens_Dataset/OUTPUTS/'+str(img_file_list[f])+'.png'

        input_image = cv.imread(input_image_path)
        output_image = cv.imread(output_image_path)

        cornerTL_x = corners_list[f][0]
        cornerTL_y = corners_list[f][1]
        cornerTR_x = corners_list[f][2]
        cornerTR_y = corners_list[f][3]
        cornerBL_x = corners_list[f][4]
        cornerBL_y = corners_list[f][5]
        cornerBR_x = corners_list[f][6]
        cornerBR_y = corners_list[f][7]


        if DEBUG:
            debug_img = output_image.copy()
            cv.circle(debug_img,(int(cornerTL_x),int(cornerTL_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerTL_x),int(cornerTL_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerTR_x),int(cornerTR_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerTR_x),int(cornerTR_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerBL_x),int(cornerBL_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerBL_x),int(cornerBL_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerBR_x),int(cornerBR_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerBR_x),int(cornerBR_y)),6,(255,255,255),-1)

            cv.imshow('PRE DEBUG OUTPT',  cv.resize(debug_img,(960,720)))


        #generate random vector to determine which augmentations to perform
        R = []
        for r in range(0,7):
            if DEBUG:
                R.append(0)
            else:
                rn = random.random()
                R.append(rn)

        #======ZOOM AND CROP OR NOT?======#

        if R[0] < 0.3:

            aspect_ratio = input_image.shape[1]/input_image.shape[0]
            target_ratio = IMG_WIDTH/IMG_HEIGHT
            height_center = int(input_image.shape[0]/2)
            width_center = int(input_image.shape[1]/2)

            if aspect_ratio > IMG_WIDTH/IMG_HEIGHT:
                target_width = int(input_image.shape[0]*target_ratio)
                input_image = input_image[0:input_image.shape[0], width_center-floor(target_width/2):width_center+floor(target_width/2)]
                output_image = output_image[0:input_image.shape[0], width_center-floor(target_width/2):width_center+floor(target_width/2)]

            elif aspect_ratio < IMG_WIDTH/IMG_HEIGHT:
                target_height = int(input_image.shape[1]*target_ratio)
                input_image = input_image[height_center-floor(target_height/2):height_center+floor(target_height/2), 0:input_image.shape[1]]
                output_image = output_image[height_center-floor(target_height/2):height_center+floor(target_height/2), 0:input_image.shape[1]]

            rn = random.uniform(1,1)    #how much to zoom in

            crop_w = floor(IMG_WIDTH*rn)    #how much to crop
            crop_h = floor(IMG_HEIGHT*rn)

            w_space = IMG_WIDTH - crop_w    #how much space is left
            h_space = IMG_HEIGHT - crop_h

            random_center_x = int(crop_w/2) + random.randint(0,w_space)     #center of the zoom/crop
            random_center_y = int(crop_h/2) + random.randint(0,h_space)

            #crops according to random parameters and resizes to desired input/output size
            input_image = input_image[random_center_y-int(crop_h/2):random_center_y+int(crop_h/2),random_center_x-int(crop_w/2):random_center_x+int(crop_w/2)]
            input_image = cv.resize(input_image, (IMG_WIDTH, IMG_HEIGHT), interpolation = cv.INTER_CUBIC)
            output_image = output_image[random_center_y-int(crop_h/2):random_center_y+int(crop_h/2),random_center_x-int(crop_w/2):random_center_x+int(crop_w/2)]
            output_image = cv.resize(output_image, (IMG_WIDTH, IMG_HEIGHT), interpolation = cv.INTER_NEAREST)


            # TODO: estas contas provavelmente estão mal, pelo que 'desabilitei o random crop e zoom'
            cornerTL_x = (cornerTL_x-(random_center_x - (IMG_WIDTH/2)))*rn
            cornerTL_y = (cornerTL_y-(random_center_y - (IMG_HEIGHT/2)))*rn
            cornerTR_x = (cornerTR_x-(random_center_x - (IMG_WIDTH/2)))*rn
            cornerTR_y = (cornerTR_y-(random_center_y - (IMG_HEIGHT/2)))*rn
            cornerBL_x = (cornerBL_x-(random_center_x - (IMG_WIDTH/2)))*rn
            cornerBL_y = (cornerBL_y-(random_center_y - (IMG_HEIGHT/2)))*rn
            cornerBR_x = (cornerBR_x-(random_center_x - (IMG_WIDTH/2)))*rn
            cornerBR_y = (cornerBR_y-(random_center_y - (IMG_HEIGHT/2)))*rn

        else:

            input_image = cv.resize(input_image, (IMG_WIDTH, IMG_HEIGHT), interpolation = cv.INTER_CUBIC)
            output_image = cv.resize(output_image, (IMG_WIDTH, IMG_HEIGHT), interpolation = cv.INTER_CUBIC)

        #======FLIP OR NOT?======#

        #Flip Horizontaly
        if R[1] < 0.5:
            input_image = cv.flip(input_image, 1)
            output_image = cv.flip(output_image, 1)
            cornerTL_x = (IMG_WIDTH) - cornerTL_x
            cornerTR_x = (IMG_WIDTH) - cornerTR_x
            cornerBL_x = (IMG_WIDTH) - cornerBL_x
            cornerBR_x = (IMG_WIDTH) - cornerBR_x

        #Flip Verticaly
        if R[2] < 0.5:
            input_image = cv.flip(input_image, 0)
            output_image = cv.flip(output_image, 0)
            cornerTL_y = (IMG_HEIGHT) - cornerTL_y
            cornerTR_y = (IMG_HEIGHT) - cornerTR_y
            cornerBL_y = (IMG_HEIGHT) - cornerBL_y
            cornerBR_y = (IMG_HEIGHT) - cornerBR_y

        #Color Shift
        if R[3] < 0.3:
            seq = iaa.Sequential([iaa.MultiplyHueAndSaturation((0.8, 1.2), per_channel=True)])
            input_image = seq.augment_image(input_image)
            output_image = seq.augment_image(output_image)

        #Random Simplex Noise Blobs
        # if R[4] < 0.3:
        #     seq = iaa.SimplexNoiseAlpha( first=iaa.Multiply(mul = (0.6,1.4),per_channel=True),per_channel=True)
        #     input_image = seq.augment_image(input_image)
        #     output_image = seq.augment_image(output_image)

        #Now we have to work in [0,1] instead of [0,255]
        input_image = input_image/255.0
        output_image = output_image/255.0

        # Gaussian Blur
        # if R[5] < 0.3:
        #     input_image = cv.GaussianBlur(input_image,(5,5),0)

        #Gaussian Noise
        # if R[6] < 0:#0.2:
        #     mean = 0
        #     var = 0.001
        #     sigma = var**0.5
        #     gauss = np.random.normal(mean,sigma,(IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS))
        #     gauss = gauss.reshape(IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS)
        #     input_image = input_image + gauss

        if DEBUG:
            debug_img = output_image.copy()
            cv.circle(debug_img,(int(cornerTL_x),int(cornerTL_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerTL_x),int(cornerTL_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerTR_x),int(cornerTR_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerTR_x),int(cornerTR_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerBL_x),int(cornerBL_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerBL_x),int(cornerBL_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerBR_x),int(cornerBR_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerBR_x),int(cornerBR_y)),6,(255,255,255),-1)

            cv.imshow('DEBUG OUTPT',  cv.resize(debug_img,(960,720)))
            cv.waitKey()

        input_images.append(input_image)
        output_images.append(output_image)
        batch_corners.append((cornerTL_x,cornerTL_y,cornerTR_x,cornerTR_y,cornerBL_x,cornerBL_y,cornerBR_x,cornerBR_y))

    input_images_np = np.array(input_images)
    output_images_np = np.array(output_images)

    if len(input_images) == 0:
        # no valid images this batch
        return None, None, []

    # stack → (B, H, W, C)
    input_np  = np.stack(input_images,  axis=0).astype(np.float32)
    output_np = np.stack(output_images, axis=0).astype(np.float32)

    # one‑shot tensor conversion + permute → (B, C, H, W)
    input_tensor  = torch.from_numpy(input_np).permute(0, 3, 1, 2).cuda()
    output_tensor = torch.from_numpy(output_np).permute(0, 3, 1, 2).cuda()

    return input_tensor, output_tensor, batch_corners

def yield_training_batch_black_BG(img_file_list, corners_list):
    batch_size = len(img_file_list)
    input_images = []
    output_images = []
    batch_corners = []

    for f in range(batch_size):
        try:
            # Build paths and load images
            input_image_path = SOURCE_FOLDER + '/Lens_Dataset/INPUTS/' + str(img_file_list[f]) + '.png'
            output_image_path = SOURCE_FOLDER + '/Lens_Dataset/OUTPUTS/' + str(img_file_list[f]) + '.png'
            input_image = cv.imread(input_image_path)
            output_image = cv.imread(output_image_path)

            # Check if image loading failed; if so, skip this image.
            if input_image is None or output_image is None:
                print("Skipping image", img_file_list[f], "- failed to load one or both images.")
                continue

            # Extract corner coordinates
            cornerTL_x = corners_list[f][0]
            cornerTL_y = corners_list[f][1]
            cornerTR_x = corners_list[f][2]
            cornerTR_y = corners_list[f][3]
            cornerBL_x = corners_list[f][4]
            cornerBL_y = corners_list[f][5]
            cornerBR_x = corners_list[f][6]
            cornerBR_y = corners_list[f][7]

            # (Optional debugging display omitted if not needed)

            # Generate random vector for augmentation decisions
            R = []
            for r in range(7):
                # You can remove DEBUG-related conditions if not required.
                R.append(random.random())

            # ====== ZOOM AND CROP OR NOT? ====== #
            if R[0] < 0.3:
                aspect_ratio = input_image.shape[1] / input_image.shape[0]
                target_ratio = IMG_WIDTH / IMG_HEIGHT
                height_center = int(input_image.shape[0] / 2)
                width_center = int(input_image.shape[1] / 2)

                if aspect_ratio > target_ratio:
                    target_width = int(input_image.shape[0] * target_ratio)
                    input_image = input_image[:, width_center - floor(target_width / 2):width_center + floor(target_width / 2)]
                    output_image = output_image[:, width_center - floor(target_width / 2):width_center + floor(target_width / 2)]
                elif aspect_ratio < target_ratio:
                    target_height = int(input_image.shape[1] / target_ratio)
                    input_image = input_image[height_center - floor(target_height / 2):height_center + floor(target_height / 2), :]
                    output_image = output_image[height_center - floor(target_height / 2):height_center + floor(target_height / 2), :]

                rn = random.uniform(1, 1)  # how much to zoom in
                crop_w = floor(IMG_WIDTH * rn)
                crop_h = floor(IMG_HEIGHT * rn)
                w_space = IMG_WIDTH - crop_w
                h_space = IMG_HEIGHT - crop_h

                random_center_x = int(crop_w / 2) + random.randint(0, w_space)
                random_center_y = int(crop_h / 2) + random.randint(0, h_space)

                # Crop according to random parameters and resize to desired size
                input_image = input_image[random_center_y - int(crop_h / 2):random_center_y + int(crop_h / 2),
                                          random_center_x - int(crop_w / 2):random_center_x + int(crop_w / 2)]
                input_image = cv.resize(input_image, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv.INTER_CUBIC)
                output_image = output_image[random_center_y - int(crop_h / 2):random_center_y + int(crop_h / 2),
                                            random_center_x - int(crop_w / 2):random_center_x + int(crop_w / 2)]
                output_image = cv.resize(output_image, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv.INTER_NEAREST)

                # Adjust corner coordinates (adjust these calculations as needed)
                cornerTL_x = (cornerTL_x - (random_center_x - (IMG_WIDTH / 2))) * rn
                cornerTL_y = (cornerTL_y - (random_center_y - (IMG_HEIGHT / 2))) * rn
                cornerTR_x = (cornerTR_x - (random_center_x - (IMG_WIDTH / 2))) * rn
                cornerTR_y = (cornerTR_y - (random_center_y - (IMG_HEIGHT / 2))) * rn
                cornerBL_x = (cornerBL_x - (random_center_x - (IMG_WIDTH / 2))) * rn
                cornerBL_y = (cornerBL_y - (random_center_y - (IMG_HEIGHT / 2))) * rn
                cornerBR_x = (cornerBR_x - (random_center_x - (IMG_WIDTH / 2))) * rn
                cornerBR_y = (cornerBR_y - (random_center_y - (IMG_HEIGHT / 2))) * rn

            else:
                input_image = cv.resize(input_image, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv.INTER_CUBIC)
                output_image = cv.resize(output_image, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv.INTER_CUBIC)

            # ====== FLIP OR NOT? ====== #
            if R[1] < 0.5:
                input_image = cv.flip(input_image, 1)
                output_image = cv.flip(output_image, 1)
                cornerTL_x = IMG_WIDTH - cornerTL_x
                cornerTR_x = IMG_WIDTH - cornerTR_x
                cornerBL_x = IMG_WIDTH - cornerBL_x
                cornerBR_x = IMG_WIDTH - cornerBR_x

            if R[2] < 0.5:
                input_image = cv.flip(input_image, 0)
                output_image = cv.flip(output_image, 0)
                cornerTL_y = IMG_HEIGHT - cornerTL_y
                cornerTR_y = IMG_HEIGHT - cornerTR_y
                cornerBL_y = IMG_HEIGHT - cornerBL_y
                cornerBR_y = IMG_HEIGHT - cornerBR_y

            if R[3] < 0.3:
                seq = iaa.Sequential([iaa.MultiplyHueAndSaturation((0.8, 1.2), per_channel=True)])
                input_image = seq.augment_image(input_image)
                output_image = seq.augment_image(output_image)

            # Convert pixel range from [0,255] to [0,1]
            input_image = input_image / 255.0
            output_image = output_image / 255.0

            sorted_x = np.sort((cornerTL_x, cornerBL_x, cornerTR_x, cornerBR_x))
            sorted_y = np.sort((cornerTL_y, cornerBL_y, cornerTR_y, cornerBR_y))
            min_x = int(sorted_x[1]) + 10
            max_x = int(sorted_x[-2]) - 10
            min_y = int(sorted_y[1]) + 10
            max_y = int(sorted_y[-2]) - 10

            cv.rectangle(output_image, (0, 0), (IMG_WIDTH, min_y), (0, 0, 0), -1)
            cv.rectangle(output_image, (0, max_y), (IMG_WIDTH, IMG_HEIGHT), (0, 0, 0), -1)
            cv.rectangle(output_image, (0, 0), (min_x, IMG_HEIGHT), (0, 0, 0), -1)
            cv.rectangle(output_image, (max_x, 0), (IMG_WIDTH, IMG_HEIGHT), (0, 0, 0), -1)

            # (Optional debugging display omitted if not needed)

            # Append processed images and adjusted corners to lists
            input_images.append(input_image)
            output_images.append(output_image)
            batch_corners.append((cornerTL_x, cornerTL_y,
                                  cornerTR_x, cornerTR_y,
                                  cornerBL_x, cornerBL_y,
                                  cornerBR_x, cornerBR_y))

        except Exception as e:
            # If an error occurs processing an image, skip to the next image.
            continue

    # If no valid images were processed, return None for batch inputs/outputs
    if len(input_images) == 0:
        return None, None, []

    # Convert lists to torch tensors and rearrange dimensions to [B, C, H, W]
    input_images_tensor = torch.Tensor(np.array(input_images)).permute(0, 3, 1, 2)
    output_images_tensor = torch.Tensor(np.array(output_images)).permute(0, 3, 1, 2)
    return input_images_tensor.cuda(), output_images_tensor.cuda(), batch_corners





def yield_training_batch_binary(img_file_list, corners_list):

    batch_size = len(img_file_list)

    input_images = []
    output_images = []
    batch_corners = []

    for f in range(0,batch_size):

        # Read image from list and convert to array
        input_image_path = SOURCE_FOLDER+'/Lens_Dataset/INPUTS/'+str(img_file_list[f])+'.png'
        output_image_path = SOURCE_FOLDER+'/Lens_Dataset/OUTPUTS/'+str(img_file_list[f])+'.png'

        input_image = cv.imread(input_image_path)
        output_image = cv.imread(output_image_path)

        cornerTL_x = corners_list[f][0]
        cornerTL_y = corners_list[f][1]
        cornerTR_x = corners_list[f][2]
        cornerTR_y = corners_list[f][3]
        cornerBL_x = corners_list[f][4]
        cornerBL_y = corners_list[f][5]
        cornerBR_x = corners_list[f][6]
        cornerBR_y = corners_list[f][7]


        if DEBUG:
            debug_img = output_image.copy()
            cv.circle(debug_img,(int(cornerTL_x),int(cornerTL_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerTL_x),int(cornerTL_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerTR_x),int(cornerTR_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerTR_x),int(cornerTR_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerBL_x),int(cornerBL_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerBL_x),int(cornerBL_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerBR_x),int(cornerBR_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerBR_x),int(cornerBR_y)),6,(255,255,255),-1)

            cv.imshow('PRE DEBUG OUTPT',  cv.resize(debug_img,(960,720)))


        #generate random vector to determine which augmentations to perform
        R = []
        for r in range(0,7):
            if DEBUG:
                R.append(0)
            else:
                rn = random.random()
                R.append(rn)

        #======ZOOM AND CROP OR NOT?======#

        if R[0] < 0.3:

            aspect_ratio = input_image.shape[1]/input_image.shape[0]
            target_ratio = IMG_WIDTH/IMG_HEIGHT
            height_center = int(input_image.shape[0]/2)
            width_center = int(input_image.shape[1]/2)

            if aspect_ratio > IMG_WIDTH/IMG_HEIGHT:
                target_width = int(input_image.shape[0]*target_ratio)
                input_image = input_image[0:input_image.shape[0], width_center-floor(target_width/2):width_center+floor(target_width/2)]
                output_image = output_image[0:input_image.shape[0], width_center-floor(target_width/2):width_center+floor(target_width/2)]

            elif aspect_ratio < IMG_WIDTH/IMG_HEIGHT:
                target_height = int(input_image.shape[1]*target_ratio)
                input_image = input_image[height_center-floor(target_height/2):height_center+floor(target_height/2), 0:input_image.shape[1]]
                output_image = output_image[height_center-floor(target_height/2):height_center+floor(target_height/2), 0:input_image.shape[1]]

            rn = random.uniform(1,1)    #how much to zoom in

            crop_w = floor(IMG_WIDTH*rn)    #how much to crop
            crop_h = floor(IMG_HEIGHT*rn)

            w_space = IMG_WIDTH - crop_w    #how much space is left
            h_space = IMG_HEIGHT - crop_h

            random_center_x = int(crop_w/2) + random.randint(0,w_space)     #center of the zoom/crop
            random_center_y = int(crop_h/2) + random.randint(0,h_space)

            #crops according to random parameters and resizes to desired input/output size
            input_image = input_image[random_center_y-int(crop_h/2):random_center_y+int(crop_h/2),random_center_x-int(crop_w/2):random_center_x+int(crop_w/2)]
            input_image = cv.resize(input_image, (IMG_WIDTH, IMG_HEIGHT), interpolation = cv.INTER_CUBIC)
            output_image = output_image[random_center_y-int(crop_h/2):random_center_y+int(crop_h/2),random_center_x-int(crop_w/2):random_center_x+int(crop_w/2)]
            output_image = cv.resize(output_image, (IMG_WIDTH, IMG_HEIGHT), interpolation = cv.INTER_NEAREST)


            # TODO: estas contas provavelmente estão mal, pelo que 'desabilitei o random crop e zoom'
            cornerTL_x = (cornerTL_x-(random_center_x - (IMG_WIDTH/2)))*rn
            cornerTL_y = (cornerTL_y-(random_center_y - (IMG_HEIGHT/2)))*rn
            cornerTR_x = (cornerTR_x-(random_center_x - (IMG_WIDTH/2)))*rn
            cornerTR_y = (cornerTR_y-(random_center_y - (IMG_HEIGHT/2)))*rn
            cornerBL_x = (cornerBL_x-(random_center_x - (IMG_WIDTH/2)))*rn
            cornerBL_y = (cornerBL_y-(random_center_y - (IMG_HEIGHT/2)))*rn
            cornerBR_x = (cornerBR_x-(random_center_x - (IMG_WIDTH/2)))*rn
            cornerBR_y = (cornerBR_y-(random_center_y - (IMG_HEIGHT/2)))*rn

        else:

            input_image = cv.resize(input_image, (IMG_WIDTH, IMG_HEIGHT), interpolation = cv.INTER_CUBIC)
            output_image = cv.resize(output_image, (IMG_WIDTH, IMG_HEIGHT), interpolation = cv.INTER_CUBIC)



        #======FLIP OR NOT?======#

        #Flip Horizontaly
        if R[1] < 0.5:
            input_image = cv.flip(input_image, 1)
            output_image = cv.flip(output_image, 1)
            cornerTL_x = (IMG_WIDTH) - cornerTL_x
            cornerTR_x = (IMG_WIDTH) - cornerTR_x
            cornerBL_x = (IMG_WIDTH) - cornerBL_x
            cornerBR_x = (IMG_WIDTH) - cornerBR_x

        #Flip Verticaly
        if R[2] < 0.5:
            input_image = cv.flip(input_image, 0)
            output_image = cv.flip(output_image, 0)
            cornerTL_y = (IMG_HEIGHT) - cornerTL_y
            cornerTR_y = (IMG_HEIGHT) - cornerTR_y
            cornerBL_y = (IMG_HEIGHT) - cornerBL_y
            cornerBR_y = (IMG_HEIGHT) - cornerBR_y

        #Color Shift
        if R[3] < 0.3:
            seq = iaa.Sequential([iaa.MultiplyHueAndSaturation((0.8, 1.2), per_channel=True)])
            input_image = seq.augment_image(input_image)
            output_image = seq.augment_image(output_image)

        #Random Simplex Noise Blobs
        # if R[4] < 0.3:
        #     seq = iaa.SimplexNoiseAlpha( first=iaa.Multiply(mul = (0.6,1.4),per_channel=True),per_channel=True)
        #     input_image = seq.augment_image(input_image)
        #     output_image = seq.augment_image(output_image)

        #Now we have to work in [0,1] instead of [0,255]
        input_image = input_image/255.0
        output_image = output_image/255.0

        # Gaussian Blur
        # if R[5] < 0.3:
        #     input_image = cv.GaussianBlur(input_image,(5,5),0)

        #Gaussian Noise
        # if R[6] < 0:#0.2:
        #     mean = 0
        #     var = 0.001
        #     sigma = var**0.5
        #     gauss = np.random.normal(mean,sigma,(IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS))
        #     gauss = gauss.reshape(IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS)
        #     input_image = input_image + gauss

        sorted_x = np.sort((cornerTL_x,cornerBL_x,cornerTR_x,cornerBR_x))
        sorted_y = np.sort((cornerTL_y,cornerBL_y,cornerTR_y,cornerBR_y))

        # find the second largest/smallest of the x and y values, and give some manouvering margin to ensure the are no edges
        min_x = int(sorted_x[1])  + 10
        max_x = int(sorted_x[-2]) - 10
        min_y = int(sorted_y[1])  + 10
        max_y = int(sorted_y[-2]) - 10

        #min_x = int(min(cornerTL_x,cornerBL_x,cornerTR_x,cornerBR_x))
        #max_x = int(max(cornerTL_x,cornerBL_x,cornerTR_x,cornerBR_x))
        #min_y = int(min(cornerTL_y,cornerBL_y,cornerTR_y,cornerBR_y))
        #max_y = int(max(cornerTL_y,cornerBL_y,cornerTR_y,cornerBR_y))

        cv.rectangle(output_image,(0,0),(IMG_WIDTH,IMG_HEIGHT),(255,255,255),-1)
        cv.rectangle(output_image,(0,0),(IMG_WIDTH,min_y),(0,0,0),-1)
        cv.rectangle(output_image,(0,max_y),(IMG_WIDTH,IMG_HEIGHT),(0,0,0),-1)
        cv.rectangle(output_image,(0,0),(min_x,IMG_HEIGHT),(0,0,0),-1)
        cv.rectangle(output_image,(max_x,0),(IMG_WIDTH,IMG_HEIGHT),(0,0,0),-1)


        if DEBUG:
            debug_img = output_image.copy()
            cv.circle(debug_img,(int(cornerTL_x),int(cornerTL_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerTL_x),int(cornerTL_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerTR_x),int(cornerTR_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerTR_x),int(cornerTR_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerBL_x),int(cornerBL_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerBL_x),int(cornerBL_y)),6,(255,255,255),-1)

            cv.circle(debug_img,(int(cornerBR_x),int(cornerBR_y)),10,(0,0,0),-1)
            cv.circle(debug_img,(int(cornerBR_x),int(cornerBR_y)),6,(255,255,255),-1)

            cv.imshow('DEBUG OUTPT',  cv.resize(debug_img,(960,720)))
            cv.waitKey()

        output_image = cv.cvtColor(output_image.astype(np.uint8),cv.COLOR_BGR2GRAY)
        output_image = np.expand_dims(output_image,axis=2)

        input_images.append(input_image)
        output_images.append(output_image)
        batch_corners.append((cornerTL_x,cornerTL_y,cornerTR_x,cornerTR_y,cornerBL_x,cornerBL_y,cornerBR_x,cornerBR_y))

    input_images_np  = np.stack(input_images,  axis=0)
    output_images_np = np.stack(output_images, axis=0)

    if input_images_np.ndim != 4:
      print("Skipping image with ID {img_file_list[f]} because one or both files were not found.")

    else:
      input_tensor  = torch.from_numpy(input_np).permute(0, 3, 1, 2).cuda()
      output_tensor = torch.from_numpy(output_np).permute(0, 3, 1, 2).cuda()


    return input_images.cuda(), output_images.cuda(), batch_corners

class EDSR(torch.nn.Module):
    def __init__(self, conv=common.default_conv):
        super(EDSR, self).__init__()

        rgb_range = 1.0
        n_resblocks = 5
        if DNN_TYPE==0:
            n_feats=3
            n_output_feats=3
            self.n_colors = 3
        else:
            n_feats=1
            n_output_feats=1
            self.n_colors = 3

        kernel_size = 3
        scale = 1
        act = torch.nn.ReLU(True)
        self.url = None
        self.sub_mean = common.MeanShift(rgb_range)
        self.add_mean = common.MeanShift(rgb_range, sign=1)

        self.res_scale = 1

        # define head module
        m_head = [conv(self.n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            common.ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=self.res_scale
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            common.Upsampler(conv, scale, n_output_feats, act=False),
            conv(n_feats, self.n_colors, kernel_size)
        ]

        self.head = torch.nn.Sequential(*m_head)
        self.body = torch.nn.Sequential(*m_body)
        self.tail = torch.nn.Sequential(*m_tail)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)

        res = self.body(x)
        res += x

        #x = self.tail(res)
        x = self.add_mean(res)

        return x

    def load_state_dict(self, state_dict, strict=True):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, torch.nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') == -1:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

class EDSR_binary(torch.nn.Module):
    def __init__(self, conv=common.default_conv):
        super(EDSR_binary, self).__init__()

        rgb_range = 1.0
        n_resblocks = 5
        if DNN_TYPE==0:
            n_feats=3
            n_output_feats=3
            self.n_colors = 3
        else:
            n_feats=1
            n_output_feats=1
            self.n_colors = 3

        kernel_size = 3
        scale = 1
        act = torch.nn.ReLU(True)
        self.url = None
        self.sub_mean = common.MeanShift(rgb_range)
        self.add_mean = common.MeanShift(rgb_range, sign=1)
        self.collapse = torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=1)

        self.res_scale = 1

        # define head module
        m_head = [conv(self.n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            common.ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=self.res_scale
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            common.Upsampler(conv, scale, n_output_feats, act=False),
            conv(n_feats, self.n_colors, kernel_size)
        ]

        self.head = torch.nn.Sequential(*m_head)
        self.body = torch.nn.Sequential(*m_body)
        self.tail = torch.nn.Sequential(*m_tail)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)

        res = self.body(x)
        res += x

        #x = self.tail(res)
        x = self.add_mean(res)
        x=self.collapse(x)

        return x

    def load_state_dict(self, state_dict, strict=True):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, torch.nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') == -1:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

FILTER_SIZE = 128
PADDING = 1

class DoubleConv(torch.nn.Module):
    """(Conv => BN => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True),

            torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(torch.nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=[4,8, 16, 32]):
        super(UNet, self).__init__()

        # Downsampling path
        self.downs = torch.nn.ModuleList()
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Upsampling path
        self.ups = torch.nn.ModuleList()
        self.upconvs = torch.nn.ModuleList()

        rev_features = features[::-1]
        for feature in rev_features:
            self.upconvs.append(torch.nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature * 2, feature))

        # Final output layer
        if DNN_TYPE==0:
            self.final_conv = torch.nn.Conv2d(features[0], 3, kernel_size=1)
        else:
            self.final_conv = torch.nn.Conv2d(features[0], 1, kernel_size=1)


    def forward(self, x):
        # Encoder / Down path
        x1 = self.downs[0](x)              # [B, 64, H, W]
        x2 = self.downs[1](self.pool(x1))  # [B, 128, H/2, W/2]
        x3 = self.downs[2](self.pool(x2))  # [B, 256, H/4, W/4]
        x4 = self.downs[3](self.pool(x3))  # [B, 512, H/8, W/8]

        # Bottleneck
        x5 = self.bottleneck(self.pool(x4))  # [B, 1024, H/16, W/16]

        # Decoder / Up path
        u4 = self.upconvs[0](x5)
        if u4.shape != x4.shape:
            u4 = torch.nn.functional.interpolate(u4, size=x4.shape[2:])
        u4 = self.ups[0](torch.cat([x4, u4], dim=1))

        u3 = self.upconvs[1](u4)
        if u3.shape != x3.shape:
            u3 = torch.nn.functional.interpolate(u3, size=x3.shape[2:])
        u3 = self.ups[1](torch.cat([x3, u3], dim=1))

        u2 = self.upconvs[2](u3)
        if u2.shape != x2.shape:
            u2 = torch.nn.functional.interpolate(u2, size=x2.shape[2:])
        u2 = self.ups[2](torch.cat([x2, u2], dim=1))

        u1 = self.upconvs[3](u2)
        if u1.shape != x1.shape:
            u1 = torch.nn.functional.interpolate(u1, size=x1.shape[2:])
        u1 = self.ups[3](torch.cat([x1, u1], dim=1))

        # Final output
        return self.final_conv(u1)


#======================================== LOSS FUNCTION ==============================================

if DNN_TYPE==0:
    if TRAIN_TYPE==3:
        model=EDSR_binary()
    else:
        model=EDSR()
else:
    model=UNet()

def _gaussian_window(window_size: int, sigma: float) -> torch.Tensor:
    coords = torch.arange(window_size, dtype=torch.float)
    coords -= (window_size - 1) / 2.0
    g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
    g /= g.sum()
    return g

def _create_window(window_size: int, channel: int, device, dtype) -> torch.Tensor:
    # 1D Gaussian
    _1d = _gaussian_window(window_size, sigma=1.5).to(device=device, dtype=dtype)
    _2d = _1d.unsqueeze(1) @ _1d.unsqueeze(0)              # outer product → 2D kernel
    _2d = _2d.unsqueeze(0).unsqueeze(0)                     # shape (1,1,ws,ws)
    window = _2d.expand(channel, 1, window_size, window_size).contiguous()
    return window

class MaskedDeblurLoss(nn.Module):
    def __init__(self, lambda_ssim=0.5):
        super().__init__()
        self.lambda_ssim = lambda_ssim

    def forward(self, pred, target, corners):
        # pred, target: [B,3,H,W], in [0,1]
        B, C, H, W = pred.shape
        device = pred.device

        # a) build mask
        mask = build_square_mask(corners, H, W, device)     # [B,1,H,W]

        # b) masked MSE
        diff2 = (pred - target).pow(2).mean(dim=1, keepdim=True)  # [B,1,H,W]
        mse_num = (diff2 * mask).sum()
        mse_den = mask.sum().clamp(min=1.0)
        loss_mse = mse_num / mse_den

        # c) masked SSIM
        # we multiply by mask so that SSIM only “sees” the square
        # ssim() returns one scalar per batch if size_average=True
        loss_ssim = 1 - ssim(
            pred * mask,
            target * mask,
            data_range=1.0,
            size_average=True
        )

        return loss_mse + self.lambda_ssim * loss_ssim


class MSE_Crop_Loss(torch.nn.Module):
    def __init__(self):
        super(MSE_Crop_Loss, self).__init__()
    def forward(self,pred, target, corners):

        crop_score = 0

        for b in range(len(corners)):
            corners_x = corners[b][0:8:2]
            corners_y = corners[b][1:8:2]
            min_x = min(corners_x)
            max_x = max(corners_x)
            min_y = min(corners_y)
            max_y = max(corners_y)

            if min_x<0:
                min_x=0
            if min_y<0:
                min_y=0
            if max_x<0:
                max_x=0
            if max_y<0:
                max_y=0
            if min_x>IMG_WIDTH:
                min_x=IMG_WIDTH
            if min_y>IMG_HEIGHT:
                min_y=IMG_HEIGHT
            if max_x>IMG_WIDTH:
                max_x=IMG_WIDTH
            if max_y>IMG_HEIGHT:
                max_y=IMG_HEIGHT

            cropped_tgt = torchvision.transforms.functional.crop(target[b],int(min_y),int(min_x),int(max_y-min_y), int(max_x-min_x))
            cropped_pred = torchvision.transforms.functional.crop(pred[b],int(min_y),int(min_x),int(max_y-min_y), int(max_x-min_x))
            #cropped_tgt = target[b][min_x:max_x,min_y:max_y]
            #cropped_pred = pred[b][min_x:max_x,min_y:max_y]
            crop_score += torch.nn.functional.mse_loss(cropped_pred,cropped_tgt)
            #crop_score = np.square(np.subtract(cropped_pred,cropped_tgt)).mean()

        score = torch.nn.functional.mse_loss(pred,target)

        crop_score = crop_score / len(corners)
        return 1-(0.9*crop_score + 0.1*score)

x = torch.randn(1,3,1920, 1080)
y = model(x)

aux = y[0,:,:,:].cpu().detach().numpy()*255
aux = aux.astype(np.uint8)
aux = np.transpose(aux, (1,2,0))

make_dot(y.mean(), params=dict(model.named_parameters())).render('Debulr_Model', format="png")

model_image = cv.imread('Debulr_Model.png')
cv.imwrite(SOURCE_FOLDER+'/Results/Deblur_Model.png',model_image)

model= torch.nn.DataParallel(model)
model.to('cuda')

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

if TRAIN_TYPE == 4:
    loss_function = MaskedDeblurLoss(lambda_ssim=0.5)
elif TRAIN_TYPE == 3:
    loss_function = nn.BCEWithLogitsLoss()
elif TRAIN_TYPE == 0:
    loss_function = MSE_Crop_Loss()   # your existing crop‐MSE
else:
    loss_function = nn.MSELoss()

# Create Log file
with open(training_history_filename, "w") as csv_file:
    progress_string = 'EPOCH,LOSS,TRAIN_SCORE,VAL_SCORE/n'
    csv_file.write(progress_string)


best_train_score = 0.0

if TRAIN:

    for epoch in range(EPOCHS):
        running_loss = 0.0
        running_train_score = 0.0
        running_val_score = 0.0

        # count how many batches we actually processed
        n_train_batches = 0
        n_val_batches   = 0

        # —————— training ——————
        for start in range(0, len(train_img_list), BATCH_SIZE):
            batch_ids     = train_img_list    [start : start + BATCH_SIZE]
            batch_corners = train_corners_list[start : start + BATCH_SIZE]

            batch_inputs, batch_outputs, batch_corners = yield_training_batch(
                batch_ids, batch_corners
            )
            if batch_inputs is None:
                continue

            n_train_batches += 1
            batch_preds = model(batch_inputs)
            # compute loss exactly as you already do…
            if TRAIN_TYPE in (0,4):
                loss = loss_function(batch_preds, batch_outputs.cuda(), batch_corners)
            else:
                loss = loss_function(batch_preds, batch_outputs.cuda())

            train_score = 1.0 - loss

            # backprop…
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_train_score += train_score.item()

        # —————— validation ——————
        for start in range(0, len(val_img_list), BATCH_SIZE):
            batch_ids     = val_img_list    [start : start + BATCH_SIZE]
            batch_corners = val_corners_list[start : start + BATCH_SIZE]

            batch_inputs, batch_outputs, batch_corners = yield_training_batch(
                batch_ids, batch_corners
            )
            if batch_inputs is None:
                continue

            n_val_batches += 1
            batch_preds = model(batch_inputs)
            if TRAIN_TYPE in (0,4):
                val_loss = loss_function(batch_preds, batch_outputs.cuda(), batch_corners)
            else:
                val_loss = loss_function(batch_preds, batch_outputs.cuda())
            val_score = 1.0 - val_loss

            running_val_score += val_score.item()

        # —————— now compute **averages** ——————
        avg_loss       = running_loss       / max(1, n_train_batches)
        avg_train_score= running_train_score/ max(1, n_train_batches)
        avg_val_score  = running_val_score  / max(1, n_val_batches)

        print(f"Epoch {epoch}  "
              f"Avg Train Loss: {avg_loss:.4f}  "
              f"Avg Train Score: {avg_train_score:.4f}  "
              f"Avg Val Score: {avg_val_score:.4f}")



        # write epoch-level **averages** to CSV
        with open(training_history_filename, "a") as csv_file:
            progress_string = (
                f"{epoch},"
                f"{avg_loss:.6f},"
                f"{avg_train_score:.6f},"
                f"{avg_val_score:.6f}\n"
            )
            csv_file.write(progress_string)

        print('=========================================================')
        print(f'Epoch {epoch}')
        print(f'Avg Train Loss:    {avg_loss:.6f}')
        print(f'Avg Train Score:   {avg_train_score:.6f}')
        print(f'Avg Val   Score:   {avg_val_score:.6f}')
        print('=========================================================')

        if avg_train_score > best_train_score:
            best_train_score = avg_train_score
            model_name = 'Best_Deblur_model'
            # rebuild a fresh model instance and load weights
            if DNN_TYPE == 0:
                model_to_save = EDSR_binary() if TRAIN_TYPE == 3 else EDSR()
            else:
                model_to_save = UNet()
            model_to_save.load_state_dict(model.module.state_dict())
            model_to_save.to('cuda')
            model_scripted = torch.jit.script(model_to_save)
            model_scripted.save(f"{SOURCE_FOLDER}/Results/{model_name}.pt")

print('DONE')


if TEST:
    # Load the best model from file (update path if necessary)
    model_path = os.path.join(SOURCE_FOLDER, 'Results', '16.04.2025.pt')
    model_instance = torch.load(model_path,weights_only=False)
    model_instance.to('cuda')
    model_instance.eval()
    test_dir = os.path.join(SOURCE_FOLDER, 'Lens_Dataset', 'INPUTS')  # Set your test directory
    filelist = [f for f in os.listdir(test_dir) if f.endswith('.png')]

    # Display predictions for up to 10 random test images
    for i in range(min(10, len(filelist))):
        r = random.randrange(0, len(filelist))
        frame = cv.imread(os.path.join(test_dir, filelist[r]))
        # Resize and normalize for network input
        frame = cv.resize(frame, (1920, 1080)) / 255.0
        frame_T = np.transpose(frame, (2, 0, 1))
        np_tensor = np.expand_dims(frame_T, axis=0).astype(np.float32)
        input_tensor = torch.tensor(np_tensor).cuda()
        with torch.no_grad():
            preds = model_instance(input_tensor)
        # For SSIM mode, apply sigmoid to force outputs in [0,1]
        if TRAIN_TYPE == 4:
            preds = torch.sigmoid(preds)
        aux = preds[0, :, :, :].cpu().detach().numpy() * 255
        aux = aux.astype(np.uint8)
        output_image = np.transpose(aux, (1, 2, 0))
        imshow_cv2(output_image, title=f"Prediction for {filelist[r]}", resize_factor=0.5)

DisabledFunctionError: cv2.imshow() is disabled in Colab, because it causes Jupyter sessions
to crash; see https://github.com/jupyter/notebook/issues/3935.
As a substitution, consider using
  from google.colab.patches import cv2_imshow


In [None]:
!ls  /content/Deblur/Results/

ls: cannot access '/content/Deblur/Results/': No such file or directory
