In [None]:
import os
import glob
import math
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_tensor, to_pil_image

# Custom dataset for super-resolution training
class SRDataset(Dataset):
    def __init__(self, image_dir, scale_factor, patch_size=96, is_training=True):
        self.image_paths = glob.glob(os.path.join(image_dir, "*.png")) + \
                          glob.glob(os.path.join(image_dir, "*.jpg")) + \
                          glob.glob(os.path.join(image_dir, "*.jpeg"))
        self.scale_factor = scale_factor
        self.patch_size = patch_size
        self.is_training = is_training
        
        # For testing, we use the full images
        if not is_training:
            self.lr_transforms = transforms.Compose([
                transforms.ToTensor()
            ])
            self.hr_transforms = transforms.Compose([
                transforms.ToTensor()
            ])
        # For training, we extract random patches
        else:
            self.lr_transforms = transforms.Compose([
                transforms.ToTensor()
            ])
            self.hr_transforms = transforms.Compose([
                transforms.ToTensor()
            ])
            
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        hr_img = Image.open(img_path).convert('RGB')
        
        if self.is_training:
            # Get random patch for training
            hr_width, hr_height = hr_img.size
            lr_patch_size = self.patch_size // self.scale_factor
            
            # Ensure the image is large enough for patch extraction
            if hr_width < self.patch_size or hr_height < self.patch_size:
                hr_img = transforms.Resize((max(self.patch_size, hr_height), 
                                           max(self.patch_size, hr_width)))(hr_img)
                hr_width, hr_height = hr_img.size
            
            # Random crop
            left = np.random.randint(0, hr_width - self.patch_size)
            top = np.random.randint(0, hr_height - self.patch_size)
            right = left + self.patch_size
            bottom = top + self.patch_size
            
            hr_img = hr_img.crop((left, top, right, bottom))
            
            # Create LR image through downsampling
            lr_img = hr_img.resize((lr_patch_size, lr_patch_size), Image.BICUBIC)
            
            # Apply data augmentation (random flips and rotations)
            if np.random.random() > 0.5:
                hr_img = hr_img.transpose(Image.FLIP_LEFT_RIGHT)
                lr_img = lr_img.transpose(Image.FLIP_LEFT_RIGHT)
            if np.random.random() > 0.5:
                hr_img = hr_img.transpose(Image.FLIP_TOP_BOTTOM)
                lr_img = lr_img.transpose(Image.FLIP_TOP_BOTTOM)
            
            rotation = np.random.choice([0, 90, 180, 270])
            if rotation > 0:
                hr_img = hr_img.rotate(rotation)
                lr_img = lr_img.rotate(rotation)
                
        else:
            # For testing, resize the entire image
            hr_width, hr_height = hr_img.size
            lr_img = hr_img.resize((hr_width // self.scale_factor, hr_height // self.scale_factor), Image.BICUBIC)
            
        # Convert to tensor
        lr_tensor = self.lr_transforms(lr_img)
        hr_tensor = self.hr_transforms(hr_img)
        
        return {"lr": lr_tensor, "hr": hr_tensor}

def train_espcn(model, train_dir, val_dir, scale_factor, batch_size=16, patch_size=96, num_epochs=100, 
               learning_rate=1e-3, device='cuda', save_path='best_model.pth'):
    """
    Train the ESPCN model
    
    Args:
        model: ESPCN model instance
        train_dir: Directory containing training images
        val_dir: Directory containing validation images
        scale_factor: Super-resolution scale factor
        batch_size: Training batch size
        patch_size: Size of HR image patches for training
        num_epochs: Number of training epochs
        learning_rate: Learning rate
        device: Device for training ('cuda' or 'cpu')
        save_path: Path to save the best model weights
    """
    # Set device
    device = torch.device(device if torch.cuda.is_available() and device == 'cuda' else 'cpu')
    model = model.to(device)
    
    # Create datasets and dataloaders
    train_dataset = SRDataset(train_dir, scale_factor=scale_factor, patch_size=patch_size, is_training=True)
    val_dataset = SRDataset(val_dir, scale_factor=scale_factor, is_training=False)
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1)
    
    # Loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    
    # Training loop
    best_psnr = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_psnr = 0.0
        
        with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for batch in train_dataloader:
                lr_imgs = batch["lr"].to(device)
                hr_imgs = batch["hr"].to(device)
                
                # Forward pass
                optimizer.zero_grad()
                sr_imgs = model(lr_imgs)
                
                # Compute loss
                loss = criterion(sr_imgs, hr_imgs)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                # Calculate PSNR
                batch_psnr = 10 * torch.log10(1.0 / loss).item()
                
                train_loss += loss.item()
                train_psnr += batch_psnr
                
                pbar.update(1)
                pbar.set_postfix({"loss": loss.item(), "PSNR": batch_psnr})
        
        avg_train_loss = train_loss / len(train_dataloader)
        avg_train_psnr = train_psnr / len(train_dataloader)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_psnr = 0.0
        
        with torch.no_grad():
            for batch in val_dataloader:
                lr_imgs = batch["lr"].to(device)
                hr_imgs = batch["hr"].to(device)
                
                sr_imgs = model(lr_imgs)
                loss = criterion(sr_imgs, hr_imgs)
                
                batch_psnr = 10 * torch.log10(1.0 / loss).item()
                
                val_loss += loss.item()
                val_psnr += batch_psnr
        
        avg_val_loss = val_loss / len(val_dataloader)
        avg_val_psnr = val_psnr / len(val_dataloader)
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        
        # Save the best model
        if avg_val_psnr > best_psnr:
            best_psnr = avg_val_psnr
            torch.save(model.state_dict(), save_path)
            print(f"Best model saved with PSNR: {best_psnr:.2f} dB")
        
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"Train Loss: {avg_train_loss:.6f}, Train PSNR: {avg_train_psnr:.2f} dB")
        print(f"Val Loss: {avg_val_loss:.6f}, Val PSNR: {avg_val_psnr:.2f} dB")
        print("-" * 50)

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images"""
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 10 * torch.log10(1.0 / mse)

def calculate_ssim(img1, img2):
    """Calculate SSIM between two images"""
    C1 = (0.01 * 1) ** 2
    C2 = (0.03 * 1) ** 2
    
    img1 = img1.unsqueeze(0)
    img2 = img2.unsqueeze(0)

    mu1 = torch.nn.functional.avg_pool2d(img1, kernel_size=11, stride=1, padding=5)
    mu2 = torch.nn.functional.avg_pool2d(img2, kernel_size=11, stride=1, padding=5)
    
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2
    
    sigma1_sq = torch.nn.functional.avg_pool2d(img1 * img1, kernel_size=11, stride=1, padding=5) - mu1_sq
    sigma2_sq = torch.nn.functional.avg_pool2d(img2 * img2, kernel_size=11, stride=1, padding=5) - mu2_sq
    sigma12 = torch.nn.functional.avg_pool2d(img1 * img2, kernel_size=11, stride=1, padding=5) - mu1_mu2
    
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean().item()

def test_espcn(model, test_dir, output_dir, scale_factor, device='cuda'):
    """
    Test the ESPCN model on a directory of test images
    
    Args:
        model: Trained ESPCN model instance
        test_dir: Directory containing test images
        output_dir: Directory to save super-resolution results
        scale_factor: Super-resolution scale factor
        device: Device for inference ('cuda' or 'cpu')
    """
    # Set device
    device = torch.device(device if torch.cuda.is_available() and device == 'cuda' else 'cpu')
    model = model.to(device)
    model.eval()
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all image paths
    image_paths = glob.glob(os.path.join(test_dir, "*.png")) + \
                  glob.glob(os.path.join(test_dir, "*.jpg")) + \
                  glob.glob(os.path.join(test_dir, "*.jpeg"))
    
    # Metrics
    total_psnr = 0.0
    total_ssim = 0.0
    
    # Process each image
    for img_path in tqdm(image_paths, desc="Testing"):
        # Load image
        img_name = os.path.basename(img_path)
        hr_img = Image.open(img_path).convert('RGB')
        hr_width, hr_height = hr_img.size
        
        # Create bicubic upscaled image for comparison
        lr_img = hr_img.resize((hr_width // scale_factor, hr_height // scale_factor), Image.BICUBIC)
        bicubic_img = lr_img.resize((hr_width, hr_height), Image.BICUBIC)
        
        # Convert to tensor
        lr_tensor = to_tensor(lr_img).unsqueeze(0).to(device)
        hr_tensor = to_tensor(hr_img).to(device)
        bicubic_tensor = to_tensor(bicubic_img).to(device)
        
        # Generate SR image
        with torch.no_grad():
            sr_tensor = model(lr_tensor).squeeze(0).clamp(0.0, 1.0)
        
        # Calculate metrics
        psnr_val = calculate_psnr(sr_tensor, hr_tensor)
        ssim_val = calculate_ssim(sr_tensor, hr_tensor)
        bicubic_psnr = calculate_psnr(bicubic_tensor, hr_tensor)
        bicubic_ssim = calculate_ssim(bicubic_tensor, hr_tensor)
        
        total_psnr += psnr_val
        total_ssim += ssim_val
        
        # Save SR image
        sr_img = to_pil_image(sr_tensor.cpu())
        sr_img.save(os.path.join(output_dir, f"SR_{img_name}"))
        
        # Save bicubic image for comparison
        bicubic_img.save(os.path.join(output_dir, f"Bicubic_{img_name}"))
        
        # Save LR image
        lr_img.save(os.path.join(output_dir, f"LR_{img_name}"))
        
        # Print metrics for this image
        print(f"Image: {img_name}")
        print(f"ESPCN - PSNR: {psnr_val:.2f} dB, SSIM: {ssim_val:.4f}")
        print(f"Bicubic - PSNR: {bicubic_psnr:.2f} dB, SSIM: {bicubic_ssim:.4f}")
        print(f"Improvement - PSNR: {psnr_val - bicubic_psnr:.2f} dB, SSIM: {ssim_val - bicubic_ssim:.4f}")
        print("-" * 50)
    
    # Print average metrics
    avg_psnr = total_psnr / len(image_paths)
    avg_ssim = total_ssim / len(image_paths)
    print(f"Average PSNR: {avg_psnr:.2f} dB")
    print(f"Average SSIM: {avg_ssim:.4f}")



In [None]:
# Usage example
if __name__ == "__main__":
    # Initialize model
    model = ESPCN(
        in_channels=3,  # RGB image
        out_channels=3,  # RGB output
        channels=64,    # Number of feature channels
        upscale_factor=4 # Upscaling factor
    )
    
    # Example directories
    train_dir = "dataset/train"
    val_dir = "dataset/val"
    test_dir = "dataset/test"
    output_dir = "results"
    
    # Train the model
    train_espcn(
        model=model,
        train_dir=train_dir,
        val_dir=val_dir,
        scale_factor=4,
        batch_size=16,
        patch_size=96,
        num_epochs=100,
        device='cuda',
        save_path='best_espcn_x4.pth'
    )
    
    # Load best model for testing
    model.load_state_dict(torch.load('best_espcn_x4.pth'))
    
    # Test the model
    test_espcn(
        model=model,
        test_dir=test_dir,
        output_dir=output_dir,
        scale_factor=4,
        device='cuda'
    )