In [5]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision import transforms

In [6]:
def split_into_patches(image, patch_size=256, overlap=64):
    """Splits an image into overlapping patches of size patch_size x patch_size."""
    h, w, c = image.shape
    step = patch_size - overlap
    patches = []
    coordinates = []
    
    for y in range(0, h - overlap, step):
        for x in range(0, w - overlap, step):
            y_end = min(y + patch_size, h)
            x_end = min(x + patch_size, w)
            patch = image[y:y_end, x:x_end]
            
            if patch.shape[0] < patch_size or patch.shape[1] < patch_size:
                # Pad the patch if it's smaller than expected
                pad_h = patch_size - patch.shape[0]
                pad_w = patch_size - patch.shape[1]
                patch = np.pad(patch, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
            
            patches.append(patch)
            coordinates.append((y, x))
    
    return np.array(patches), coordinates

In [7]:
def predict_on_patches(model, patches, device):
    """Runs inference on all patches and returns RGB predictions."""
    model.eval()
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Adjust based on your model
    ])
    
    results = []
    with torch.no_grad():
        for patch in patches:
            patch = transform(patch).unsqueeze(0).to(device)
            output = model(patch)  # Expected to output RGB image
            results.append(output.cpu().numpy().squeeze().transpose(1, 2, 0))
    
    return np.array(results)

In [8]:
def stitch_patches(patches, coordinates, image_shape, patch_size=256, overlap=64):
    """Reconstructs the full RGB image from patches using averaging at overlapping areas."""
    h, w, c = image_shape
    step = patch_size - overlap
    stitched_image = np.zeros((h, w, 3))  # Ensure 3-channel output
    weight_map = np.zeros((h, w, 1))
    
    for patch, (y, x) in zip(patches, coordinates):
        y_end = min(y + patch_size, h)
        x_end = min(x + patch_size, w)
        patch_h = y_end - y
        patch_w = x_end - x
        
        stitched_image[y:y_end, x:x_end] += patch[:patch_h, :patch_w]
        weight_map[y:y_end, x:x_end] += 1
    
    return np.clip(stitched_image / weight_map, 0, 255).astype(np.uint8)  # Normalize and clip values

In [9]:
def main(model_path, image_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = torch.load(model_path, map_location=device)
    
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    patches, coordinates = split_into_patches(image)
    predictions = predict_on_patches(model, patches, device)
    
    stitched_image = stitch_patches(predictions, coordinates, image.shape)
    
    # Plot original and predicted image
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Original Image")
    plt.axis("off")
    
    plt.subplot(1, 2, 2)
    plt.imshow(stitched_image)
    plt.title("Predicted Image")
    plt.axis("off")
    
    plt.show()

In [10]:
model_path = r"C:\Users\91909\Desktop\ML\NTIRE_2025\best_model.pth"
test_image = r"C:\Users\91909\Desktop\ML\DATA\NTIRE\test\LSDIR_DIV2K_Test_Sigma50\0000089.png"

In [None]:
main('model.pth', 'input.jpg')