In [1]:
import os
import torch
import torch.nn as nn
from torchvision.transforms.functional import resize
from torch.utils.data import DataLoader
import torch.optim as optim
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()




In [2]:
class ResidualBlock(nn.Module):
    def __init__(self, num_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = out * 0.1
        return out + residual

class EDSR(nn.Module):
    def __init__(self, scale_factor, num_channels=3, num_features=64, num_residuals=32):
        super(EDSR, self).__init__()
        self.scale_factor = scale_factor
        self.conv1 = nn.Conv2d(num_channels, num_features, kernel_size=3, stride=1, padding=1)

        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features) for _ in range(num_residuals)]
        )

        self.conv2 = nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1)

        self.upsample = nn.Sequential(
            nn.Conv2d(num_features, num_features * (scale_factor ** 2), kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(scale_factor),
            nn.Conv2d(num_features, num_channels, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.residual_blocks(x)
        x = self.conv2(x)
        x = self.upsample(x)
        return x


In [3]:
import random
from torchvision.transforms.functional import gaussian_blur

def denormalize(tensor):
    """Convert [-1, 1] range back to [0, 1] for visualization."""
    return torch.clamp(tensor * 0.5 + 0.5, 0, 1)

class SuperResolutionDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, patch_size=128, scale_factor=2, max_patches=16, seed=42):
        self.root_dir = root_dir
        self.patch_size = patch_size
        self.lr_patch_size = patch_size // scale_factor
        self.image_list = os.listdir(root_dir)
        self.scale_factor = scale_factor
        self.max_patches = max_patches  
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        random.seed(seed)
        torch.manual_seed(seed)

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_list[idx])
        image = Image.open(img_path).convert("RGB")
        width, height = image.size

        hr_patches = []
        lr_patches = []
        for y in range(0, height, self.patch_size):
            for x in range(0, width, self.patch_size):
                hr_patch = image.crop((x, y, x + self.patch_size, y + self.patch_size))

                if hr_patch.size[0] != self.patch_size or hr_patch.size[1] != self.patch_size:
                    continue

                lr_patch = resize(hr_patch, (self.lr_patch_size, self.lr_patch_size), interpolation=Image.BILINEAR)

                hr_patch_tensor = self.normalize(transforms.ToTensor()(hr_patch))
                lr_patch_tensor = self.normalize(transforms.ToTensor()(lr_patch))
                hr_patches.append(hr_patch_tensor)
                lr_patches.append(lr_patch_tensor)

                blur_kernel_size = random.choice([3, 5])
                lr_patch_blur = gaussian_blur(lr_patch_tensor, kernel_size=blur_kernel_size)
                hr_patches.append(hr_patch_tensor) 
                lr_patches.append(lr_patch_blur)

                noise = torch.randn_like(lr_patch_tensor) * 0.1 
                lr_patch_noisy = torch.clamp(lr_patch_tensor + noise, 0, 1)
                hr_patches.append(hr_patch_tensor) 
                lr_patches.append(lr_patch_noisy)

        while len(hr_patches) < self.max_patches:
            hr_patches.append(torch.zeros((3, self.patch_size, self.patch_size)))
            lr_patches.append(torch.zeros((3, self.lr_patch_size, self.lr_patch_size)))

        hr_patches = hr_patches[:self.max_patches]
        lr_patches = lr_patches[:self.max_patches]

        if len(hr_patches) == 0 or len(lr_patches) == 0:
            raise ValueError(f"No valid patches for image {img_path}")

        return torch.stack(lr_patches), torch.stack(hr_patches)

dataset = SuperResolutionDataset(root_dir=r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\combined_largest_images_rd")#largest_images
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)


In [4]:

## Some Paramters ##


scale_factor = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EDSR(scale_factor=scale_factor).to(device) 
criterion = nn.MSELoss().to(device)                
# optimizer = optim.Adam(model.parameters(), lr=1e-4)

from torch.optim import RMSprop
optimizer = RMSprop(model.parameters(), lr=1e-4, alpha=0.9, weight_decay=1e-5)
print(model.conv1.weight.device)  # Should print: cuda:0


cuda:0


In [5]:
import os
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import pad
from tqdm import tqdm
import matplotlib.pyplot as plt

os.makedirs("chkt_v6", exist_ok=True) 
os.makedirs("progress_v6", exist_ok=True)  

def pad_image(image, patch_size):
    width, height = image.size
    pad_width = (patch_size - width % patch_size) % patch_size
    pad_height = (patch_size - height % patch_size) % patch_size
    padding = (0, 0, pad_width, pad_height)  # (left, top, right, bottom)
    padded_image = pad(image, padding, fill=0)
    return padded_image, padding

num_epochs = 15
loss_values = []  
patch_size = 128  
scale_factor = 2  
lr_patch_size = patch_size // scale_factor  

test_image_path = r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\combined_largest_images_test_model\HR_output.PNG" 
test_image = Image.open(test_image_path).convert("RGB")
test_width, test_height = test_image.size

padded_test_image, padding = pad_image(test_image, patch_size)
padded_width, padded_height = padded_test_image.size

test_patches = []
for y in range(0, padded_height, patch_size):
    for x in range(0, padded_width, patch_size):
        patch = padded_test_image.crop((x, y, x + patch_size, y + patch_size))
        lr_patch = resize(patch, (lr_patch_size, lr_patch_size), interpolation=Image.BILINEAR)
        test_patches.append(transforms.ToTensor()(lr_patch).unsqueeze(0))


In [5]:
import os
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import pad
from tqdm import tqdm
import matplotlib.pyplot as plt

os.makedirs("chkt_v6", exist_ok=True) 
os.makedirs("progress_v6", exist_ok=True)  

def pad_image(image, patch_size):
    width, height = image.size
    pad_width = (patch_size - width % patch_size) % patch_size
    pad_height = (patch_size - height % patch_size) % patch_size
    padding = (0, 0, pad_width, pad_height)  # (left, top, right, bottom)
    padded_image = pad(image, padding, fill=0)
    return padded_image, padding

num_epochs = 15
loss_values = []  
patch_size = 128  
scale_factor = 2  
lr_patch_size = patch_size // scale_factor  

test_image_path = r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\combined_largest_images_test_model\HR_output.PNG" 
test_image = Image.open(test_image_path).convert("RGB")
test_width, test_height = test_image.size

padded_test_image, padding = pad_image(test_image, patch_size)
padded_width, padded_height = padded_test_image.size

test_patches = []
for y in range(0, padded_height, patch_size):
    for x in range(0, padded_width, patch_size):
        patch = padded_test_image.crop((x, y, x + patch_size, y + patch_size))
        lr_patch = resize(patch, (lr_patch_size, lr_patch_size), interpolation=Image.BILINEAR)
        test_patches.append(transforms.ToTensor()(lr_patch).unsqueeze(0))

for epoch in range(num_epochs):
    model.train()
    torch.cuda.empty_cache()

    epoch_loss = 0
    for lr_patches, hr_patches in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch_size, num_patches, channels, lr_height, lr_width = lr_patches.shape
        _, _, _, hr_height, hr_width = hr_patches.shape

        lr_patches = lr_patches.view(-1, channels, lr_height, lr_width).to(device)
        hr_patches = hr_patches.view(-1, channels, hr_height, hr_width).to(device)

        optimizer.zero_grad()
        sr_patches = model(lr_patches)

        loss = criterion(sr_patches, hr_patches)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    loss_values.append(epoch_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    torch.save(model.state_dict(), f"chkt_v6/edsr_epoch_{epoch+1}.pth")

    model.eval()
    sr_test_patches = []
    with torch.no_grad():
        for patch in test_patches:
            patch = patch.to(device)
            sr_test_patches.append(model(patch).squeeze(0).cpu())

    sr_test_image = torch.zeros((3, padded_height, padded_width))  # Use padded dimensions
    patch_idx = 0
    for y in range(0, padded_height, patch_size):
        for x in range(0, padded_width, patch_size):
            # sr_test_image[:, y:y+patch_size, x:x+patch_size] = sr_test_patches[patch_idx]
            sr_test_image[:, y:y+patch_size, x:x+patch_size] = denormalize(sr_test_patches[patch_idx])

            patch_idx += 1

    sr_test_image = sr_test_image[:, :test_height, :test_width]

    patch_idx = 0  # Choose a specific patch to visualize (adjust as needed)
    low_res_patch = test_patches[patch_idx].squeeze(0).cpu()
    super_res_patch = sr_test_patches[patch_idx].cpu()
    high_res_patch = transforms.ToTensor()(padded_test_image.crop((0, 0, patch_size, patch_size)))

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.title("Low-Resolution Patch")
    # plt.imshow(transforms.ToPILImage()(low_res_patch))
    plt.imshow(transforms.ToPILImage()(denormalize(low_res_patch)))
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.title(f"Super-Resolved Patch (Epoch {epoch+1})")
    # plt.imshow(transforms.ToPILImage()(super_res_patch))
    plt.imshow(transforms.ToPILImage()(denormalize(super_res_patch)))

    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.title("High-Resolution Patch")

    plt.imshow(transforms.ToPILImage()(high_res_patch))
    # plt.imshow(transforms.ToPILImage()(denormalize(high_res_patch)))

    plt.axis("off")

    plt.tight_layout()
    plt.savefig(f"progress_v6/epoch_{epoch+1}_patch_comparison.png", dpi=300, bbox_inches='tight')
    plt.close()  # Close the figure

    plt.figure(figsize=(16, 5))

    plt.subplot(1, 3, 1)
    plt.title(f"Low-Resolution Input ({lr_patch_size}x{lr_patch_size} patches)")

    lr_test_reconstructed = torch.zeros((3, padded_height // scale_factor, padded_width // scale_factor))
    patch_idx = 0
    for y in range(0, padded_height, patch_size):
        for x in range(0, padded_width, patch_size):
            lr_patch = test_patches[patch_idx].squeeze(0).cpu()
            lr_test_reconstructed[:, y // scale_factor:(y + patch_size) // scale_factor,
                                  x // scale_factor:(x + patch_size) // scale_factor] = lr_patch
            patch_idx += 1

    lr_test_reconstructed = lr_test_reconstructed[:, :test_height // scale_factor, :test_width // scale_factor]
    # plt.imshow(transforms.ToPILImage()(lr_test_reconstructed))
    plt.imshow(transforms.ToPILImage()(denormalize(lr_test_reconstructed)))

    plt.axis("off")

    plt.subplot(1, 3, 2)
    sr_height, sr_width = sr_test_image.shape[1:] 
    plt.title(f"Super-Resolved Output (Epoch {epoch+1}, {sr_width}x{sr_height})")
    plt.imshow(transforms.ToPILImage()(denormalize(sr_test_image)))
    # plt.imshow(transforms.ToPILImage()(sr_test_image))
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.title(f"Original High-Resolution ({test_width}x{test_height})")
    plt.imshow(test_image)
    plt.axis("off")

    plt.tight_layout()
    plt.savefig(f"progress_v6/epoch_{epoch+1}_reconstruction_comparison.png", dpi=200, bbox_inches='tight')
    plt.close() 


    plt.figure(figsize=(8, 6))
    plt.plot(range(1, len(loss_values) + 1), loss_values, marker='o', label="Training Loss")
    plt.title("Training Loss Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.legend()
    plt.savefig("progress_v6/loss_plot.png", dpi=200, bbox_inches='tight')
    plt.close()


Epoch 1/15: 100%|██████████| 81/81 [03:46<00:00,  2.79s/it]


Epoch [1/15], Loss: 5.3279


Epoch 2/15: 100%|██████████| 81/81 [03:39<00:00,  2.71s/it]


Epoch [2/15], Loss: 3.5021


Epoch 3/15: 100%|██████████| 81/81 [03:36<00:00,  2.67s/it]


Epoch [3/15], Loss: 2.6821


Epoch 4/15: 100%|██████████| 81/81 [03:37<00:00,  2.68s/it]


Epoch [4/15], Loss: 2.4325


Epoch 5/15: 100%|██████████| 81/81 [03:37<00:00,  2.68s/it]


Epoch [5/15], Loss: 2.2653


Epoch 6/15: 100%|██████████| 81/81 [2:55:29<00:00, 129.99s/it]    


Epoch [6/15], Loss: 2.1388


Epoch 7/15: 100%|██████████| 81/81 [04:04<00:00,  3.02s/it]


Epoch [7/15], Loss: 2.0593


Epoch 8/15: 100%|██████████| 81/81 [03:39<00:00,  2.71s/it]


Epoch [8/15], Loss: 1.9656


Epoch 9/15: 100%|██████████| 81/81 [04:06<00:00,  3.04s/it]


Epoch [9/15], Loss: 1.9501


Epoch 10/15: 100%|██████████| 81/81 [03:44<00:00,  2.78s/it]


Epoch [10/15], Loss: 1.8936


Epoch 11/15: 100%|██████████| 81/81 [03:43<00:00,  2.75s/it]


Epoch [11/15], Loss: 1.8531


Epoch 12/15: 100%|██████████| 81/81 [03:36<00:00,  2.67s/it]


Epoch [12/15], Loss: 1.8189


Epoch 13/15: 100%|██████████| 81/81 [03:41<00:00,  2.74s/it]


Epoch [13/15], Loss: 1.7807


Epoch 14/15: 100%|██████████| 81/81 [03:38<00:00,  2.69s/it]


Epoch [14/15], Loss: 1.7691


Epoch 15/15: 100%|██████████| 81/81 [03:38<00:00,  2.70s/it]


Epoch [15/15], Loss: 1.7326


# Inference  Below

In [17]:
import os
import torch
import math
from torchvision.transforms.functional import pad, resize, to_pil_image, to_tensor
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# Define normalization and denormalization
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
def denormalize(tensor):
    """Denormalize the tensor from [-1, 1] to [0, 1]."""
    return torch.clamp(tensor * 0.5 + 0.5, 0, 1)

def pad_image(image, patch_size):
    """Pad image to make dimensions divisible by patch size."""
    width, height = image.size
    pad_width = (patch_size - width % patch_size) % patch_size
    pad_height = (patch_size - height % patch_size) % patch_size
    padding = (0, 0, pad_width, pad_height)
    padded_image = pad(image, padding, fill=0)
    return padded_image, padding

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images."""
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')  # Infinite PSNR (perfect match)
    return 20 * math.log10(1.0 / math.sqrt(mse))

def psnr(pred, target):
    """Compute the PSNR (Peak Signal-to-Noise Ratio) between two tensors."""
    mse = F.mse_loss(pred, target)
    if mse == 0:
        return 5.5
    return 20 * math.log10(1.0 / math.sqrt(mse.item()))

def perform_inference_comparison_EDSR(
    model_path, image_paths, output_dir="edsr_inference_results", patch_size=128, scale_factor=2
):

    os.makedirs(output_dir, exist_ok=True)

    # Load trained EDSR model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EDSR(scale_factor=scale_factor).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    for image_path in image_paths:
        image_name = os.path.splitext(os.path.basename(image_path))[0]
        image = Image.open(image_path).convert("RGB")
        original_width, original_height = image.size

        # Pad the image
        padded_image, padding = pad_image(image, patch_size)
        padded_width, padded_height = padded_image.size

        # Prepare patches
        lr_patches, hr_patches = [], []
        for y in range(0, padded_height, patch_size):
            for x in range(0, padded_width, patch_size):
                patch = padded_image.crop((x, y, x + patch_size, y + patch_size))
                lr_patch = resize(patch, (patch_size // scale_factor, patch_size // scale_factor), interpolation=Image.BICUBIC)
                lr_patches.append(normalize(to_tensor(lr_patch)).unsqueeze(0).to(device))
                hr_patches.append(normalize(to_tensor(patch)).unsqueeze(0).to(device))

        # Perform inference
        sr_patches, bicubic_patches = [], []
        total_sr_psnr, total_bicubic_psnr = 0.0, 0.0
        with torch.no_grad():
            for lr_patch, hr_patch in zip(lr_patches, hr_patches):
                sr_patch = model(lr_patch).squeeze(0)
                # bicubic_patch = torch.nn.functional.interpolate(lr_patch, scale_factor=scale_factor, mode="bicubic")

                bicubic_patch = F.interpolate(
                    lr_patch, scale_factor=scale_factor, mode="bilinear", align_corners=True
                ).squeeze(0)

                sr_patches.append(denormalize(sr_patch.cpu()))
                bicubic_patches.append(denormalize(bicubic_patch.squeeze(0).cpu()))
                total_sr_psnr += calculate_psnr(denormalize(sr_patch), denormalize(hr_patch.squeeze(0)))
                total_bicubic_psnr += psnr(bicubic_patch.squeeze(0),hr_patch.squeeze(0))


        avg_sr_psnr = total_sr_psnr / len(hr_patches)
        avg_bicubic_psnr = total_bicubic_psnr / len(hr_patches)

        # Reconstruct images
        sr_image = torch.zeros((3, padded_height, padded_width))
        bicubic_image = torch.zeros((3, padded_height, padded_width))
        patch_idx = 0
        for y in range(0, padded_height, patch_size):
            for x in range(0, padded_width, patch_size):
                sr_image[:, y:y+patch_size, x:x+patch_size] = sr_patches[patch_idx]
                bicubic_image[:, y:y+patch_size, x:x+patch_size] = bicubic_patches[patch_idx]
                patch_idx += 1
        sr_image = sr_image[:, :original_height, :original_width]
        bicubic_image = bicubic_image[:, :original_height, :original_width]



        # test_image = Image.open(test_image_path).convert("RGB")
        padded_test_image, padding = pad_image(test_image, patch_size)
        padded_width, padded_height = padded_test_image.size
        test_patches = []
        hr_patches = []
        for y in range(0, padded_height, patch_size):
            for x in range(0, padded_width, patch_size):
                patch = padded_test_image.crop((x, y, x + patch_size, y + patch_size))
                lr_patch = resize(patch, (patch_size // scale_factor, patch_size // scale_factor), interpolation=Image.BILINEAR)
                test_patches.append(normalize(to_tensor(lr_patch)).unsqueeze(0))  
                hr_patches.append(normalize(to_tensor(patch)).unsqueeze(0)) 

        test_patches = [patch.to(device) for patch in test_patches]
        hr_patches = [patch.to(device) for patch in hr_patches]

        bilinear_patches = []
        total_bilinear_psnr = 0.0
        with torch.no_grad():
            for lr_patch, hr_patch in zip(test_patches, hr_patches):
                bilinear_patch = F.interpolate(
                    lr_patch, scale_factor=scale_factor, mode="bilinear", align_corners=False
                ).squeeze(0)
                bilinear_patches.append(bilinear_patch)
                total_bilinear_psnr += psnr(bilinear_patch, hr_patch.squeeze(0))
        total_bicubic_psnr = total_bilinear_psnr / len(test_patches)




        # Save comparison
        plt.figure(figsize=(30, 15))

        # Low-Resolution Input
        plt.subplot(1, 4, 1)
        plt.title("Low-Resolution Input")
        lr_image = resize(image, (original_height // scale_factor, original_width // scale_factor), interpolation=Image.BICUBIC)
        plt.imshow(lr_image)
        plt.axis("off")

        # Bicubic Upscaling
        plt.subplot(1, 4, 2)
        plt.title(f"Bicubic PSNR: {avg_bicubic_psnr:.2f} dB")
        plt.imshow(to_pil_image(bicubic_image))
        plt.axis("off")

        # Super-Resolved Output (EDSR)
        plt.subplot(1, 4, 3)
        plt.title(f"EDSR PSNR: {avg_sr_psnr:.2f} dB")
        plt.imshow(to_pil_image(sr_image))
        plt.axis("off")

        # Original High-Resolution
        plt.subplot(1, 4, 4)
        plt.title("Original High-Resolution")
        plt.imshow(image)
        plt.axis("off")

        plt.tight_layout()
        output_path = os.path.join(output_dir, f"{image_name}_comparison.png")
        plt.savefig(output_path, dpi=500, bbox_inches="tight")
        plt.close()

        print(f"Saved comparison for '{image_name}' to {output_path}")
        print(f"PSNR Results: Bicubic = {avg_bicubic_psnr:.2f} dB, EDSR = {avg_sr_psnr:.2f} dB")




image_list = [
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0745.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0746.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0747.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0748.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0749.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\HR_output.PNG", r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\waves\1702053632133310.png"]
perform_inference_comparison_EDSR("chkt_v6/edsr_epoch_15.pth", image_list, scale_factor=2,output_dir="inference_results_v6")


  model.load_state_dict(torch.load(model_path, map_location=device))


Saved comparison for '0745' to inference_results_v6\0745_comparison.png
PSNR Results: Bicubic = 24.92 dB, EDSR = 24.07 dB
Saved comparison for '0746' to inference_results_v6\0746_comparison.png
PSNR Results: Bicubic = 29.62 dB, EDSR = 25.33 dB
Saved comparison for '0747' to inference_results_v6\0747_comparison.png
PSNR Results: Bicubic = 23.89 dB, EDSR = 25.19 dB
Saved comparison for '0748' to inference_results_v6\0748_comparison.png
PSNR Results: Bicubic = 20.57 dB, EDSR = 24.03 dB
Saved comparison for '0749' to inference_results_v6\0749_comparison.png
PSNR Results: Bicubic = 30.30 dB, EDSR = 27.54 dB
Saved comparison for 'HR_output' to inference_results_v6\HR_output_comparison.png
PSNR Results: Bicubic = 21.19 dB, EDSR = 24.35 dB
Saved comparison for '1702053632133310' to inference_results_v6\1702053632133310_comparison.png
PSNR Results: Bicubic = 20.15 dB, EDSR = 23.26 dB


In [None]:
import os
import torch
import math
from torchvision.transforms.functional import pad, resize, to_pil_image, to_tensor
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F

# Define normalization and denormalization
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
def denormalize(tensor):
    """Denormalize the tensor from [-1, 1] to [0, 1]."""
    return torch.clamp(tensor * 0.5 + 0.5, 0, 1)

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images."""
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')  # Infinite PSNR (perfect match)
    return 20 * math.log10(1.0 / math.sqrt(mse))

def perform_inference_single_patch_EDSR(
    model_path, image_paths, output_dir="edsr_best_patch_results", patch_size=128, scale_factor=2
):

    os.makedirs(output_dir, exist_ok=True)

    # Load trained EDSR model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EDSR(scale_factor=scale_factor).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    for image_path in image_paths:
        image_name = os.path.splitext(os.path.basename(image_path))[0]
        image = Image.open(image_path).convert("RGB")
        original_width, original_height = image.size

        # Pad the image
        padded_image, _ = pad_image(image, patch_size)
        padded_width, padded_height = padded_image.size

        # Prepare patches
        lr_patches, hr_patches = [], []
        for y in range(0, padded_height, patch_size):
            for x in range(0, padded_width, patch_size):
                patch = padded_image.crop((x, y, x + patch_size, y + patch_size))
                lr_patch = resize(patch, (patch_size // scale_factor, patch_size // scale_factor), interpolation=Image.BICUBIC)
                lr_patches.append(normalize(to_tensor(lr_patch)).unsqueeze(0).to(device))
                hr_patches.append(normalize(to_tensor(patch)).unsqueeze(0).to(device))

        # Find the patch with the highest PSNR
        highest_psnr = 0.0
        best_patch_index = 0
        bilinear_psnr_best = 0.0
        with torch.no_grad():
            for i, (lr_patch, hr_patch) in enumerate(zip(lr_patches, hr_patches)):
                sr_patch = model(lr_patch).squeeze(0)
                bilinear_patch = F.interpolate(
                    lr_patch, scale_factor=scale_factor, mode="bilinear", align_corners=False
                ).squeeze(0)
                psnr_value = calculate_psnr(denormalize(sr_patch), denormalize(hr_patch.squeeze(0)))
                bilinear_psnr = calculate_psnr(bilinear_patch, hr_patch.squeeze(0))

                if psnr_value > highest_psnr:
                    highest_psnr = psnr_value
                    bilinear_psnr_best = bilinear_psnr
                    best_patch_index = i

        # Extract the best patch
        lr_patch_best = denormalize(lr_patches[best_patch_index].squeeze(0).cpu())
        sr_patch_best = denormalize(model(lr_patches[best_patch_index]).squeeze(0).cpu())
        hr_patch_best = denormalize(hr_patches[best_patch_index].squeeze(0).cpu())

        # Bilinear Upscale
        with torch.no_grad():
            bilinear_patch = F.interpolate(
                lr_patches[best_patch_index], scale_factor=scale_factor, mode="bilinear", align_corners=True
            ).squeeze(0)
            bilinear_patch_best = denormalize(bilinear_patch.cpu())

        # Save comparison
        plt.figure(figsize=(20, 5))

        # Low-Resolution Patch
        plt.subplot(1, 4, 1)
        plt.title("Low-Resolution Patch")
        plt.imshow(to_pil_image(lr_patch_best))
        plt.axis("off")

        # Bilinear Upscale
        if bilinear_psnr_best < highest_psnr:
            plt.subplot(1, 4, 2)
            plt.title(f"Bilinear PSNR: {bilinear_psnr_best:.2f} dB")
            plt.imshow(to_pil_image(bilinear_patch_best))
            plt.axis("off")
        else:
            plt.subplot(1, 4, 2)
            plt.title("Bilinear Upscale")
            plt.imshow(to_pil_image(bilinear_patch_best))
            plt.axis("off")

        # Super-Resolved Output (EDSR)
        plt.subplot(1, 4, 3)
        plt.title(f"EDSR PSNR: {highest_psnr:.2f} dB")
        plt.imshow(to_pil_image(sr_patch_best))
        plt.axis("off")

        # Original High-Resolution Patch
        plt.subplot(1, 4, 4)
        plt.title("Original High-Resolution Patch")
        plt.imshow(to_pil_image(hr_patch_best))
        plt.axis("off")

        plt.tight_layout()
        output_path = os.path.join(output_dir, f"{image_name}_best_patch.png")
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.close()

        print(f"Saved best patch comparison for '{image_name}' to {output_path}")
        print(f"Best Patch PSNR: EDSR = {highest_psnr:.2f} dB, Bilinear = {bilinear_psnr_best:.2f} dB")

def pad_image(image, patch_size):
    """Pad image to make dimensions divisible by patch size."""
    width, height = image.size
    pad_width = (patch_size - width % patch_size) % patch_size
    pad_height = (patch_size - height % patch_size) % patch_size
    padding = (0, 0, pad_width, pad_height)
    padded_image = pad(image, padding, fill=0)
    return padded_image, padding

# Run inference on a list of images

perform_inference_single_patch_EDSR("chkt_v6/edsr_epoch_15.pth", image_list, scale_factor=2, output_dir="edsr_best_patch_results")


In [None]:

image_direcotry = [
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",    
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",    
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",    
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",    
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",    
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0845.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0846.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0847.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\07848.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\DIV2K\0859.png",
    r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\HR_output.PNG", r"C:\Users\Turog\OneDrive\Documents\GitHub\576_DL_SuperRes\data\waves\1702053632133310.png"]
perform_inference_comparison_EDSR("chkt_v6/edsr_epoch_15.pth", image_list, scale_factor=2,output_dir="inference_results_v6")
