In [3]:
import numpy as np
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

# =================== CHANNEL ATTENTION ===================
class ChannelAttention(nn.Module):
    """
    RCAN-style Channel Attention
    """
    def __init__(self, channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, 1, 0),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, 1, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, x):
        w = self.fc(self.avg_pool(x))
        return x * w

# =================== IMPROVED DENSE BLOCK ===================
class ImprovedDenseBlock(nn.Module):
    """
    IRE dense block (paper-accurate):
    - 3 conv layers with dense connections
    - Channel Attention applied after final conv
    - Residual scaling = 0.2
    """
    def __init__(self, nf=64, gc=32):
        super(ImprovedDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1)
        self.conv3 = nn.Conv2d(nf + 2*gc, nf, 3, 1, 1)
        self.ca = ChannelAttention(nf)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)
        self.scale = 0.2

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat([x, x1], dim=1)))
        x3 = self.conv3(torch.cat([x, x1, x2], dim=1))
        x3 = self.ca(x3)
        return x + x3 * self.scale

# =================== RRDB BLOCK ===================
class RRDB_IRE(nn.Module):
    """
    Residual-in-Residual Dense Block (RRDB) using 3 improved dense blocks
    """
    def __init__(self, nf=64, gc=32):
        super(RRDB_IRE, self).__init__()
        self.db1 = ImprovedDenseBlock(nf, gc)
        self.db2 = ImprovedDenseBlock(nf, gc)
        self.db3 = ImprovedDenseBlock(nf, gc)
        self.scale = 0.2

    def forward(self, x):
        out = self.db1(x)
        out = self.db2(out)
        out = self.db3(out)
        return x + out * self.scale

# =================== IRE GENERATOR ===================
class IREGenerator(nn.Module):
    """
    IRE Generator: Real-ESRGAN topology with RRDB backbone + Improved Dense Blocks
    Supports 4x super-resolution by default
    """
    def __init__(self, in_nc=3, out_nc=3, nf=64, num_rrdb=23, gc=32):
        super(IREGenerator, self).__init__()
        # First convolution
        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1)

        # RRDB trunk
        self.rrdb_trunk = nn.Sequential(*[RRDB_IRE(nf, gc) for _ in range(num_rrdb)])
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1)

        # Upsampling layers for 4x SR (x2 + x2)
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1)

        # High-resolution conv
        self.hr_conv = nn.Conv2d(nf, nf, 3, 1, 1)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        # Feature extraction
        fea = self.conv_first(x)

        # RRDB trunk
        trunk = self.trunk_conv(self.rrdb_trunk(fea))
        fea = fea + trunk

        # Upsampling x2 -> x2 (total x4)
        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))

        # HR reconstruction
        out = self.conv_last(self.lrelu(self.hr_conv(fea)))
        return out

# =================== PAIRED DATASET ===================
class PairedSRDataset(Dataset):
    """
    Dataset for paired LR-HR super-resolution images.
    LR and HR folders must contain images with identical filenames.
    """
    def __init__(self, lr_dir, hr_dir, transform=None):
        super(PairedSRDataset, self).__init__()
        self.lr_paths = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.lower().endswith(('.png','.jpg','.jpeg'))])
        self.hr_paths = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.lower().endswith(('.png','.jpg','.jpeg'))])
        assert len(self.lr_paths) == len(self.hr_paths), "LR and HR datasets must have same number of images"
        self.transform = transform or nn.Sequential(
            lambda img: torch.tensor(np.array(img)).permute(2,0,1).float() / 255.
        )

    def __len__(self):
        return len(self.lr_paths)

    def __getitem__(self, idx):
        lr_image = Image.open(self.lr_paths[idx]).convert("RGB")
        hr_image = Image.open(self.hr_paths[idx]).convert("RGB")
        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)
        return lr_image, hr_image

# =================== QUICK TEST ===================
if __name__ == "__main__":
    # Test generator
    model = IREGenerator()
    x = torch.randn(1, 3, 64, 64)  # LR input
    y = model(x)
    print("Generator output shape:", y.shape)  # Expected: [1, 3, 256, 256] -> 4x SR

    # Test dataset
    dataset = PairedSRDataset("path/to/LR_images", "path/to/HR_images")
    lr, hr = dataset[0]
    print("LR shape:", lr.shape, "HR shape:", hr.shape)


Generator output shape: torch.Size([1, 3, 256, 256])


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'path/to/LR_images'