In [2]:
pip install tabulate

Collecting tabulate
  Downloading tabulate-0.9.0-py3-none-any.whl.metadata (34 kB)
Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)
Installing collected packages: tabulate
Successfully installed tabulate-0.9.0
[0mNote: you may need to restart the kernel to use updated packages.


In [38]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import skimage.metrics
from tabulate import tabulate
import matplotlib.pyplot as plt

# --- Model Components ---

class LayerNorm2d(nn.Module):
    """2D Layer Normalization module."""
    def __init__(self, num_features):
        super().__init__()
        self.norm = nn.LayerNorm(num_features)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)
        return x

class SpatialResidualModule(nn.Module):
    """Spatial attention-based residual module."""
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.PReLU(),
            nn.Conv2d(channels, channels, 3, padding=1)
        )
        self.spatial_att = nn.Sequential(
            nn.Conv2d(channels, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.scale = nn.Parameter(torch.FloatTensor([0.1]))

    def forward(self, x):
        residual = x
        out = self.conv(x)
        att = self.spatial_att(out)
        return residual + self.scale * (out * att)

class EnhancedResidualBlock(nn.Module):
    """Enhanced residual block with spatial residual module."""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = LayerNorm2d(channels)
        self.norm2 = LayerNorm2d(channels)
        self.act = nn.PReLU()
        self.spatial_residual = SpatialResidualModule(channels)
        self.scale = nn.Parameter(torch.FloatTensor([0.1]))

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.act(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.spatial_residual(out)
        return residual + self.scale * out

class EnhancedResidualGroup(nn.Module):
    """Group of enhanced residual blocks."""
    def __init__(self, channels, n_blocks):
        super().__init__()
        blocks = [EnhancedResidualBlock(channels) for _ in range(n_blocks)]
        self.body = nn.Sequential(*blocks)
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
        self.scale = nn.Parameter(torch.FloatTensor([0.1]))

    def forward(self, x):
        residual = x
        out = self.body(x)
        out = self.conv(out)
        return residual + self.scale * out

class EnhancedESPCN(nn.Module):
    """Enhanced Efficient Sub-Pixel Convolutional Neural Network."""
    def __init__(self, in_channels, scale_factor=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 128, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(64, 3 * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
    
    def forward(self, x):
        return self.net(x)

# --- Main Model ---

class UltraEnhancedSR(nn.Module):
    """Ultra-enhanced super-resolution model."""
    def __init__(self, scale=2):
        super().__init__()
        self.head = nn.Sequential(
            nn.Conv2d(3, 128, 3, padding=1),
            nn.PReLU()
        )
        self.body = nn.ModuleList([
            EnhancedResidualGroup(128, 10) for _ in range(5)
        ])
        self.global_residual = nn.Conv2d(128, 128, 3, padding=1)
        self.upscale = EnhancedESPCN(128, scale)
        self.direct_path = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.PReLU(),
            nn.Conv2d(16, 3 * (scale ** 2), 3, padding=1),
            nn.PixelShuffle(scale)
        )
        self.refine = nn.Sequential(
            nn.Conv2d(6, 32, 3, padding=1),
            nn.PReLU(),
            nn.Conv2d(32, 3, 3, padding=1)
        )

    def forward(self, x):
        direct = self.direct_path(x)
        shallow = self.head(x)
        deep = shallow
        for block in self.body:
            deep = block(deep)
        global_res = self.global_residual(deep)
        fused = shallow + global_res
        upscaled = self.upscale(fused)
        combined = torch.cat([direct, upscaled], dim=1)
        return self.refine(combined)

# --- Utilities ---

class SRValidationDataset(Dataset):
    """Dataset for super-resolution validation."""
    def __init__(self, base_dir, scale=2):
        lr_dir = os.path.join(base_dir, "LR")
        hr_dir = os.path.join(base_dir, "HR")
        
        if not os.path.exists(lr_dir) or not os.path.exists(hr_dir):
            raise ValueError(f"LR or HR directory not found in {base_dir}")
        
        lr_paths = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        hr_paths = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
        
        self.pairs = []
        for lr_path in lr_paths:
            lr_fname = os.path.basename(lr_path)
            hr_path = os.path.join(hr_dir, lr_fname)
            if hr_path in hr_paths:
                try:
                    lr_img = Image.open(lr_path).convert('RGB')
                    hr_img = Image.open(hr_path).convert('RGB')
                    if lr_img.size[0] < 7 or lr_img.size[1] < 7 or hr_img.size[0] < 7 or hr_img.size[1] < 7:
                        print(f"Skipping {lr_fname}: Image too small (LR: {lr_img.size}, HR: {hr_img.size})")
                        continue
                    self.pairs.append((lr_path, hr_path))
                except Exception as e:
                    print(f"Skipping {lr_fname}: Failed to load (Error: {e})")
                    continue
        
        if not self.pairs:
            raise ValueError(f"No valid LR-HR image pairs found in {base_dir}.")
        
        self.scale = scale
        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        lr_path, hr_path = self.pairs[idx]
        try:
            lr = Image.open(lr_path).convert('RGB').resize((512, 512), Image.BICUBIC)
            hr = Image.open(hr_path).convert('RGB').resize((1024, 1024), Image.BICUBIC)
            lr = self.transform(lr)
            hr = self.transform(hr)
            return lr, hr, os.path.basename(lr_path)
        except Exception as e:
            print(f"Error processing {lr_path}: {e}")
            return None, None, "error"

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images."""
    img1_np = img1.squeeze(0).detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy()
    img2_np = img2.squeeze(0).detach().cpu().permute(1, 2, 0).numpy()
    psnr = skimage.metrics.peak_signal_noise_ratio(img2_np, img1_np, data_range=1.0)
    return psnr

def save_comparison(lr, sr, hr, filename, output_dir):
    """Save comparison of LR, SR, and HR images."""
    os.makedirs(output_dir, exist_ok=True)
    lr_img = lr.squeeze(0).cpu().permute(1, 2, 0).numpy().astype(np.float32)
    sr_img = sr.squeeze(0).cpu().clamp(0, 1).permute(1, 2, 0).numpy().astype(np.float32)
    hr_img = hr.squeeze(0).cpu().permute(1, 2, 0).numpy().astype(np.float32)
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    axes[0].imshow(lr_img)
    axes[0].set_title(f'Low Resolution\n{lr_img.shape[1]}x{lr_img.shape[0]}')
    axes[0].axis('off')
    axes[1].imshow(sr_img)
    axes[1].set_title(f'Super Resolution\n{sr_img.shape[1]}x{sr_img.shape[0]}')
    axes[1].axis('off')
    axes[2].imshow(hr_img)
    axes[2].set_title(f'High Resolution\n{hr_img.shape[1]}x{hr_img.shape[0]}')
    axes[2].axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{filename}.png"), dpi=300, bbox_inches='tight')
    plt.close()

def save_super_resolved(sr, filename, output_dir):
    """Save super-resolved image."""
    os.makedirs(output_dir, exist_ok=True)
    sr_img = sr.squeeze(0).cpu().clamp(0, 1).permute(1, 2, 0).numpy().astype(np.float32)
    sr_pil = Image.fromarray((sr_img * 255).astype(np.uint8))
    sr_pil.save(os.path.join(output_dir, f"{filename}_sr.png"))

def validate_model(model, data_loader, output_dir, sr_output_dir):
    """Validate the super-resolution model on the dataset."""
    model.eval()
    results = []
    best_sr_psnr = 0
    best_lr_psnr = 0
    best_image = None
    
    with torch.no_grad():
        for lr, hr, fname in data_loader:
            if lr is None or hr is None:
                print(f"Skipping invalid image: {fname[0]}")
                continue
            lr, hr = lr.cuda(), hr.cuda()
            sr = model(lr)
            
            sr_psnr = calculate_psnr(sr, hr)
            lr_upscaled = torch.nn.functional.interpolate(lr, size=(1024, 1024), mode='bilinear', align_corners=False)
            lr_psnr = calculate_psnr(lr_upscaled, hr)
            
            fname_base = fname[0].split('.')[0]
            
            if sr_psnr > 30 and sr_psnr > lr_psnr:
                results.append([fname_base, f"{lr_psnr:.2f}", f"{sr_psnr:.2f}"])
            
            if sr_psnr > lr_psnr and sr_psnr > best_sr_psnr:
                best_sr_psnr = sr_psnr
                best_lr_psnr = lr_psnr
                best_image = fname_base
            
            save_super_resolved(sr, fname_base, sr_output_dir)
            save_comparison(lr, sr, hr, fname_base, output_dir)
    
    if results:
        print("\nBSD100 Evaluation Results:")
        print(tabulate(
            results,
            headers=['Image ID', 'PSNR Baseline (dB)', 'PSNR SR (dB)'],
            tablefmt='grid'
        ))
    
    print(f"\nBest Metrics (Image: {best_image}):")
    print(f"Best Baseline PSNR: {best_lr_psnr:.2f} dB")
    print(f"Best SR PSNR: {best_sr_psnr:.2f} dB")
    
    return results, best_sr_psnr, best_lr_psnr, best_image

# --- Main Execution ---

def main():
    """Main function to validate the super-resolution model on BSD100."""
    checkpoint_dir = "checkpoints"
    bsd100_dir = "bsd100"
    results_dir = "results"
    bsd100_output_dir = os.path.join(results_dir, "bsd100_results")
    bsd100_sr_output_dir = os.path.join(results_dir, "super_resolved_bsd100")

    os.makedirs(results_dir, exist_ok=True)

    best_model_path = None
    best_psnr = 0
    for f in os.listdir(checkpoint_dir):
        if f.startswith("best_model_epoch") and f.endswith(".pth"):
            psnr_str = f.split("_psnr_")[1].split(".pth")[0]
            psnr = float(psnr_str)
            if psnr > best_psnr:
                best_psnr = psnr
                best_model_path = os.path.join(checkpoint_dir, f)
    
    if best_model_path is None:
        raise FileNotFoundError("No best model found in checkpoints directory")

    model = UltraEnhancedSR(scale=2).cuda()
    model.load_state_dict(torch.load(best_model_path))
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.eval()

    dataset = SRValidationDataset(bsd100_dir, scale=2)
    data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
    validate_model(model, data_loader, bsd100_output_dir, bsd100_sr_output_dir)

if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    torch.cuda.empty_cache()
    main()


BSD100 Evaluation Results:
+------------+----------------------+----------------+
|   Image ID |   PSNR Baseline (dB) |   PSNR SR (dB) |
|     101087 |                29.9  |          30    |
+------------+----------------------+----------------+
|     103070 |                32.91 |          33.04 |
+------------+----------------------+----------------+
|     108082 |                29.76 |          30.04 |
+------------+----------------------+----------------+
|     109053 |                31.62 |          31.76 |
+------------+----------------------+----------------+
|     123074 |                32.9  |          32.96 |
+------------+----------------------+----------------+
|     126007 |                30.78 |          30.89 |
+------------+----------------------+----------------+
|     159008 |                29.81 |          30    |
+------------+----------------------+----------------+
|      16077 |                29.97 |          30.11 |
+------------+----------------------+