In [2]:
import os
import glob
import random
import math
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from tqdm import tqdm
from torchmetrics import StructuralSimilarityIndexMeasure

# (PSNR and EDSR definitions remain the same)
def calculate_psnr(sr, hr, max_val=1.0):
    mse = F.mse_loss(sr, hr)
    if mse == 0:
        return 100
    psnr = 10 * torch.log10((max_val ** 2) / mse)
    return psnr.item()

class ResidualBlock(nn.Module):
    def __init__(self, n_feats, res_scale=0.1):
        super(ResidualBlock, self).__init__()
        self.res_scale = res_scale
        self.conv1 = nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1)
    def forward(self, x):
        res = self.conv1(x)
        res = self.relu(res)
        res = self.conv2(res)
        return x + res * self.res_scale
class EDSR(nn.Module):
    def __init__(self, scale=4, n_resblocks=32, n_feats=64, res_scale=0.1, in_channels=3):
        super(EDSR, self).__init__()
        self.scale = scale
        self.conv_in = nn.Conv2d(in_channels, n_feats, kernel_size=3, padding=1)
        self.res_blocks = nn.Sequential(*[ResidualBlock(n_feats, res_scale) for _ in range(n_resblocks)])
        self.conv_mid = nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1)
        upscaling = []
        if scale in [2, 3]:
            upscaling.append(nn.Conv2d(n_feats, n_feats * (scale ** 2), kernel_size=3, padding=1))
            upscaling.append(nn.PixelShuffle(scale))
        elif scale == 4:
            for _ in range(2):
                upscaling.append(nn.Conv2d(n_feats, n_feats * 4, kernel_size=3, padding=1))
                upscaling.append(nn.PixelShuffle(2))
        else:
            raise NotImplementedError("Scale factor {} not supported.".format(scale))
        self.upscale = nn.Sequential(*upscaling)
        self.conv_out = nn.Conv2d(n_feats, in_channels, kernel_size=3, padding=1)
    def forward(self, x):
        x = self.conv_in(x)
        residual = x
        x = self.res_blocks(x)
        x = self.conv_mid(x)
        x = x + residual
        x = self.upscale(x)
        x = self.conv_out(x)
        return x

def validate_model(model, dataloader, device):
    model.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    num_batches = 0
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation")
        for batch in pbar:
            hr = batch['hr'].to(device)
            lr = batch['lr'].to(device)
            sr = model(lr).clamp(0, 1) # Ensure output is in [0, 1] for metrics

            # Ensure HR has the same size as SR for PSNR calculation
            if hr.size()[-2:] != sr.size()[-2:]:
                hr = F.interpolate(hr, size=sr.size()[-2:], mode='bicubic', align_corners=False)

            # Calculate PSNR
            psnr = calculate_psnr(sr, hr)
            total_psnr += psnr

            # Calculate SSIM
            total_ssim += ssim_metric(sr, hr).item()
            num_batches += 1
            pbar.set_postfix(psnr=f"{psnr:.2f}")

    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    return avg_psnr, avg_ssim

def main():
    # Define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the EDSR model
    scale = 4
    model = EDSR(scale=scale, n_resblocks=32, n_feats=64, res_scale=0.1, in_channels=3).to(device)

    # Load the pretrained model weights
    pretrained_model_path = "/kaggle/input/edsr_model/pytorch/default/1/EDSR_25.35.pth"
    model.load_state_dict(torch.load(pretrained_model_path, map_location=device)) # Load directly to the detected device
    model.eval()

    # --- Dataset Definitions (Ensure these are consistent with your previous code) ---
    val_transform_hr = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])
    val_transform_lr = transforms.Compose([
        transforms.ToTensor(),
    ])

    class ValDownsampleDataset(Dataset):
        def __init__(self, hr_dir, transform_hr=None):
            self.hr_dir = hr_dir
            self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                    key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
            self.transform_hr = transform_hr
        def __len__(self):
            return len(self.hr_images)
        def __getitem__(self, idx):
            hr_img = Image.open(self.hr_images[idx]).convert('RGB')
            if self.transform_hr:
                hr = self.transform_hr(hr_img)
            else:
                hr = transforms.ToTensor()(hr_img)
            c, h, w = hr.shape
            lr_width, lr_height = w // 4, h // 4
            hr_pil = transforms.ToPILImage()(hr)
            lr_pil = hr_pil.resize((lr_width, lr_height), Image.BICUBIC)
            lr = transforms.ToTensor()(lr_pil)
            return {'hr': hr, 'lr': lr}

    class PairedImageDataset(Dataset):
        def __init__(self, hr_dir, lr_dir, transform_hr=None, transform_lr=None):
            self.hr_dir = hr_dir
            self.lr_dir = lr_dir
            self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                    key=lambda path: os.path.splitext(os.path.basename(path))[0])
            self.lr_images = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                    key=lambda path: os.path.splitext(os.path.basename(path))[0])
            if len(self.hr_images) != len(self.lr_images):
                raise ValueError(f"Number of HR images in {hr_dir} ({len(self.hr_images)}) and LR images in {lr_dir} ({len(self.lr_images)}) must be the same.")
            hr_basenames = [os.path.splitext(os.path.basename(path))[0] for path in self.hr_images]
            lr_basenames = [os.path.splitext(os.path.basename(path))[0] for path in self.lr_images]
            if hr_basenames != lr_basenames:
                raise ValueError(f"HR and LR image filenames in {hr_dir} and {lr_dir} must match.")
            self.transform_hr = transform_hr
            self.transform_lr = transform_lr
        def __len__(self):
            return len(self.hr_images)
        def __getitem__(self, idx):
            hr_path = self.hr_images[idx]
            lr_path = self.lr_images[idx]
            hr_img = Image.open(hr_path).convert('RGB')
            lr_img = Image.open(lr_path).convert('RGB')
            if self.transform_hr:
                hr = self.transform_hr(hr_img)
            else:
                hr = transforms.ToTensor()(hr_img)
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = transforms.ToTensor()(lr_img)
            return {'hr': hr, 'lr': lr}

    class Urban100PairedDataset(Dataset):
        def __init__(self, hr_dir, lr_dir, transform_hr=None, transform_lr=None):
            self.hr_dir = hr_dir
            self.lr_dir = lr_dir
            self.hr_images = sorted([f for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg')) and '_HR' in f])
            self.lr_images = sorted([f for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg')) and '_LR' in f])
            self.transform_hr = transform_hr
            self.transform_lr = transform_lr
            self.pairs = []
            for hr_file in self.hr_images:
                base_name = hr_file.replace('_HR', '')
                lr_file = base_name.replace('.png', '_LR.png')
                if lr_file in self.lr_images:
                    hr_path = os.path.join(hr_dir, hr_file)
                    lr_path = os.path.join(lr_dir, lr_file)
                    self.pairs.append((hr_path, lr_path))
        def __len__(self):
            return len(self.pairs)
        def __getitem__(self, idx):
            hr_path, lr_path = self.pairs[idx]
            hr_img = Image.open(hr_path).convert('RGB')
            lr_img = Image.open(lr_path).convert('RGB')
            if self.transform_hr:
                hr = self.transform_hr(hr_img)
            else:
                hr = transforms.ToTensor()(hr_img)
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = transforms.ToTensor()(lr_img)
            return {'hr': hr, 'lr': lr}

    # --- Load Validation Datasets ---
    root_dir = "/kaggle/input"
    val_datasets = {}

    # DIV2K
    div2k_val_hr_directory = os.path.join(root_dir, "div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR")
    div2k_dataset = ValDownsampleDataset(hr_dir=div2k_val_hr_directory, transform_hr=val_transform_hr)
    print(f"DIV2K Dataset size: {len(div2k_dataset)}")
    val_datasets['DIV2K'] = DataLoader(div2k_dataset, batch_size=1, shuffle=False, num_workers=2)

    # BSD100
    bsd100_hr_dir = os.path.join(root_dir, "bsd100/bsd100/bicubic_4x/train/HR")
    bsd100_lr_dir_bicubic_4x = os.path.join(root_dir, "bsd100/bsd100/bicubic_4x/train/LR")
    if os.path.exists(bsd100_lr_dir_bicubic_4x) and len(os.listdir(bsd100_lr_dir_bicubic_4x)) > 0:
        bsd100_dataset = PairedImageDataset(hr_dir=bsd100_hr_dir, lr_dir=bsd100_lr_dir_bicubic_4x, transform_hr=val_transform_hr, transform_lr=val_transform_lr)
        print(f"BSD100 Paired Dataset size: {len(bsd100_dataset)}")
        val_datasets['BSD100'] = DataLoader(bsd100_dataset, batch_size=1, shuffle=False, num_workers=2)
    else:
        bsd100_dataset = ValDownsampleDataset(hr_dir=bsd100_hr_dir, transform_hr=val_transform_hr)
        print(f"BSD100 Downsample Dataset size: {len(bsd100_dataset)}")
        val_datasets['BSD100'] = DataLoader(bsd100_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set14
    set14_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set14/Set14")
    set14_dataset = ValDownsampleDataset(hr_dir=set14_hr_dir, transform_hr=val_transform_hr)
    print(f"Set14 Dataset size: {len(set14_dataset)}")
    val_datasets['Set14'] = DataLoader(set14_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set5
    set5_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set5/Set5")
    set5_dataset = ValDownsampleDataset(hr_dir=set5_hr_dir, transform_hr=val_transform_hr)
    print(f"Set5 Dataset size: {len(set5_dataset)}")
    val_datasets['Set5'] = DataLoader(set5_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Urban100
    urban100_hr_dir = os.path.join(root_dir, "urban100/Urban 100/X4 Urban100/X4/HIGH x4 URban100")
    urban100_lr_dir = os.path.join(root_dir, "urban100/Urban 100/X4 Urban100/X4/LOW x4 URban100")
    if os.path.exists(urban100_lr_dir) and len(os.listdir(urban100_lr_dir)) > 0:
        urban100_dataset = Urban100PairedDataset(hr_dir=urban100_hr_dir, lr_dir=urban100_lr_dir, transform_hr=val_transform_hr, transform_lr=val_transform_lr)
        print(f"Urban100 Paired Dataset size: {len(urban100_dataset)}")
        val_datasets['Urban100'] = DataLoader(urban100_dataset, batch_size=1, shuffle=False, num_workers=2)
    else:
        urban100_dataset = ValDownsampleDataset(hr_dir=urban100_hr_dir, transform_hr=val_transform_hr)
        print(f"Urban100 Downsample Dataset size: {len(urban100_dataset)}")
        val_datasets['Urban100'] = DataLoader(urban100_dataset, batch_size=1, shuffle=False, num_workers=2)

    print("--- Validation Results (GPU) ---")
    for dataset_name, dataloader in val_datasets.items():
        avg_psnr, avg_ssim = validate_model(model, dataloader, device)
        print(f"Dataset: {dataset_name}, Avg PSNR: {avg_psnr:.2f} dB, Avg SSIM: {avg_ssim:.4f}")

if __name__ == "__main__":
    main()

Using device: cuda


  model.load_state_dict(torch.load(pretrained_model_path, map_location=device)) # Load directly to the detected device


DIV2K Dataset size: 100
BSD100 Paired Dataset size: 80
Set14 Dataset size: 14
Set5 Dataset size: 5
Urban100 Paired Dataset size: 100
--- Validation Results (GPU) ---


Validation: 100%|██████████| 100/100 [00:08<00:00, 12.20it/s, psnr=21.14]


Dataset: DIV2K, Avg PSNR: 23.26 dB, Avg SSIM: 0.6643


Validation: 100%|██████████| 80/80 [00:01<00:00, 48.35it/s, psnr=30.59]


Dataset: BSD100, Avg PSNR: 32.82 dB, Avg SSIM: 0.9123


Validation: 100%|██████████| 14/14 [00:00<00:00, 56.45it/s, psnr=19.80]


Dataset: Set14, Avg PSNR: 23.13 dB, Avg SSIM: 0.6697


Validation: 100%|██████████| 5/5 [00:00<00:00, 35.56it/s, psnr=23.69]


Dataset: Set5, Avg PSNR: 24.75 dB, Avg SSIM: 0.7436


Validation: 100%|██████████| 100/100 [00:09<00:00, 10.79it/s, psnr=20.60]

Dataset: Urban100, Avg PSNR: 23.31 dB, Avg SSIM: 0.6748





In [2]:
import os
import glob
import random
import math
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torchmetrics import StructuralSimilarityIndexMeasure

#######################################
# 2. PSNR Calculation Function (Reused)
#######################################

def calculate_psnr(sr, hr, max_val=1.0):
    mse = F.mse_loss(sr, hr)
    if mse == 0:
        return 100.0
    psnr = 10 * torch.log10((max_val ** 2) / mse)
    return psnr.item()

#######################################
# 4. ESRGAN Generator (Reused)
#######################################

class ResidualDenseBlock(nn.Module):
    def __init__(self, channels, growth_channels=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, growth_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels + growth_channels, growth_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(channels + 2 * growth_channels, growth_channels, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(channels + 3 * growth_channels, growth_channels, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(channels + 4 * growth_channels, channels, kernel_size=3, padding=1)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)
        self.res_scale = 0.2

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat([x, x1], 1)))
        x3 = self.lrelu(self.conv3(torch.cat([x, x1, x2], 1)))
        x4 = self.lrelu(self.conv4(torch.cat([x, x1, x2, x3], 1)))
        x5 = self.conv5(torch.cat([x, x1, x2, x3, x4], 1))
        return x + x5 * self.res_scale

class RRDB(nn.Module):
    def __init__(self, channels, growth_channels=32):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(channels, growth_channels)
        self.rdb2 = ResidualDenseBlock(channels, growth_channels)
        self.rdb3 = ResidualDenseBlock(channels, growth_channels)
        self.res_scale = 0.2

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return x + out * self.res_scale

class ESRGANGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_feat=64, num_rrdb=23, growth_channels=32, scale=4):
        super(ESRGANGenerator, self).__init__()
        self.conv_first = nn.Conv2d(in_channels, num_feat, kernel_size=3, padding=1)
        rrdb_blocks = [RRDB(num_feat, growth_channels) for _ in range(num_rrdb)]
        self.RRDB_trunk = nn.Sequential(*rrdb_blocks)
        self.trunk_conv = nn.Conv2d(num_feat, num_feat, kernel_size=3, padding=1)
        # Upsampling: for 4x scaling, use two PixelShuffle blocks.
        upsample_layers = []
        num_upsample = int(math.log(scale, 2))
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(num_feat, num_feat * 4, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        self.conv_last = nn.Conv2d(num_feat, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk
        out = self.upsampling(fea)
        out = self.conv_last(out)
        return out

#######################################
# Validation Function (Adapted from EDSR)
#######################################

def validate_model(model, dataloader, device):
    model.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    num_batches = 0
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation")
        for batch in pbar:
            hr = batch['hr'].to(device)
            lr = batch['lr'].to(device)
            sr = model(lr).clamp(0, 1) # Ensure output is in [0, 1] for metrics

            # Ensure HR has the same size as SR for PSNR calculation
            if hr.size()[-2:] != sr.size()[-2:]:
                hr = F.interpolate(hr, size=sr.size()[-2:], mode='bicubic', align_corners=False)

            # Calculate PSNR
            psnr = calculate_psnr(sr, hr)
            total_psnr += psnr

            # Calculate SSIM
            total_ssim += ssim_metric(sr, hr).item()
            num_batches += 1
            pbar.set_postfix(psnr=f"{psnr:.2f}")

    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    return avg_psnr, avg_ssim

#######################################
# Dataset Definitions (Reused and potentially adapted)
#######################################

val_transform_hr = transforms.Compose([
    transforms.Resize((128, 128)), # Or your desired HR validation size
    transforms.ToTensor(),
])
val_transform_lr = transforms.Compose([
    transforms.ToTensor(),
])

class ValDownsampleDataset(Dataset):
    def __init__(self, hr_dir, transform_hr=None, scale=4):
        self.hr_dir = hr_dir
        self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
        self.transform_hr = transform_hr
        self.scale = scale
    def __len__(self):
        return len(self.hr_images)
    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')
        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = transforms.ToTensor()(hr_img)
        c, h, w = hr.shape
        lr_width, lr_height = w // self.scale, h // self.scale
        hr_pil = transforms.ToPILImage()(hr)
        lr_pil = hr_pil.resize((lr_width, lr_height), Image.BICUBIC)
        lr = transforms.ToTensor()(lr_pil)
        return {'hr': hr, 'lr': lr}

class PairedImageDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform_hr=None, transform_lr=None):
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                key=lambda path: os.path.splitext(os.path.basename(path))[0])
        self.lr_images = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                key=lambda path: os.path.splitext(os.path.basename(path))[0])
        if len(self.hr_images) != len(self.lr_images):
            raise ValueError(f"Number of HR images in {hr_dir} ({len(self.hr_images)}) and LR images in {lr_dir} ({len(self.lr_images)}) must be the same.")
        hr_basenames = [os.path.splitext(os.path.basename(path))[0] for path in self.hr_images]
        lr_basenames = [os.path.splitext(os.path.basename(path))[0] for path in self.lr_images]
        if hr_basenames != lr_basenames:
            raise ValueError(f"HR and LR image filenames in {hr_dir} and {lr_dir} must match.")
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
    def __len__(self):
        return len(self.hr_images)
    def __getitem__(self, idx):
        hr_path = self.hr_images[idx]
        lr_path = self.lr_images[idx]
        hr_img = Image.open(hr_path).convert('RGB')
        lr_img = Image.open(lr_path).convert('RGB')
        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = transforms.ToTensor()(hr_img)
        if self.transform_lr:
            lr = self.transform_lr(lr_img)
        else:
            lr = transforms.ToTensor()(lr_img)
        return {'hr': hr, 'lr': lr}

class Urban100PairedDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform_hr=None, transform_lr=None):
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.hr_images = sorted([f for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg')) and '_HR' in f])
        self.lr_images = sorted([f for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg')) and '_LR' in f])
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
        self.pairs = []
        for hr_file in self.hr_images:
            base_name = hr_file.replace('_HR', '')
            lr_file = base_name.replace('.png', '_LR.png')
            if lr_file in self.lr_images:
                hr_path = os.path.join(hr_dir, hr_file)
                lr_path = os.path.join(lr_dir, lr_file)
                self.pairs.append((hr_path, lr_path))
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        hr_path, lr_path = self.pairs[idx]
        hr_img = Image.open(hr_path).convert('RGB')
        lr_img = Image.open(lr_path).convert('RGB')
        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = transforms.ToTensor()(hr_img)
        if self.transform_lr:
            lr = self.transform_lr(lr_img)
        else:
            lr = transforms.ToTensor()(lr_img)
        return {'hr': hr, 'lr': lr}

#######################################
# Main Function for Validation
#######################################

def main():
    # Define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the ESRGAN model
    scale = 4
    model = ESRGANGenerator(scale=scale, num_rrdb=16, num_feat=64, growth_channels=32).to(device) # Use the same architecture as in training

    # Load the pretrained model weights
    pretrained_model_path = "/kaggle/input/esrgan_model/pytorch/default/1/ESRGAN_26.56.pth" # Update this path to your pretrained model file
    if os.path.exists(pretrained_model_path):
        model.load_state_dict(torch.load(pretrained_model_path, map_location=device))
        print(f"Loaded pretrained model from: {pretrained_model_path}")
    else:
        print(f"Pretrained model not found at: {pretrained_model_path}")
        return

    model.eval()

    # --- Load Validation Datasets ---
    root_dir = "/kaggle/input"
    val_datasets = {}

    # DIV2K (Using downsampling as the training data was generated this way)
    div2k_val_hr_directory = os.path.join(root_dir, "div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR")
    div2k_dataset = ValDownsampleDataset(hr_dir=div2k_val_hr_directory, transform_hr=val_transform_hr, scale=scale)
    print(f"DIV2K Dataset size: {len(div2k_dataset)}")
    val_datasets['DIV2K'] = DataLoader(div2k_dataset, batch_size=1, shuffle=False, num_workers=2)

    # BSD100 (Using downsampling)
    bsd100_hr_dir = os.path.join(root_dir, "bsd100/bsd100/bicubic_4x/train/HR")
    bsd100_dataset = ValDownsampleDataset(hr_dir=bsd100_hr_dir, transform_hr=val_transform_hr, scale=scale)
    print(f"BSD100 Downsample Dataset size: {len(bsd100_dataset)}")
    val_datasets['BSD100'] = DataLoader(bsd100_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set14 (Using downsampling)
    set14_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set14/Set14")
    set14_dataset = ValDownsampleDataset(hr_dir=set14_hr_dir, transform_hr=val_transform_hr, scale=scale)
    print(f"Set14 Dataset size: {len(set14_dataset)}")
    val_datasets['Set14'] = DataLoader(set14_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set5 (Using downsampling)
    set5_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set5/Set5")
    set5_dataset = ValDownsampleDataset(hr_dir=set5_hr_dir, transform_hr=val_transform_hr, scale=scale)
    print(f"Set5 Dataset size: {len(set5_dataset)}")
    val_datasets['Set5'] = DataLoader(set5_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Urban100 (Using downsampling)
    urban100_hr_dir = os.path.join(root_dir, "urban100/Urban 100/X4 Urban100/X4/HIGH x4 URban100")
    urban100_dataset = ValDownsampleDataset(hr_dir=urban100_hr_dir, transform_hr=val_transform_hr, scale=scale)
    print(f"Urban100 Downsample Dataset size: {len(urban100_dataset)}")
    val_datasets['Urban100'] = DataLoader(urban100_dataset, batch_size=1, shuffle=False, num_workers=2)

    print("--- ESRGAN Validation Results ---")
    for dataset_name, dataloader in val_datasets.items():
        avg_psnr, avg_ssim = validate_model(model, dataloader, device)
        print(f"Dataset: {dataset_name}, Avg PSNR: {avg_psnr:.2f} dB, Avg SSIM: {avg_ssim:.4f}")

if __name__ == "__main__":
    main()

Using device: cuda


  model.load_state_dict(torch.load(pretrained_model_path, map_location=device))


Loaded pretrained model from: /kaggle/input/esrgan_model/pytorch/default/1/ESRGAN_26.56.pth
DIV2K Dataset size: 100
BSD100 Downsample Dataset size: 80
Set14 Dataset size: 14
Set5 Dataset size: 5
Urban100 Downsample Dataset size: 100
--- ESRGAN Validation Results ---


Validation: 100%|██████████| 100/100 [00:06<00:00, 15.04it/s, psnr=21.34]


Dataset: DIV2K, Avg PSNR: 23.56 dB, Avg SSIM: 0.6803


Validation: 100%|██████████| 80/80 [00:02<00:00, 29.67it/s, psnr=22.74]


Dataset: BSD100, Avg PSNR: 25.46 dB, Avg SSIM: 0.7118


Validation: 100%|██████████| 14/14 [00:00<00:00, 25.46it/s, psnr=20.00]


Dataset: Set14, Avg PSNR: 23.56 dB, Avg SSIM: 0.6911


Validation: 100%|██████████| 5/5 [00:00<00:00, 18.88it/s, psnr=24.22]


Dataset: Set5, Avg PSNR: 25.47 dB, Avg SSIM: 0.7762


Validation: 100%|██████████| 100/100 [00:03<00:00, 30.42it/s, psnr=17.01]

Dataset: Urban100, Avg PSNR: 21.83 dB, Avg SSIM: 0.6081





In [1]:
import os
import glob
import random
import math
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torchmetrics import StructuralSimilarityIndexMeasure

#######################################
# 2. PSNR Calculation (Reused)
#######################################

def calculate_psnr(sr, hr, max_val=1.0):
    mse = F.mse_loss(sr, hr)
    if mse == 0:
        return 100.0
    psnr = 10 * torch.log10((max_val ** 2) / mse)
    return psnr.item()

#######################################
# 4. RGT-S Generator Architecture (Reused)
#######################################

class RGTransformerBlock(nn.Module):
    """
    An improved recursive-generalization transformer block.
    It flattens spatial features into tokens, applies global self-attention
    followed by a local self-attention branch, then an MLP.
    """
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(RGTransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn_global = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.attn_local = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout)
        self.norm3 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # x: (B, C, H, W). Flatten spatial dims to (N, B, C) where N=H*W.
        B, C, H, W = x.shape
        N = H * W
        x_flat = x.view(B, C, N).permute(2, 0, 1)  # shape: (N, B, C)

        # Global self-attention
        x_norm = self.norm1(x_flat)
        attn_global, _ = self.attn_global(x_norm, x_norm, x_norm)
        x_global = x_flat + attn_global

        # Local self-attention (could use windowing; here we use full tokens for simplicity)
        x_norm2 = self.norm2(x_global)
        attn_local, _ = self.attn_local(x_norm2, x_norm2, x_norm2)
        x_local = x_global + attn_local

        # MLP block
        x_norm3 = self.norm3(x_local)
        mlp_out = self.mlp(x_norm3)
        x_out = x_local + mlp_out

        # Reshape back to (B, C, H, W)
        x_out = x_out.permute(1, 2, 0).view(B, C, H, W)
        return x_out

class RGT_S_Generator(nn.Module):
    """
    Improved RGT-S Generator for 4x image super-resolution.
    Architecture:
      - Shallow convolutional embedding.
      - A stack of transformer blocks (RGTransformerBlock) with increased capacity.
      - A convolutional trunk (plus an extra refinement conv) with a skip connection.
      - Upsampling via PixelShuffle.
      - Final convolution to output the image.
    """
    def __init__(self, in_channels=3, out_channels=3, embed_dim=128, depth=12, num_heads=8, mlp_ratio=4.0, scale=4, dropout=0.1):
        super(RGT_S_Generator, self).__init__()
        self.conv_first = nn.Conv2d(in_channels, embed_dim, kernel_size=3, padding=1)
        self.transformer_blocks = nn.Sequential(*[
            RGTransformerBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, dropout=dropout)
            for _ in range(depth)
        ])
        # Convolutional trunk with extra refinement
        self.trunk_conv = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1)
        self.refine_conv = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1)

        # Upsampling module: for 4x scaling, use two PixelShuffle blocks.
        upsample_layers = []
        num_upsample = int(math.log(scale, 2))
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(embed_dim, embed_dim * 4, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
                nn.GELU()
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        self.conv_last = nn.Conv2d(embed_dim, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        fea = self.conv_first(x)
        trans_out = self.transformer_blocks(fea)
        trunk = self.trunk_conv(trans_out)
        refined = self.refine_conv(trunk)
        fea = fea + refined  # global skip connection
        out = self.upsampling(fea)
        out = self.conv_last(out)
        return out

#######################################
# Validation Function (Reused)
#######################################

def validate_model(model, dataloader, device):
    model.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    num_batches = 0
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation")
        for batch in pbar:
            hr = batch['hr'].to(device)
            lr = batch['lr'].to(device)
            sr = model(lr).clamp(0, 1) # Ensure output is in [0, 1] for metrics

            # Ensure HR has the same size as SR for PSNR calculation
            if hr.size()[-2:] != sr.size()[-2:]:
                hr = F.interpolate(hr, size=sr.size()[-2:], mode='bicubic', align_corners=False)

            # Calculate PSNR
            psnr = calculate_psnr(sr, hr)
            total_psnr += psnr

            # Calculate SSIM
            total_ssim += ssim_metric(sr, hr).item()
            num_batches += 1
            pbar.set_postfix(psnr=f"{psnr:.2f}")

    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    return avg_psnr, avg_ssim

#######################################
# Dataset Definitions (Modified to limit HR size)
#######################################

val_transform_hr = transforms.Compose([
    transforms.ToTensor(),
])
val_transform_lr = transforms.Compose([
    transforms.ToTensor(),
])

class ValDownsampleDataset(Dataset):
    def __init__(self, hr_dir, transform_hr=None, scale=4, max_hr_size=256): # Added max_hr_size
        self.hr_dir = hr_dir
        self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
        self.transform_hr = transform_hr
        self.scale = scale
        self.max_hr_size = max_hr_size # Maximum height or width for HR image

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')

        # Resize HR image if it's larger than max_hr_size
        if max(hr_img.size) > self.max_hr_size:
            width, height = hr_img.size
            if width > height:
                new_width = self.max_hr_size
                new_height = int(height * (new_width / width))
            else:
                new_height = self.max_hr_size
                new_width = int(width * (new_height / height))
            hr_img = hr_img.resize((new_width, new_height), Image.LANCZOS) # Use LANCZOS for better quality

        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = transforms.ToTensor()(hr_img)

        c, h, w = hr.shape
        lr_width, lr_height = w // self.scale, h // self.scale
        hr_pil = transforms.ToPILImage()(hr)
        lr_pil = hr_pil.resize((lr_width, lr_height), Image.BICUBIC)
        lr = transforms.ToTensor()(lr_pil)
        return {'hr': hr, 'lr': lr}

#######################################
# Main Function for Validation
#######################################

def main():
    # Define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the RGT-S Generator model
    scale = 4
    model = RGT_S_Generator(scale=scale, embed_dim=128, depth=12, num_heads=8, mlp_ratio=4.0, dropout=0.1).to(device) # Use the same architecture as in training

    # Load the pretrained model weights
    pretrained_model_path = "/kaggle/input/best_rgt/pytorch/default/1/RGT_S_27.57.pth" # Updated path
    if os.path.exists(pretrained_model_path):
        model.load_state_dict(torch.load(pretrained_model_path, map_location=device))
        print(f"Loaded pretrained model from: {pretrained_model_path}")
    else:
        print(f"Pretrained model not found at: {pretrained_model_path}")
        return

    model.eval()

    # --- Load Validation Datasets ---
    root_dir = "/kaggle/input"
    val_datasets = {}

    # DIV2K
    div2k_val_hr_directory = os.path.join(root_dir, "div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR")
    div2k_dataset = ValDownsampleDataset(hr_dir=div2k_val_hr_directory, transform_hr=val_transform_hr, scale=scale, max_hr_size=256) # Added max_hr_size
    print(f"DIV2K Dataset size: {len(div2k_dataset)}")
    val_datasets['DIV2K'] = DataLoader(div2k_dataset, batch_size=1, shuffle=False, num_workers=2)

    # BSD100
    bsd100_hr_dir = os.path.join(root_dir, "bsd100/bsd100/bicubic_4x/train/HR")
    bsd100_dataset = ValDownsampleDataset(hr_dir=bsd100_hr_dir, transform_hr=val_transform_hr, scale=scale, max_hr_size=256) # Added max_hr_size
    print(f"BSD100 Downsample Dataset size: {len(bsd100_dataset)}")
    val_datasets['BSD100'] = DataLoader(bsd100_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set14
    set14_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set14/Set14")
    set14_dataset = ValDownsampleDataset(hr_dir=set14_hr_dir, transform_hr=val_transform_hr, scale=scale, max_hr_size=256) # Added max_hr_size
    print(f"Set14 Dataset size: {len(set14_dataset)}")
    val_datasets['Set14'] = DataLoader(set14_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set5
    set5_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set5/Set5")
    set5_dataset = ValDownsampleDataset(hr_dir=set5_hr_dir, transform_hr=val_transform_hr, scale=scale, max_hr_size=256) # Added max_hr_size
    print(f"Set5 Dataset size: {len(set5_dataset)}")
    val_datasets['Set5'] = DataLoader(set5_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Urban100
    urban100_hr_dir = os.path.join(root_dir, "urban100/Urban 100/X4 Urban100/X4/HIGH x4 URban100")
    urban100_dataset = ValDownsampleDataset(hr_dir=urban100_hr_dir, transform_hr=val_transform_hr, scale=scale, max_hr_size=256) # Added max_hr_size
    print(f"Urban100 Downsample Dataset size: {len(urban100_dataset)}")
    val_datasets['Urban100'] = DataLoader(urban100_dataset, batch_size=1, shuffle=False, num_workers=2)

    print("--- RGT-S Model Validation Results ---")
    for dataset_name, dataloader in val_datasets.items():
        avg_psnr, avg_ssim = validate_model(model, dataloader, device)
        print(f"Dataset: {dataset_name}, Avg PSNR: {avg_psnr:.2f} dB, Avg SSIM: {avg_ssim:.4f}")

if __name__ == "__main__":
    main()

Using device: cuda


  model.load_state_dict(torch.load(pretrained_model_path, map_location=device))


Loaded pretrained model from: /kaggle/input/best_rgt/pytorch/default/1/RGT_S_27.57.pth
DIV2K Dataset size: 100
BSD100 Downsample Dataset size: 80
Set14 Dataset size: 14
Set5 Dataset size: 5
Urban100 Downsample Dataset size: 100
--- RGT-S Model Validation Results ---


Validation: 100%|██████████| 100/100 [00:21<00:00,  4.69it/s, psnr=20.17]


Dataset: DIV2K, Avg PSNR: 22.27 dB, Avg SSIM: 0.6315


Validation: 100%|██████████| 80/80 [00:14<00:00,  5.44it/s, psnr=21.40]


Dataset: BSD100, Avg PSNR: 23.94 dB, Avg SSIM: 0.6482


Validation: 100%|██████████| 14/14 [00:04<00:00,  3.39it/s, psnr=18.71]


Dataset: Set14, Avg PSNR: 22.82 dB, Avg SSIM: 0.6433


Validation: 100%|██████████| 5/5 [00:01<00:00,  2.78it/s, psnr=23.90]


Dataset: Set5, Avg PSNR: 25.73 dB, Avg SSIM: 0.7561


Validation: 100%|██████████| 100/100 [00:22<00:00,  4.44it/s, psnr=17.29]

Dataset: Urban100, Avg PSNR: 20.25 dB, Avg SSIM: 0.5560





In [4]:
import os
import glob
import random
import math
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torchmetrics import StructuralSimilarityIndexMeasure
import torchvision.models as models

################################################################################
# MODEL: TRANSFORMER-BASED SR (Smaller RGT-like)
################################################################################

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=4):
        super(MultiHeadSelfAttention, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.query = nn.Linear(dim, dim)
        self.key   = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.out   = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.size()
        q = self.query(x).view(B, N, self.num_heads, self.head_dim)
        k = self.key(x).view(B, N, self.num_heads, self.head_dim)
        v = self.value(x).view(B, N, self.num_heads, self.head_dim)

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, N, C)
        out = self.out(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, dim, expansion=2):
        super(FeedForward, self).__init__()
        hidden_dim = dim * expansion
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4, expansion=2):
        super(TransformerBlock, self).__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)
        self.ln2 = nn.LayerNorm(dim)
        self.ffn = FeedForward(dim, expansion)

    def forward(self, x):
        x_ = self.ln1(x)
        x = x + self.attn(x_)
        x_ = self.ln2(x)
        x = x + self.ffn(x_)
        return x


class ResidualGroup(nn.Module):
    def __init__(self, dim, num_blocks=2, num_heads=4):
        super(ResidualGroup, self).__init__()
        blocks = []
        for _ in range(num_blocks):
            blocks.append(TransformerBlock(dim, num_heads=num_heads))
        self.blocks = nn.Sequential(*blocks)
        self.layer_norm = nn.LayerNorm(dim)

    def forward(self, x):
        res = x
        x = self.blocks(x)
        x = self.layer_norm(x)
        return x + res


class UpsampleBlock(nn.Module):
    def __init__(self, dim, scale=4):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(dim, dim * (scale**2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale)
        self.final_conv = nn.Conv2d(dim, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.final_conv(x)
        return x


class RGTNet(nn.Module):
    """
    A smaller version of an RGT-like Transformer for 4× SR.
    """
    def __init__(self, dim=48, num_groups=3, num_blocks=2, heads=4):
        super(RGTNet, self).__init__()
        self.shallow = nn.Conv2d(3, dim, kernel_size=3, padding=1)
        self.groups = nn.ModuleList([
            ResidualGroup(dim, num_blocks=num_blocks, num_heads=heads)
            for _ in range(num_groups)
        ])
        self.ln = nn.LayerNorm(dim)
        self.upsample = UpsampleBlock(dim, scale=4)

    def forward(self, x):
        fea = self.shallow(x)
        B, C, H, W = fea.shape
        fea_seq = fea.permute(0, 2, 3, 1).contiguous().view(B, H*W, C)

        for g in self.groups:
            fea_seq = g(fea_seq)

        fea_seq = self.ln(fea_seq)
        fea = fea_seq.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
        out = self.upsample(fea)
        return out

################################################################################
# LOSS FUNCTIONS & METRICS - Reused PSNR
################################################################################

def calculate_psnr(sr, hr, shave_border=4, only_y=True):
    sr = sr.clamp(0,1)
    hr = hr.clamp(0,1)
    if shave_border > 0:
        sr = sr[..., shave_border:-shave_border, shave_border:-shave_border]
        hr = hr[..., shave_border:-shave_border, shave_border:-shave_border]
    if only_y:
        sr_y = rgb_to_y(sr)
        hr_y = rgb_to_y(hr)
        mse = F.mse_loss(sr_y, hr_y)
    else:
        mse = F.mse_loss(sr, hr)
    if mse == 0:
        return 999.0
    return -10 * math.log10(mse.item())

def rgb_to_y(tensor_rgb):
    # Weighted sum for Y channel
    r = tensor_rgb[:,0,:,:]
    g = tensor_rgb[:,1,:,:]
    b = tensor_rgb[:,2,:,:]
    y = 0.299*r + 0.587*g + 0.114*b
    return y.unsqueeze(1)

#######################################
# Validation Function (Adapted with error handling)
#######################################

def validate_model(model, dataloader, device):
    model.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    num_batches = 0
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation")
        for batch in pbar:
            hr = batch['hr'].to(device)
            lr = batch['lr'].to(device)

            try:
                sr = model(lr).clamp(0, 1) # Ensure output is in [0, 1]

                # Ensure HR has the same size as SR for PSNR calculation
                if hr.size()[-2:] != sr.size()[-2:]:
                    hr = F.interpolate(hr, size=sr.size()[-2:], mode='bicubic', align_corners=False)

                # Calculate PSNR (using the training script's function)
                psnr = calculate_psnr(sr, hr, shave_border=4, only_y=True)
                total_psnr += psnr

                # Calculate SSIM (on RGB)
                total_ssim += ssim_metric(sr, hr).item()
                num_batches += 1
                pbar.set_postfix(psnr=f"{psnr:.2f}")

            except Exception as e:
                print(f"Error during model forward pass: {e}")
                return -1, -1 # Indicate an error

    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    return avg_psnr, avg_ssim

#######################################
# Dataset Definitions (Adapted for direct LR/HR loading or downsampling)
#######################################

val_transform_hr = transforms.Compose([
    transforms.ToTensor(),
])
val_transform_lr = transforms.Compose([
    transforms.ToTensor(),
])

class ValDataset(Dataset):
    def __init__(self, hr_dir, scale=4, transform_hr=None, transform_lr=None, load_lr_directly=False, lr_dir=None, max_hr_size=None): # Added max_hr_size
        self.hr_dir = hr_dir
        self.scale = scale
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
        self.load_lr_directly = load_lr_directly
        self.lr_dir = lr_dir
        self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
        if load_lr_directly and lr_dir:
            self.lr_images = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                    key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
            assert len(self.hr_images) == len(self.lr_images), "Number of HR and LR images mismatch in direct load mode."
        self.max_hr_size = max_hr_size

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_path = self.hr_images[idx]
        hr_img = Image.open(hr_path).convert('RGB')

        if self.max_hr_size is not None and max(hr_img.size) > self.max_hr_size:
            width, height = hr_img.size
            if width > height:
                new_width = self.max_hr_size
                new_height = int(height * (new_width / width))
            else:
                new_height = self.max_hr_size
                new_width = int(width * (new_height / height))
            hr_img = hr_img.resize((new_width, new_height), Image.LANCZOS)

        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = transforms.ToTensor()(hr_img)

        if self.load_lr_directly and self.lr_dir:
            lr_path = self.lr_images[idx]
            lr_img = Image.open(lr_path).convert('RGB')
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = transforms.ToTensor()(lr_img)
        else:
            lr_width, lr_height = hr_img.width // self.scale, hr_img.height // self.scale
            lr_img = hr_img.resize((lr_width, lr_height), Image.BICUBIC)
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = transforms.ToTensor()(lr_img)

        return {'hr': hr, 'lr': lr}

#######################################
# Main Function for Validation
#######################################

def main():
    # Define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the RGTNet model (using the same parameters as in training)
    model = RGTNet(dim=48, num_groups=3, num_blocks=2, heads=4).to(device)

    # Load the pretrained model weights
    pretrained_model_path = "/kaggle/input/sr_model/pytorch/default/1/SR_30.38.pth" # Update this path if necessary
    if os.path.exists(pretrained_model_path):
        model.load_state_dict(torch.load(pretrained_model_path, map_location=device))
        print(f"Loaded pretrained model from: {pretrained_model_path}")
    else:
        print(f"Pretrained model not found at: {pretrained_model_path}")
        return

    model.eval()

    # --- Load Validation Datasets ---
    root_dir = "/kaggle/input"
    val_datasets = {}
    scale = 4
    max_hr_size = 256 # Keeping the maximum HR size

    # DIV2K (HR only, will downsample)
    div2k_val_hr_directory = os.path.join(root_dir, "div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR")
    div2k_dataset = ValDataset(hr_dir=div2k_val_hr_directory, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"DIV2K Dataset size: {len(div2k_dataset)}")
    val_datasets['DIV2K'] = DataLoader(div2k_dataset, batch_size=1, shuffle=False, num_workers=2)

    # BSD100 (HR only, will downsample)
    bsd100_hr_dir = os.path.join(root_dir, "bsd100/bsd100/bicubic_4x/train/HR")
    bsd100_dataset = ValDataset(hr_dir=bsd100_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"BSD100 Dataset size: {len(bsd100_dataset)}")
    val_datasets['BSD100'] = DataLoader(bsd100_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set14 (HR only, will downsample)
    set14_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set14/Set14")
    set14_dataset = ValDataset(hr_dir=set14_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"Set14 Dataset size: {len(set14_dataset)}")
    val_datasets['Set14'] = DataLoader(set14_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set5 (HR only, will downsample)
    set5_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set5/Set5")
    set5_dataset = ValDataset(hr_dir=set5_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"Set5 Dataset size: {len(set5_dataset)}")
    val_datasets['Set5'] = DataLoader(set5_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Urban100 (HR only, will downsample)
    urban100_hr_dir = os.path.join(root_dir, "urban100/Urban 100/X4 Urban100/X4/HIGH x4 URban100")
    urban100_dataset = ValDataset(hr_dir=urban100_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"Urban100 Dataset size: {len(urban100_dataset)}")
    val_datasets['Urban100'] = DataLoader(urban100_dataset, batch_size=1, shuffle=False, num_workers=2)

    print("--- RGTNet Model Validation Results ---")
    for dataset_name, dataloader in val_datasets.items():
        avg_psnr, avg_ssim = validate_model(model, dataloader, device)
        if avg_psnr != -1:
            print(f"Dataset: {dataset_name}, Avg PSNR: {avg_psnr:.2f} dB, Avg SSIM: {avg_ssim:.4f}")

if __name__ == "__main__":
    main()

  model.load_state_dict(torch.load(pretrained_model_path, map_location=device))


Using device: cuda
Loaded pretrained model from: /kaggle/input/sr_model/pytorch/default/1/SR_30.38.pth
DIV2K Dataset size: 100
BSD100 Dataset size: 80
Set14 Dataset size: 14
Set5 Dataset size: 5
Urban100 Dataset size: 100
--- RGTNet Model Validation Results ---


Validation: 100%|██████████| 100/100 [00:07<00:00, 13.07it/s, psnr=20.19]


Dataset: DIV2K, Avg PSNR: 22.48 dB, Avg SSIM: 0.6289


Validation: 100%|██████████| 80/80 [00:02<00:00, 31.58it/s, psnr=21.25]


Dataset: BSD100, Avg PSNR: 23.99 dB, Avg SSIM: 0.6408


Validation: 100%|██████████| 14/14 [00:00<00:00, 19.18it/s, psnr=18.46]


Dataset: Set14, Avg PSNR: 23.34 dB, Avg SSIM: 0.6436


Validation: 100%|██████████| 5/5 [00:00<00:00, 13.52it/s, psnr=23.86]


Dataset: Set5, Avg PSNR: 26.20 dB, Avg SSIM: 0.7551


Validation: 100%|██████████| 100/100 [00:03<00:00, 26.77it/s, psnr=17.31]

Dataset: Urban100, Avg PSNR: 20.40 dB, Avg SSIM: 0.5514





In [5]:
import os
import glob
import random
import math
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torchmetrics import StructuralSimilarityIndexMeasure
from torchvision.models import vgg19

################################################################################
# MODEL: SRResNet - Reused from training script
################################################################################

def conv(in_c, out_c, k, s=1, p=0):
    return nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=p)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = conv(channels, channels, 3, p=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv(channels, channels, 3, p=1)
    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

class SRResNet(nn.Module):
    def __init__(self, num_blocks=16, upscale_factor=4):
        super().__init__()
        self.conv1 = conv(3, 64, 9, p=4)
        self.relu = nn.ReLU(inplace=True)
        self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_blocks)])
        self.conv2 = conv(64, 64, 3, p=1)
        self.upsample = nn.Sequential(
            conv(64, 64*4, 3, p=1), nn.PixelShuffle(2), nn.ReLU(inplace=True),
            conv(64, 64*4, 3, p=1), nn.PixelShuffle(2), nn.ReLU(inplace=True)
        )
        self.conv3 = conv(64, 3, 9, p=4)
    def forward(self, x):
        x1 = self.relu(self.conv1(x))
        x2 = self.res_blocks(x1)
        x = x1 + self.conv2(x2)
        x = self.upsample(x)
        return self.conv3(x)

################################################################################
# LOSS FUNCTIONS & METRICS - Reused PSNR
################################################################################

def calc_psnr(sr, hr, max_val=1.0):
    mse = nn.functional.mse_loss(sr, hr)
    return 20 * math.log10(max_val) - 10 * math.log10(mse.item() + 1e-10)

#######################################
# Validation Function (Adapted)
#######################################

def validate_model(model, dataloader, device):
    model.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    num_batches = 0
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation")
        for batch in pbar:
            hr = batch['hr'].to(device)
            lr = batch['lr'].to(device)
            sr = model(lr).clamp(0, 1) # Ensure output is in [0, 1]

            # Ensure HR has the same size as SR for PSNR calculation
            if hr.size()[-2:] != sr.size()[-2:]:
                hr = F.interpolate(hr, size=sr.size()[-2:], mode='bicubic', align_corners=False)

            # Calculate PSNR (using the training script's function)
            psnr = calc_psnr(sr, hr, max_val=1.0)
            total_psnr += psnr

            # Calculate SSIM (on RGB)
            total_ssim += ssim_metric(sr, hr).item()
            num_batches += 1
            pbar.set_postfix(psnr=f"{psnr:.2f}")

    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    return avg_psnr, avg_ssim

#######################################
# Dataset Definitions (Adapted for direct LR/HR loading or downsampling)
#######################################

val_transform_hr = transforms.Compose([
    transforms.ToTensor(),
])
val_transform_lr = transforms.Compose([
    transforms.ToTensor(),
])

class ValDataset(Dataset):
    def __init__(self, hr_dir, scale=4, transform_hr=None, transform_lr=None, load_lr_directly=False, lr_dir=None, max_hr_size=None): # Added max_hr_size
        self.hr_dir = hr_dir
        self.scale = scale
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
        self.load_lr_directly = load_lr_directly
        self.lr_dir = lr_dir
        self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
        if load_lr_directly and lr_dir:
            self.lr_images = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                    key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
            assert len(self.hr_images) == len(self.lr_images), "Number of HR and LR images mismatch in direct load mode."
        self.max_hr_size = max_hr_size

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_path = self.hr_images[idx]
        hr_img = Image.open(hr_path).convert('RGB')

        if self.max_hr_size is not None and max(hr_img.size) > self.max_hr_size:
            width, height = hr_img.size
            if width > height:
                new_width = self.max_hr_size
                new_height = int(height * (new_width / width))
            else:
                new_height = self.max_hr_size
                new_width = int(width * (new_height / height))
            hr_img = hr_img.resize((new_width, new_height), Image.LANCZOS)

        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = transforms.ToTensor()(hr_img)

        if self.load_lr_directly and self.lr_dir:
            lr_path = self.lr_images[idx]
            lr_img = Image.open(lr_path).convert('RGB')
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = transforms.ToTensor()(lr_img)
        else:
            lr_width, lr_height = hr_img.width // self.scale, hr_img.height // self.scale
            lr_img = hr_img.resize((lr_width, lr_height), Image.BICUBIC)
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = transforms.ToTensor()(lr_img)

        return {'hr': hr, 'lr': lr}

#######################################
# Main Function for Validation
#######################################

def main():
    # Define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the SRResNet model (using the same parameters as in training)
    model = SRResNet().to(device)

    # Load the pretrained model weights
    pretrained_model_path = "/kaggle/input/srresnet_model/pytorch/default/1/SRResnet_26.01.pth" # Path from the training script
    if os.path.exists(pretrained_model_path):
        model.load_state_dict(torch.load(pretrained_model_path, map_location=device))
        print(f"Loaded pretrained model from: {pretrained_model_path}")
    else:
        print(f"Pretrained model not found at: {pretrained_model_path}")
        return

    model.eval()

    # --- Load Validation Datasets ---
    root_dir = "/kaggle/input"
    val_datasets = {}
    scale = 4
    max_hr_size = 512 # You can adjust this based on your GPU memory

    # DIV2K (HR only, will downsample)
    div2k_val_hr_directory = os.path.join(root_dir, "div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR")
    div2k_dataset = ValDataset(hr_dir=div2k_val_hr_directory, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"DIV2K Dataset size: {len(div2k_dataset)}")
    val_datasets['DIV2K'] = DataLoader(div2k_dataset, batch_size=1, shuffle=False, num_workers=2)

    # BSD100 (HR only, will downsample)
    bsd100_hr_dir = os.path.join(root_dir, "bsd100/bsd100/bicubic_4x/train/HR")
    bsd100_dataset = ValDataset(hr_dir=bsd100_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"BSD100 Dataset size: {len(bsd100_dataset)}")
    val_datasets['BSD100'] = DataLoader(bsd100_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set14 (HR only, will downsample)
    set14_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set14/Set14")
    set14_dataset = ValDataset(hr_dir=set14_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"Set14 Dataset size: {len(set14_dataset)}")
    val_datasets['Set14'] = DataLoader(set14_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set5 (HR only, will downsample)
    set5_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set5/Set5")
    set5_dataset = ValDataset(hr_dir=set5_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"Set5 Dataset size: {len(set5_dataset)}")
    val_datasets['Set5'] = DataLoader(set5_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Urban100 (HR only, will downsample)
    urban100_hr_dir = os.path.join(root_dir, "urban100/Urban 100/X4 Urban100/X4/HIGH x4 URban100")
    urban100_dataset = ValDataset(hr_dir=urban100_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"Urban100 Dataset size: {len(urban100_dataset)}")
    val_datasets['Urban100'] = DataLoader(urban100_dataset, batch_size=1, shuffle=False, num_workers=2)

    print("--- SRResNet Model Validation Results ---")
    for dataset_name, dataloader in val_datasets.items():
        avg_psnr, avg_ssim = validate_model(model, dataloader, device)
        if avg_psnr != -1:
            print(f"Dataset: {dataset_name}, Avg PSNR: {avg_psnr:.2f} dB, Avg SSIM: {avg_ssim:.4f}")

if __name__ == "__main__":
    main()

Using device: cuda
Loaded pretrained model from: /kaggle/input/srresnet_model/pytorch/default/1/SRResnet_26.01.pth
DIV2K Dataset size: 100


  model.load_state_dict(torch.load(pretrained_model_path, map_location=device))


BSD100 Dataset size: 80
Set14 Dataset size: 14
Set5 Dataset size: 5
Urban100 Dataset size: 100
--- SRResNet Model Validation Results ---


Validation: 100%|██████████| 100/100 [00:08<00:00, 12.00it/s, psnr=21.59]


Dataset: DIV2K, Avg PSNR: 24.09 dB, Avg SSIM: 0.7027


Validation: 100%|██████████| 80/80 [00:02<00:00, 28.96it/s, psnr=22.46]


Dataset: BSD100, Avg PSNR: 25.62 dB, Avg SSIM: 0.7030


Validation: 100%|██████████| 14/14 [00:00<00:00, 20.62it/s, psnr=23.94]


Dataset: Set14, Avg PSNR: 25.06 dB, Avg SSIM: 0.6999


Validation: 100%|██████████| 5/5 [00:00<00:00, 21.59it/s, psnr=27.21]


Dataset: Set5, Avg PSNR: 28.29 dB, Avg SSIM: 0.8169


Validation: 100%|██████████| 100/100 [00:04<00:00, 23.20it/s, psnr=21.03]

Dataset: Urban100, Avg PSNR: 21.61 dB, Avg SSIM: 0.6440





In [7]:
import os
import glob
import random
import math
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torchmetrics import StructuralSimilarityIndexMeasure

##############################
# 3. SwinIR Model Definition (Reused from training script)
##############################

# --- Helper Functions for Window Partitioning ---

def window_partition(x, window_size):
    # x: (B, H, W, C)
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0,1,3,2,4,5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0,1,3,2,4,5).contiguous().view(B, H, W, -1)
    return x

# --- Window Attention Module ---

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # e.g. 8
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(((2 * window_size - 1) * (2 * window_size - 1), num_heads))
        )
        coords = torch.stack(torch.meshgrid(torch.arange(window_size), torch.arange(window_size)))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1,2,0).contiguous()
        relative_coords[:,:,0] += window_size - 1
        relative_coords[:,:,1] += window_size - 1
        relative_coords[:,:,0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x):
        # x: (num_windows*B, N, C) where N = window_size*window_size
        B_, N, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size * self.window_size, self.window_size * self.window_size, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).unsqueeze(0)
        attn = attn + relative_position_bias
        attn = attn.softmax(dim=-1)
        x = (attn @ v)
        x = x.transpose(1,2).reshape(B_, N, C)
        x = self.proj(x)
        return x

# --- Swin Transformer Block ---

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=8, shift_size=0, mlp_ratio=4.):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim)
        )

    def forward(self, x):
        # x: (B, H, W, C)
        B, H, W, C = x.shape
        shortcut = x
        x = self.norm1(x.view(B * H * W, C)).view(B, H, W, C)
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        attn_windows = self.attn(x_windows)
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1,2))
        else:
            x = shifted_x
        x = shortcut + x
        x = x + self.mlp(self.norm2(x.view(B * H * W, C))).view(B, H, W, C)
        return x

# --- Residual Swin Transformer Block (RSTB) ---

class RSTB(nn.Module):
    def __init__(self, dim, input_resolution, depth, num_heads, window_size=8, mlp_ratio=4.):
        super().__init__()
        self.blocks = nn.Sequential(*[
            SwinTransformerBlock(dim, input_resolution, num_heads, window_size,
                                    shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio)
            for i in range(depth)
        ])
        self.conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1)

    def forward(self, x):
        # x: (B, C, H, W)
        res = x
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).contiguous()
        x = self.blocks(x)
        x = x.permute(0, 3, 1, 2).contiguous()
        return x + self.conv(res)

# --- SwinIR Main Model ---

class SwinIR(nn.Module):
    def __init__(self, upscale=4, in_channels=3, embed_dim=96, depths=[8,8,8,8],
                 num_heads=[6,6,6,6], window_size=8, mlp_ratio=4.):
        super(SwinIR, self).__init__()
        self.upscale = upscale
        self.conv_first = nn.Conv2d(in_channels, embed_dim, kernel_size=3, padding=1)
        self.feature_extraction = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1)
        # Assume HR crop size is 192x192.
        input_resolution = (192, 192)
        self.blocks = nn.ModuleList()
        for i in range(len(depths)):
            self.blocks.append(RSTB(dim=embed_dim, input_resolution=input_resolution,
                                     depth=depths[i], num_heads=num_heads[i],
                                     window_size=window_size, mlp_ratio=mlp_ratio))
        self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1)
        if upscale == 4:
            self.upsample = nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim * 4, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
                nn.Conv2d(embed_dim, embed_dim * 4, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
            )
        elif upscale == 2:
            self.upsample = nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim * 4, kernel_size=3, padding=1),
                nn.PixelShuffle(2),
            )
        else:
            raise NotImplementedError(f"Upscale factor {upscale} not supported.")
        self.conv_last = nn.Conv2d(embed_dim, in_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv_first(x)
        x = self.feature_extraction(x)
        residual = x
        for block in self.blocks:
            x = block(x)
        x = self.conv_after_body(x)
        x = x + residual
        x = self.upsample(x)
        x = self.conv_last(x)
        return x

##############################
# 2. PSNR Calculation Function (Reused from training script)
##############################

def calculate_psnr(sr, hr, max_val=1.0):
    mse = F.mse_loss(sr, hr)
    if mse == 0:
        return 100
    psnr = 10 * torch.log10((max_val ** 2) / mse)
    return psnr.item()

#######################################
# Validation Function (Adapted)
#######################################

def validate_model(model, dataloader, device):
    model.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    num_batches = 0
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation")
        for batch in pbar:
            hr = batch['hr'].to(device)
            lr = batch['lr'].to(device)
            sr = model(lr).clamp(0, 1) # Ensure output is in [0, 1]

            # Ensure HR has the same size as SR for PSNR calculation
            if hr.size()[-2:] != sr.size()[-2:]:
                hr = F.interpolate(hr, size=sr.size()[-2:], mode='bicubic', align_corners=False)

            # Calculate PSNR (using the training script's function)
            psnr = calculate_psnr(sr, hr, max_val=1.0)
            total_psnr += psnr

            # Calculate SSIM (on RGB)
            total_ssim += ssim_metric(sr, hr).item()
            num_batches += 1
            pbar.set_postfix(psnr=f"{psnr:.2f}")

    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    return avg_psnr, avg_ssim

#######################################
# Dataset Definitions (Adapted for direct LR/HR loading or downsampling)
#######################################

val_transform_hr = transforms.Compose([
    transforms.ToTensor(),
])
val_transform_lr = transforms.Compose([
    transforms.ToTensor(),
])

class ValDataset(Dataset):
    def __init__(self, hr_dir, scale=4, transform_hr=None, transform_lr=None, load_lr_directly=False, lr_dir=None, max_hr_size=None): # Added max_hr_size
        self.hr_dir = hr_dir
        self.scale = scale
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
        self.load_lr_directly = load_lr_directly
        self.lr_dir = lr_dir
        self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
        if load_lr_directly and lr_dir:
            self.lr_images = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                    key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
            assert len(self.hr_images) == len(self.lr_images), "Number of HR and LR images mismatch in direct load mode."
        self.max_hr_size = max_hr_size

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_path = self.hr_images[idx]
        hr_img = Image.open(hr_path).convert('RGB')
        width, height = hr_img.size

        # Resize HR image to the nearest multiple of 32
        if self.max_hr_size is not None:
            max_dim = max(width, height)
            if max_dim > self.max_hr_size:
                if width > height:
                    new_width = (self.max_hr_size // 32) * 32
                    new_height = int(height * (new_width / width))
                    new_height = (new_height // 32) * 32
                else:
                    new_height = (self.max_hr_size // 32) * 32
                    new_width = int(width * (new_height / height))
                    new_width = (new_width // 32) * 32
                hr_img = hr_img.resize((new_width, new_height), Image.LANCZOS)
                width, height = hr_img.size # Update width and height
            elif width % 32 != 0 or height % 32 != 0:
                new_width = (width // 32) * 32
                new_height = (height // 32) * 32
                if new_width == 0 or new_height == 0: # Handle cases where original size is smaller than 32
                    new_width = max(32, width)
                    new_height = max(32, height)
                    new_width = (new_width // 32) * 32
                    new_height = (new_height // 32) * 32

                hr_img = hr_img.resize((new_width, new_height), Image.LANCZOS)
                width, height = hr_img.size # Update width and height
        elif width % 32 != 0 or height % 32 != 0:
            new_width = (width // 32) * 32
            new_height = (height // 32) * 32
            if new_width == 0 or new_height == 0: # Handle cases where original size is smaller than 32
                new_width = max(32, width)
                new_height = max(32, height)
                new_width = (new_width // 32) * 32
                new_height = (new_height // 32) * 32
            hr_img = hr_img.resize((new_width, new_height), Image.LANCZOS)
            width, height = hr_img.size # Update width and height


        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = transforms.ToTensor()(hr_img)

        if self.load_lr_directly and self.lr_dir:
            lr_path = self.lr_images[idx]
            lr_img = Image.open(lr_path).convert('RGB')
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = transforms.ToTensor()(lr_img)
        else:
            lr_width, lr_height = width // self.scale, height // self.scale
            lr_img = hr_img.resize((lr_width, lr_height), Image.BICUBIC)
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = transforms.ToTensor()(lr_img)

        return {'hr': hr, 'lr': lr}

#######################################
# Main Function for Validation
#######################################

def main():
    # Define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the SwinIR model (using the same parameters as in training)
    scale = 4
    model = SwinIR(upscale=scale, in_channels=3, embed_dim=96, depths=[8,8,8,8],
                   num_heads=[6,6,6,6], window_size=8, mlp_ratio=4.).to(device)

    # Load the pretrained model weights
    pretrained_model_path = "/kaggle/input/swinir_model/pytorch/default/1/SWINIR_25.56.pth" # Path from the user's output
    if os.path.exists(pretrained_model_path):
        model.load_state_dict(torch.load(pretrained_model_path, map_location=device))
        print(f"Loaded pretrained model from: {pretrained_model_path}")
    else:
        print(f"Pretrained model not found at: {pretrained_model_path}")
        return

    model.eval()

    # --- Load Validation Datasets ---
    root_dir = "/kaggle/input"
    val_datasets = {}
    max_hr_size = 512 # You can adjust this based on your GPU memory

    # DIV2K (HR only, will downsample)
    div2k_val_hr_directory = os.path.join(root_dir, "div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR")
    div2k_dataset = ValDataset(hr_dir=div2k_val_hr_directory, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"DIV2K Dataset size: {len(div2k_dataset)}")
    val_datasets['DIV2K'] = DataLoader(div2k_dataset, batch_size=1, shuffle=False, num_workers=2)

    # BSD100 (HR only, will downsample)
    bsd100_hr_dir = os.path.join(root_dir, "bsd100/bsd100/bicubic_4x/train/HR")
    bsd100_dataset = ValDataset(hr_dir=bsd100_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"BSD100 Dataset size: {len(bsd100_dataset)}")
    val_datasets['BSD100'] = DataLoader(bsd100_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set14 (HR only, will downsample)
    set14_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set14/Set14")
    set14_dataset = ValDataset(hr_dir=set14_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"Set14 Dataset size: {len(set14_dataset)}")
    val_datasets['Set14'] = DataLoader(set14_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Set5 (HR only, will downsample)
    set5_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set5/Set5")
    set5_dataset = ValDataset(hr_dir=set5_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"Set5 Dataset size: {len(set5_dataset)}")
    val_datasets['Set5'] = DataLoader(set5_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Urban100 (HR only, will downsample)
    urban100_hr_dir = os.path.join(root_dir, "urban100/Urban 100/X4 Urban100/X4/HIGH x4 URban100")
    urban100_dataset = ValDataset(hr_dir=urban100_hr_dir, scale=scale, transform_hr=val_transform_hr, max_hr_size=max_hr_size)
    print(f"Urban100 Dataset size: {len(urban100_dataset)}")
    val_datasets['Urban100'] = DataLoader(urban100_dataset, batch_size=1, shuffle=False, num_workers=2)

    print("--- SwinIR Model Validation Results ---")
    for dataset_name, dataloader in val_datasets.items():
        avg_psnr, avg_ssim = validate_model(model, dataloader, device)
        if avg_psnr != -1:
            print(f"Dataset: {dataset_name}, Avg PSNR: {avg_psnr:.2f} dB, Avg SSIM: {avg_ssim:.4f}")

if __name__ == "__main__":
    main()

Using device: cuda
Loaded pretrained model from: /kaggle/input/swinir_model/pytorch/default/1/SWINIR_25.56.pth
DIV2K Dataset size: 100
BSD100 Dataset size: 80


  model.load_state_dict(torch.load(pretrained_model_path, map_location=device))


Set14 Dataset size: 14
Set5 Dataset size: 5
Urban100 Dataset size: 100
--- SwinIR Model Validation Results ---


Validation: 100%|██████████| 100/100 [00:10<00:00,  9.88it/s, psnr=21.35]


Dataset: DIV2K, Avg PSNR: 23.80 dB, Avg SSIM: 0.6829


Validation: 100%|██████████| 80/80 [00:07<00:00, 11.20it/s, psnr=22.16]


Dataset: BSD100, Avg PSNR: 25.30 dB, Avg SSIM: 0.6835


Validation: 100%|██████████| 14/14 [00:01<00:00,  9.03it/s, psnr=23.31]


Dataset: Set14, Avg PSNR: 24.86 dB, Avg SSIM: 0.6859


Validation: 100%|██████████| 5/5 [00:00<00:00, 11.66it/s, psnr=26.31]


Dataset: Set5, Avg PSNR: 27.90 dB, Avg SSIM: 0.8006


Validation: 100%|██████████| 100/100 [00:10<00:00,  9.28it/s, psnr=20.50]

Dataset: Urban100, Avg PSNR: 21.33 dB, Avg SSIM: 0.6166





In [14]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
from torchvision.transforms import ToPILImage, ToTensor, Compose
from PIL import Image
import glob
import random
import math
import torch.nn.functional as F
from torchmetrics import StructuralSimilarityIndexMeasure

# Channel Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(num_channels, num_channels // reduction_ratio, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(num_channels // reduction_ratio, num_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, num_channels, height, width = x.size()
        out = self.global_avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        weight = self.sigmoid(out)
        return x * weight

# Residual Channel Attention Block (RCAB)
class ResidualChannelAttentionBlock(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16):
        super(ResidualChannelAttentionBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.ca = ChannelAttention(num_channels, reduction_ratio)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.ca(out)
        return out + residual

# Residual Group (RG)
class ResidualGroup(nn.Module):
    def __init__(self, num_channels, num_rcab, reduction_ratio=16):
        super(ResidualGroup, self).__init__()
        layers = [ResidualChannelAttentionBlock(num_channels, reduction_ratio) for _ in range(num_rcab)]
        self.rcabs = nn.Sequential(*layers)
        self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.rcabs(x)
        out = self.conv(out)
        return out + residual

# RCAN-inspired Network
class RCANNet(nn.Module):
    def __init__(self, num_channels=3, num_filters=64, num_rg=10, num_rcab=20, scale_factor=4, reduction_ratio=16):
        super(RCANNet, self).__init__()
        self.num_filters = num_filters
        self.scale_factor = scale_factor
        self.initial_conv = nn.Conv2d(num_channels, num_filters, kernel_size=3, padding=1)
        self.residual_groups = nn.ModuleList([
            ResidualGroup(num_filters, num_rcab, reduction_ratio) for _ in range(num_rg)
        ])
        self.conv_after_rg = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Conv2d(num_filters, num_filters * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
        self.final_conv = nn.Conv2d(num_filters, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.initial_conv(x)
        residual = x
        for rg in self.residual_groups:
            x = rg(x)
        x = self.conv_after_rg(x)
        x += residual
        x = self.upsample(x)
        x = self.final_conv(x)
        return x

# Function to calculate PSNR
def calculate_psnr(img1, img2):
    img1 = img1.mul(255).byte().cpu().numpy()
    img2 = img2.mul(255).byte().cpu().numpy()
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse))

#######################################
# Validation Function
#######################################

def validate_model(dataloader, model, device):
    model.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    num_batches = 0
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation")
        for batch in pbar:
            hr_images = batch['hr'].to(device)
            lr_images = batch['lr'].to(device)
            outputs = model(lr_images).clamp(0, 1)

            # Resize HR images to match the output size
            resized_hr_images = F.interpolate(hr_images, size=outputs.shape[2:], mode='bicubic', align_corners=False)

            # Calculate PSNR and SSIM for the entire batch
            current_batch_size = lr_images.size(0)
            for i in range(current_batch_size):
                total_psnr += calculate_psnr(outputs[i].cpu().float(), resized_hr_images[i].cpu().float())
                total_ssim += ssim_metric(outputs[i].unsqueeze(0), resized_hr_images[i].unsqueeze(0)).item()
            num_batches += current_batch_size

    if num_batches == 0:
        return -1, -1  # Return -1 to indicate no data was processed
    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    return avg_psnr, avg_ssim

#######################################
# Dataset Definitions (Using ValDataset from previous SwinIR script)
#######################################

val_transform_hr = Compose([
    ToTensor(),
])
val_transform_lr = Compose([
    ToTensor(),
])

class ValDataset(Dataset):
    def __init__(self, hr_dir, scale=4, transform_hr=None, transform_lr=None, load_lr_directly=False, lr_dir=None):
        self.hr_dir = hr_dir
        self.scale = scale
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
        self.load_lr_directly = load_lr_directly
        self.lr_dir = lr_dir
        self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
        if load_lr_directly and lr_dir:
            self.lr_images = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                    key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
            assert len(self.hr_images) == len(self.lr_images), "Number of HR and LR images mismatch in direct load mode."

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_path = self.hr_images[idx]
        hr_img = Image.open(hr_path).convert('RGB')

        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = ToTensor()(hr_img)

        if self.load_lr_directly and self.lr_dir:
            lr_path = self.lr_images[idx]
            lr_img = Image.open(lr_path).convert('RGB')
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = ToTensor()(lr_img)
        else:
            width, height = hr_img.size
            lr_width, lr_height = width // self.scale, height // self.scale
            lr_img = hr_img.resize((lr_width, lr_height), Image.BICUBIC)
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = ToTensor()(lr_img)

        return {'hr': hr, 'lr': lr}

#######################################
# Main Function for Validation
#######################################

if __name__ == '__main__':
    # Hyperparameters
    batch_size = 1 # Using batch size 1 as requested
    scale_factor = 4
    num_filters = 64
    num_rg = 10
    num_rcab = 20
    reduction_ratio = 16
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pretrained_model_path = "/kaggle/input/rcan_model/pytorch/default/1/best_model (2).pth" # Updated path
    root_dir = "/kaggle/input"

    # Initialize model
    model = RCANNet(num_channels=3, num_filters=num_filters, num_rg=num_rg, num_rcab=num_rcab, scale_factor=scale_factor, reduction_ratio=reduction_ratio)

    # Load pretrained model weights
    if os.path.exists(pretrained_model_path):
        checkpoint = torch.load(pretrained_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        print(f"Loaded pretrained model from: {pretrained_model_path}")
    else:
        print(f"Pretrained model not found at: {pretrained_model_path}")
        exit()

    # Define validation datasets and dataloaders
    val_datasets = {}

    # DIV2K (HR only, will downsample)
    div2k_val_hr_directory = os.path.join(root_dir, "div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR")
    div2k_dataset = ValDataset(hr_dir=div2k_val_hr_directory, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"DIV2K Dataset size: {len(div2k_dataset)}")
    val_datasets['DIV2K'] = DataLoader(div2k_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    # BSD100 (HR only, will downsample)
    bsd100_hr_dir = os.path.join(root_dir, "bsd100/bsd100/bicubic_4x/train/HR")
    bsd100_dataset = ValDataset(hr_dir=bsd100_hr_dir, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"BSD100 Dataset size: {len(bsd100_dataset)}")
    val_datasets['BSD100'] = DataLoader(bsd100_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    # Set14 (HR only, will downsample)
    set14_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set14/Set14")
    set14_dataset = ValDataset(hr_dir=set14_hr_dir, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"Set14 Dataset size: {len(set14_dataset)}")
    val_datasets['Set14'] = DataLoader(set14_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    # Set5 (HR only, will downsample)
    set5_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set5/Set5")
    set5_dataset = ValDataset(hr_dir=set5_hr_dir, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"Set5 Dataset size: {len(set5_dataset)}")
    val_datasets['Set5'] = DataLoader(set5_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    # Urban100 (HR only, will downsample)
    urban100_hr_dir = os.path.join(root_dir, "urban100/Urban 100/X4 Urban100/X4/HIGH x4 URban100")
    urban100_dataset = ValDataset(hr_dir=urban100_hr_dir, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"Urban100 Dataset size: {len(urban100_dataset)}")
    val_datasets['Urban100'] = DataLoader(urban100_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    # Validate the model on all datasets
    print("--- RCAN Model Validation Results ---")
    for dataset_name, dataloader in val_datasets.items():
        avg_psnr, avg_ssim = validate_model(dataloader, model, device)
        if avg_psnr != -1:
            print(f"Dataset: {dataset_name}, Avg PSNR: {avg_psnr:.4f} dB, Avg SSIM: {avg_ssim:.4f}")
        else:
            print(f"Dataset: {dataset_name}, No data found for validation.")

    print("Validation finished!")

  checkpoint = torch.load(pretrained_model_path, map_location=device)


Loaded pretrained model from: /kaggle/input/rcan_model/pytorch/default/1/best_model (2).pth
DIV2K Dataset size: 100
BSD100 Dataset size: 80
Set14 Dataset size: 14
Set5 Dataset size: 5
Urban100 Dataset size: 100
--- RCAN Model Validation Results ---


Validation: 100%|██████████| 100/100 [03:03<00:00,  1.83s/it]


Dataset: DIV2K, Avg PSNR: 33.1956 dB, Avg SSIM: 0.8091


Validation: 100%|██████████| 80/80 [00:07<00:00, 10.90it/s]


Dataset: BSD100, Avg PSNR: 31.6858 dB, Avg SSIM: 0.7306


Validation: 100%|██████████| 14/14 [00:02<00:00,  6.15it/s]


Dataset: Set14, Avg PSNR: 31.5765 dB, Avg SSIM: 0.7370


Validation: 100%|██████████| 5/5 [00:00<00:00,  7.54it/s]


Dataset: Set5, Avg PSNR: 32.3066 dB, Avg SSIM: 0.8417


Validation: 100%|██████████| 100/100 [00:45<00:00,  2.19it/s]

Dataset: Urban100, Avg PSNR: 31.2350 dB, Avg SSIM: 0.7392
Validation finished!





In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
from torchvision.transforms import ToPILImage, ToTensor, Compose
from PIL import Image
import torch.nn.functional as F
from torchmetrics import StructuralSimilarityIndexMeasure

# Channel Attention Module (Reused)
class ChannelAttention(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(num_channels, num_channels // reduction_ratio, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(num_channels // reduction_ratio, num_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, num_channels, height, width = x.size()
        out = self.global_avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        weight = self.sigmoid(out)
        return x * weight

# Simplified Spatial Attention Module (Convolution-based)
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return self.sigmoid(out)

# Hybrid Attention Block (Simplified)
class HybridAttentionBlock(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16, spatial_kernel_size=7):
        super(HybridAttentionBlock, self).__init__()
        self.channel_attention = ChannelAttention(num_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(spatial_kernel_size)
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out_ca = self.channel_attention(out)
        out_sa = self.spatial_attention(out)
        out = out_ca * out_sa
        return out + residual

# HAT-Inspired Network
class HATInspiredNet(nn.Module):
    def __init__(self, num_channels=3, num_filters=64, num_hab=10, scale_factor=4, reduction_ratio=16, spatial_kernel_size=7):
        super(HATInspiredNet, self).__init__()
        self.num_filters = num_filters
        self.scale_factor = scale_factor
        self.initial_conv = nn.Conv2d(num_channels, num_filters, kernel_size=3, padding=1)
        self.hab_blocks = nn.Sequential(*[
            HybridAttentionBlock(num_filters, reduction_ratio, spatial_kernel_size) for _ in range(num_hab)
        ])
        self.conv_after_hab = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Conv2d(num_filters, num_filters * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
        self.final_conv = nn.Conv2d(num_filters, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.initial_conv(x)
        residual = x
        x = self.hab_blocks(x)
        x = self.conv_after_hab(x)
        x += residual
        x = self.upsample(x)
        x = self.final_conv(x)
        return x

# Residual Channel Attention Block (RCAB)
class ResidualChannelAttentionBlock(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16):
        super(ResidualChannelAttentionBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.ca = ChannelAttention(num_channels, reduction_ratio)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.ca(out)
        return out + residual

# Residual Group (RG)
class ResidualGroup(nn.Module):
    def __init__(self, num_channels, num_rcab, reduction_ratio=16):
        super(ResidualGroup, self).__init__()
        layers = [ResidualChannelAttentionBlock(num_channels, reduction_ratio) for _ in range(num_rcab)]
        self.rcabs = nn.Sequential(*layers)
        self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.rcabs(x)
        out = self.conv(out)
        return out + residual

# RCAN-inspired Network
class RCANNet(nn.Module):
    def __init__(self, num_channels=3, num_filters=64, num_rg=10, num_rcab=20, scale_factor=4, reduction_ratio=16):
        super(RCANNet, self).__init__()
        self.num_filters = num_filters
        self.scale_factor = scale_factor
        self.initial_conv = nn.Conv2d(num_channels, num_filters, kernel_size=3, padding=1)
        self.residual_groups = nn.ModuleList([
            ResidualGroup(num_filters, num_rcab, reduction_ratio) for _ in range(num_rg)
        ])
        self.conv_after_rg = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Conv2d(num_filters, num_filters * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
        self.final_conv = nn.Conv2d(num_filters, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.initial_conv(x)
        residual = x
        for rg in self.residual_groups:
            x = rg(x)
        x = self.conv_after_rg(x)
        x += residual
        x = self.upsample(x)
        x = self.final_conv(x)
        return x

# Simplified SwinIR Block (Convolutional Approximation)
class SwinIRBlock(nn.Module):
    def __init__(self, num_channels, window_size=8, reduction_ratio=4):
        super(SwinIRBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.ca = ChannelAttention(num_channels, reduction_ratio)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.ca(out)
        return out + residual

# SwinIR-Inspired Network (Simplified)
class SwinIRInspiredNet(nn.Module):
    def __init__(self, num_channels=3, num_filters=64, num_blocks=10, scale_factor=4, window_size=8):
        super(SwinIRInspiredNet, self).__init__()
        self.num_filters = num_filters
        self.scale_factor = scale_factor
        self.initial_conv = nn.Conv2d(num_channels, num_filters, kernel_size=3, padding=1)
        self.swinir_blocks = nn.Sequential(*[
            SwinIRBlock(num_filters, window_size) for _ in range(num_blocks)
        ])
        self.conv_after_blocks = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Conv2d(num_filters, num_filters * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
        self.final_conv = nn.Conv2d(num_filters, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.initial_conv(x)
        residual = x
        x = self.swinir_blocks(x)
        x = self.conv_after_blocks(x)
        x += residual
        x = self.upsample(x)
        x = self.final_conv(x)
        return x

# Function to calculate PSNR (Reused)
def calculate_psnr(img1, img2):
    img1 = img1.mul(255).byte().cpu().numpy()
    img2 = img2.mul(255).byte().cpu().numpy()
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse))

#######################################
# Validation Function for Ensemble Model
#######################################

def validate_ensemble_model(dataloader, hat_model, rcan_model, swinir_model, device):
    hat_model.eval()
    rcan_model.eval()
    swinir_model.eval()
    total_psnr = 0.0
    total_ssim = 0.0
    num_batches = 0
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation Ensemble")
        for batch in pbar:
            hr_images = batch['hr'].to(device)
            lr_images = batch['lr'].to(device)

            hat_output = hat_model(lr_images).clamp(0, 1)
            rcan_output = rcan_model(lr_images).clamp(0, 1)
            swinir_output = swinir_model(lr_images).clamp(0, 1)

            # Average the outputs of the three models
            avg_output = (hat_output + rcan_output + swinir_output) / 3

            # Resize HR images to match the output size
            resized_hr_images = F.interpolate(hr_images, size=avg_output.shape[2:], mode='bicubic', align_corners=False)

            # Calculate PSNR and SSIM for the entire batch
            current_batch_size = lr_images.size(0)
            for i in range(current_batch_size):
                total_psnr += calculate_psnr(avg_output[i].cpu().float(), resized_hr_images[i].cpu().float())
                total_ssim += ssim_metric(avg_output[i].unsqueeze(0), resized_hr_images[i].unsqueeze(0)).item()
            num_batches += current_batch_size

    if num_batches == 0:
        return -1, -1  # Return -1 to indicate no data was processed
    avg_psnr = total_psnr / num_batches
    avg_ssim = total_ssim / num_batches
    return avg_psnr, avg_ssim

#######################################
# Dataset Definitions (Reused)
#######################################

val_transform_hr = Compose([
    ToTensor(),
])
val_transform_lr = Compose([
    ToTensor(),
])

class ValDataset(Dataset):
    def __init__(self, hr_dir, scale=4, transform_hr=None, transform_lr=None, load_lr_directly=False, lr_dir=None):
        self.hr_dir = hr_dir
        self.scale = scale
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr
        self.load_lr_directly = load_lr_directly
        self.lr_dir = lr_dir
        self.hr_images = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
        if load_lr_directly and lr_dir:
            self.lr_images = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))],
                                    key=lambda path: int(os.path.splitext(os.path.basename(path))[0]) if os.path.splitext(os.path.basename(path))[0].isdigit() else os.path.splitext(os.path.basename(path))[0])
            assert len(self.hr_images) == len(self.lr_images), "Number of HR and LR images mismatch in direct load mode."

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_path = self.hr_images[idx]
        hr_img = Image.open(hr_path).convert('RGB')

        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = ToTensor()(hr_img)

        if self.load_lr_directly and self.lr_dir:
            lr_path = self.lr_images[idx]
            lr_img = Image.open(lr_path).convert('RGB')
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = ToTensor()(lr_img)
        else:
            width, height = hr_img.size
            lr_width, lr_height = width // self.scale, height // self.scale
            lr_img = hr_img.resize((lr_width, lr_height), Image.BICUBIC)
            if self.transform_lr:
                lr = self.transform_lr(lr_img)
            else:
                lr = ToTensor()(lr_img)

        return {'hr': hr, 'lr': lr}

#######################################
# Main Function for Ensemble Validation
#######################################

if __name__ == '__main__':
    # Hyperparameters
    batch_size = 1 # Using batch size 1 for validation
    scale_factor = 4
    num_filters = 64
    num_hab = 10
    reduction_ratio_hat = 16
    spatial_kernel_size = 7
    num_rg = 10
    num_rcab = 20
    reduction_ratio_rcan = 16
    num_blocks_swinir = 10
    window_size_swinir = 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pretrained_model_path = "/kaggle/input/ensemble_model_best/pytorch/default/1/best_ensemble_model (2).pth" # Path to the saved ensemble checkpoint
    root_dir = "/kaggle/input"

    # Initialize individual models
    hat_model = HATInspiredNet(num_channels=3, num_filters=num_filters, num_hab=num_hab, scale_factor=scale_factor, reduction_ratio=reduction_ratio_hat, spatial_kernel_size=spatial_kernel_size)
    rcan_model = RCANNet(num_channels=3, num_filters=num_filters, num_rg=num_rg, num_rcab=num_rcab, scale_factor=scale_factor, reduction_ratio=reduction_ratio_rcan)
    swinir_model = SwinIRInspiredNet(num_channels=3, num_filters=num_filters, num_blocks=num_blocks_swinir, scale_factor=scale_factor, window_size=window_size_swinir)

    # Load pretrained ensemble model weights
    if os.path.exists(pretrained_model_path):
        checkpoint = torch.load(pretrained_model_path, map_location=device)
        hat_model.load_state_dict(checkpoint['hat_state_dict'])
        rcan_model.load_state_dict(checkpoint['rcan_state_dict'])
        swinir_model.load_state_dict(checkpoint['swinir_state_dict'])
        hat_model.to(device)
        rcan_model.to(device)
        swinir_model.to(device)
        print(f"Loaded pretrained ensemble model from: {pretrained_model_path}")
    else:
        print(f"Pretrained ensemble model not found at: {pretrained_model_path}")
        exit()

    # Define validation datasets and dataloaders
    val_datasets = {}

    # DIV2K (HR only, will downsample)
    div2k_val_hr_directory = os.path.join(root_dir, "div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR")
    div2k_dataset = ValDataset(hr_dir=div2k_val_hr_directory, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"DIV2K Dataset size: {len(div2k_dataset)}")
    val_datasets['DIV2K'] = DataLoader(div2k_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    # BSD100 (HR only, will downsample)
    bsd100_hr_dir = os.path.join(root_dir, "bsd100/bsd100/bicubic_4x/train/HR")
    bsd100_dataset = ValDataset(hr_dir=bsd100_hr_dir, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"BSD100 Dataset size: {len(bsd100_dataset)}")
    val_datasets['BSD100'] = DataLoader(bsd100_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    # Set14 (HR only, will downsample)
    set14_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set14/Set14")
    set14_dataset = ValDataset(hr_dir=set14_hr_dir, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"Set14 Dataset size: {len(set14_dataset)}")
    val_datasets['Set14'] = DataLoader(set14_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    # Set5 (HR only, will downsample)
    set5_hr_dir = os.path.join(root_dir, "set-5-14-super-resolution-dataset/Set5/Set5")
    set5_dataset = ValDataset(hr_dir=set5_hr_dir, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"Set5 Dataset size: {len(set5_dataset)}")
    val_datasets['Set5'] = DataLoader(set5_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    # Urban100 (HR only, will downsample)
    urban100_hr_dir = os.path.join(root_dir, "urban100/Urban 100/X4 Urban100/X4/HIGH x4 URban100")
    urban100_dataset = ValDataset(hr_dir=urban100_hr_dir, scale=scale_factor, transform_hr=val_transform_hr)
    print(f"Urban100 Dataset size: {len(urban100_dataset)}")
    val_datasets['Urban100'] = DataLoader(urban100_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    # Validate the ensemble model on all datasets
    print("--- Ensemble Model Validation Results ---")
    for dataset_name, dataloader in val_datasets.items():
        avg_psnr, avg_ssim = validate_ensemble_model(dataloader, hat_model, rcan_model, swinir_model, device)
        if avg_psnr != -1:
            print(f"Dataset: {dataset_name}, Avg PSNR: {avg_psnr:.4f} dB, Avg SSIM: {avg_ssim:.4f}")
        else:
            print(f"Dataset: {dataset_name}, No data found for validation.")

    print("Ensemble validation finished!")

  checkpoint = torch.load(pretrained_model_path, map_location=device)


Loaded pretrained ensemble model from: /kaggle/input/ensemble_model_best/pytorch/default/1/best_ensemble_model (2).pth
DIV2K Dataset size: 100
BSD100 Dataset size: 80
Set14 Dataset size: 14
Set5 Dataset size: 5
Urban100 Dataset size: 100
--- Ensemble Model Validation Results ---


Validation Ensemble: 100%|██████████| 100/100 [04:05<00:00,  2.45s/it]


Dataset: DIV2K, Avg PSNR: 33.2126 dB, Avg SSIM: 0.8099


Validation Ensemble: 100%|██████████| 80/80 [00:08<00:00,  9.08it/s]


Dataset: BSD100, Avg PSNR: 31.6892 dB, Avg SSIM: 0.7309


Validation Ensemble: 100%|██████████| 14/14 [00:02<00:00,  5.16it/s]


Dataset: Set14, Avg PSNR: 31.6018 dB, Avg SSIM: 0.7385


Validation Ensemble: 100%|██████████| 5/5 [00:00<00:00,  7.18it/s]


Dataset: Set5, Avg PSNR: 32.3534 dB, Avg SSIM: 0.8446


Validation Ensemble: 100%|██████████| 100/100 [01:02<00:00,  1.61it/s]

Dataset: Urban100, Avg PSNR: 31.2291 dB, Avg SSIM: 0.7383
Ensemble validation finished!



