In [1]:
import os
from torch.utils.data import Dataset, ConcatDataset
from torchvision import transforms
from PIL import Image

# CustomDataset: Loads HR images from nested subfolders and corresponding LR images from a given subfolder.
class CustomDataset(Dataset):
    def __init__(self, root_dir, lr_subfolder_relative_path, transform_hr=None, transform_lr=None):
        self.hr_images = []
        for dirpath, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith('.png') and 'x4' not in file:
                    self.hr_images.append(os.path.join(dirpath, file))
        # Sort based on numeric order of filenames (assumes filenames are numbers)
        self.hr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path))[0]))
        
        lr_root = os.path.join(root_dir, lr_subfolder_relative_path)
        self.lr_images = []
        for dirpath, _, files in os.walk(lr_root):
            for file in files:
                if file.endswith('.png') and 'x4' in file:
                    self.lr_images.append(os.path.join(dirpath, file))
        self.lr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path).replace('x4',''))[0]))
        
        if len(self.hr_images) != len(self.lr_images):
            raise ValueError("The number of HR and LR images do not match!")
            
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr = Image.open(self.hr_images[idx]).convert('RGB')
        lr = Image.open(self.lr_images[idx]).convert('RGB')
        if self.transform_hr:
            hr = self.transform_hr(hr)
        if self.transform_lr:
            lr = self.transform_lr(lr)
        return {'hr': hr, 'lr': lr}

# SeparateDirsDataset: Loads HR and LR images from two separate directories.
class SeparateDirsDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform_hr=None, transform_lr=None):
        self.hr_images = [os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith('.png')]
        self.lr_images = [os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith('.png')]
        self.hr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path))[0]))
        self.lr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path).replace('x4',''))[0]))
        if len(self.hr_images) != len(self.lr_images):
            raise ValueError("The number of HR and LR images do not match!")
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr = Image.open(self.hr_images[idx]).convert('RGB')
        lr = Image.open(self.lr_images[idx]).convert('RGB')
        if self.transform_hr:
            hr = self.transform_hr(hr)
        if self.transform_lr:
            lr = self.transform_lr(lr)
        return {'hr': hr, 'lr': lr}

# Define transforms.
# For x4 SR, we assume HR images are resized to 64x64 and LR images to 16x16.
transform_hr = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])
transform_lr = transforms.Compose([
    transforms.Resize((16, 16)),
    transforms.ToTensor(),
])

# Set your dataset paths (update these if necessary):
root_directory = "/kaggle/input/lsdir-hr"             # For CustomDataset
lr_relative_path = "/kaggle/input/lsdir-hr/train_x4"     # For CustomDataset
hr_directory = "/kaggle/input/flickr2k/Flickr2K/Flickr2K_HR"  # For SeparateDirsDataset
lr_directory = "/kaggle/input/flickr2k/Flickr2K/Flickr2K_LR_bicubic/X4"  # For SeparateDirsDataset

# Create dataset objects.
dataset_nested = CustomDataset(root_dir=root_directory,
                               lr_subfolder_relative_path=lr_relative_path,
                               transform_hr=transform_hr,
                               transform_lr=transform_lr)
dataset_separate = SeparateDirsDataset(hr_dir=hr_directory,
                                       lr_dir=lr_directory,
                                       transform_hr=transform_hr,
                                       transform_lr=transform_lr)

# Combine datasets.
combined_dataset = ConcatDataset([dataset_nested, dataset_separate])
print(f"Total images: {len(combined_dataset)}")


Total images: 87641


In [2]:
import torchvision.transforms as T
import os
from torch.utils.data import Dataset, ConcatDataset
from torchvision import transforms
from PIL import Image

val_transform_hr = T.Compose([
    T.Resize((128, 128)),  # force same HR size for every image
    T.ToTensor(),
])

class ValDownsampleDataset(Dataset):
    def __init__(self, hr_dir, transform_hr=None):
        self.hr_dir = hr_dir
        self.hr_images = [os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith('.png')]
        self.hr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path))[0]))
        self.transform_hr = transform_hr

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')

        # 1) Transform HR to a fixed size.
        if self.transform_hr:
            hr = self.transform_hr(hr_img)  
        else:
            hr = T.ToTensor()(hr_img)

        # 2) Create LR by x4 downsampling from *that transformed HR* (128×128 -> 32×32).
        #    Do NOT use the original hr_img.size here. Instead, use hr.size().
        c, h, w = hr.shape  # e.g. h=128, w=128
        lr_width, lr_height = w // 4, h // 4

        # Convert hr back to PIL for downsampling
        hr_pil = T.ToPILImage()(hr)
        lr_pil = hr_pil.resize((lr_width, lr_height), Image.BICUBIC)
        lr = T.ToTensor()(lr_pil)

        return {'hr': hr, 'lr': lr}

# Then use `val_transform_hr` when creating the dataset.
val_hr_directory = "/kaggle/input/div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR"
val_dataset_1 = ValDownsampleDataset(
    hr_dir=val_hr_directory,
    transform_hr=val_transform_hr  # now all HR images become exactly (3,128,128)
)

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from torchvision.transforms import ToPILImage
from PIL import Image

# Channel Attention Module (Reused from RCAN)
class ChannelAttention(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(num_channels, num_channels // reduction_ratio, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(num_channels // reduction_ratio, num_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, num_channels, height, width = x.size()
        out = self.global_avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        weight = self.sigmoid(out)
        return x * weight

# Simplified Spatial Attention Module (Convolution-based)
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return self.sigmoid(out)

# Hybrid Attention Block (Simplified)
class HybridAttentionBlock(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16, spatial_kernel_size=7):
        super(HybridAttentionBlock, self).__init__()
        self.channel_attention = ChannelAttention(num_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(spatial_kernel_size)
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out_ca = self.channel_attention(out)
        out_sa = self.spatial_attention(out)
        out = out_ca * out_sa
        return out + residual

# HAT-Inspired Network
class HATInspiredNet(nn.Module):
    def __init__(self, num_channels=3, num_filters=64, num_hab=10, scale_factor=4, reduction_ratio=16, spatial_kernel_size=7):
        super(HATInspiredNet, self).__init__()
        self.num_filters = num_filters
        self.scale_factor = scale_factor
        self.initial_conv = nn.Conv2d(num_channels, num_filters, kernel_size=3, padding=1)
        self.hab_blocks = nn.Sequential(*[
            HybridAttentionBlock(num_filters, reduction_ratio, spatial_kernel_size) for _ in range(num_hab)
        ])
        self.conv_after_hab = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Conv2d(num_filters, num_filters * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
        self.final_conv = nn.Conv2d(num_filters, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.initial_conv(x)
        residual = x
        x = self.hab_blocks(x)
        x = self.conv_after_hab(x)
        x += residual
        x = self.upsample(x)
        x = self.final_conv(x)
        return x

# Function to calculate PSNR (Reused)
def calculate_psnr(img1, img2):
    img1 = img1.mul(255).byte().cpu().numpy()
    img2 = img2.mul(255).byte().cpu().numpy()
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse))

# Training Function (Reused with minor adjustments)
def train_model(train_dataloader, val_dataloader, model, criterion, optimizer, scheduler, num_epochs, device, save_dir="hat_checkpoints"):
    os.makedirs(save_dir, exist_ok=True)
    best_psnr = 0.0
    start_epoch = 0
    checkpoint_path = os.path.join(save_dir, "best_model.pth")

    # Load checkpoint if it exists
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_psnr = checkpoint['best_psnr']
        print(f"Resuming training from epoch {start_epoch} with best PSNR: {best_psnr:.4f}")
        model.to(device)
    else:
        model.to(device)
        print("Starting training from scratch.")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0.0
        train_psnr_sum = 0.0
        num_train_batches = len(train_dataloader)
        train_progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")

        for batch in train_progress_bar:
            lr_images = batch['lr'].to(device)
            hr_images = batch['hr'].to(device)

            optimizer.zero_grad()
            outputs = model(lr_images)
            loss = criterion(outputs, hr_images)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * lr_images.size(0)

            for i in range(outputs.size(0)):
                train_psnr_sum += calculate_psnr(outputs[i].clamp(0, 1), hr_images[i])

            train_progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

        avg_train_loss = train_loss / len(train_dataloader.dataset)
        avg_train_psnr = train_psnr_sum / len(train_dataloader.dataset)

        model.eval()
        val_psnr_sum = 0.0
        num_val_batches = len(val_dataloader)
        val_progress_bar = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")

        with torch.no_grad():
            for batch in val_progress_bar:
                lr_images = batch['lr'].to(device)
                hr_images = batch['hr'].to(device)
                outputs = model(lr_images)

                for i in range(outputs.size(0)):
                    val_psnr_sum += calculate_psnr(outputs[i].clamp(0, 1), hr_images[i])

        avg_val_psnr = val_psnr_sum / len(val_dataloader.dataset)
        scheduler.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Train PSNR: {avg_train_psnr:.4f}, Val PSNR: {avg_val_psnr:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")

        # Save the best model based on validation PSNR
        if avg_val_psnr > best_psnr:
            best_psnr = avg_val_psnr
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_psnr': best_psnr
            }, checkpoint_path)
            print(f"Validation PSNR improved. Saved checkpoint to {checkpoint_path}")

if __name__ == '__main__':

    # Hyperparameters for HAT-Inspired model
    batch_size = 32
    num_epochs = 400
    learning_rate = 1e-4
    num_filters = 64
    num_hab = 10  # Number of Hybrid Attention Blocks
    scale_factor = 4
    reduction_ratio = 16
    spatial_kernel_size = 7
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    save_directory = "hat_checkpoints"

    # Create data loaders (Reused)
    train_dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_dataloader = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Initialize model, loss, and optimizer
    model = HATInspiredNet(num_channels=3, num_filters=num_filters, num_hab=num_hab, scale_factor=scale_factor, reduction_ratio=reduction_ratio, spatial_kernel_size=spatial_kernel_size)
    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

    # Train the model
    train_model(train_dataloader, val_dataloader, model, criterion, optimizer, scheduler, num_epochs, device, save_dir=save_directory)

    print("Training finished!")

Starting training from scratch.


Epoch 1/400 [Train]: 100%|██████████| 2739/2739 [31:22<00:00,  1.45it/s, loss=0.0484]
Epoch 1/400 [Val]: 100%|██████████| 4/4 [00:05<00:00,  1.41s/it]


Epoch [1/400], Train Loss: 0.0562, Train PSNR: 29.8729, Val PSNR: 30.2458, LR: 0.000100
Validation PSNR improved. Saved checkpoint to hat_checkpoints/best_model.pth


Epoch 2/400 [Train]: 100%|██████████| 2739/2739 [29:45<00:00,  1.53it/s, loss=0.0442]
Epoch 2/400 [Val]: 100%|██████████| 4/4 [00:05<00:00,  1.40s/it]


Epoch [2/400], Train Loss: 0.0495, Train PSNR: 30.1667, Val PSNR: 30.3371, LR: 0.000100
Validation PSNR improved. Saved checkpoint to hat_checkpoints/best_model.pth


Epoch 3/400 [Train]:  33%|███▎      | 911/2739 [09:56<22:17,  1.37it/s, loss=0.0457]