In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

# =================== CHANNEL ATTENTION ===================
class ChannelAttention(nn.Module):
    """
    RCAN-style Channel Attention (as in IRE paper)
    """
    def __init__(self, num_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(num_channels, num_channels // reduction_ratio, kernel_size=1, stride=1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_channels // reduction_ratio, num_channels, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        w = self.fc(self.avg_pool(x))
        return x * w

# =================== IMPROVED DENSE BLOCK ===================
class ImprovedDenseBlock(nn.Module):
    """
    IRE Dense Block:
    - 3 conv layers with dense connections
    - Channel Attention
    - Residual scaling = 0.2
    """
    def __init__(self, num_features=64, growth_channels=32):
        super(ImprovedDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_features, growth_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(num_features + growth_channels, growth_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv3 = nn.Conv2d(num_features + 2 * growth_channels, num_features, kernel_size=3, stride=1, padding=1, bias=True)

        self.ca = ChannelAttention(num_features)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)
        self.res_scale = 0.2  # residual scaling as in IRE paper

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat([x, x1], dim=1)))
        x3 = self.conv3(torch.cat([x, x1, x2], dim=1))
        x3 = self.ca(x3)
        return x + x3 * self.res_scale

# =================== RRDB BLOCK ===================
class RRDB_IRE(nn.Module):
    """
    Residual-in-Residual Dense Block using 3 Improved Dense Blocks
    """
    def __init__(self, num_features=64, growth_channels=32):
        super(RRDB_IRE, self).__init__()
        self.db1 = ImprovedDenseBlock(num_features, growth_channels)
        self.db2 = ImprovedDenseBlock(num_features, growth_channels)
        self.db3 = ImprovedDenseBlock(num_features, growth_channels)
        self.res_scale = 0.2  # residual-in-residual scaling

    def forward(self, x):
        out = self.db1(x)
        out = self.db2(out)
        out = self.db3(out)
        return x + out * self.res_scale

# =================== IRE GENERATOR ===================
class IREGenerator(nn.Module):
    """
    IRE Generator for paired LR-HR datasets
    No downsampling inside, uses RRDB backbone and upsampling for SR
    """
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 num_features=64,
                 num_rrdb_blocks=23,
                 growth_channels=32,
                 scale_factor=4):
        super(IREGenerator, self).__init__()

        self.scale_factor = scale_factor

        # Feature extraction
        self.conv_first = nn.Conv2d(in_channels, num_features, kernel_size=3, stride=1, padding=1, bias=True)

        # RRDB trunk
        self.rrdb_trunk = nn.Sequential(
            *[RRDB_IRE(num_features, growth_channels) for _ in range(num_rrdb_blocks)]
        )
        self.trunk_conv = nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1, bias=True)

        # Upsampling layers (nearest + conv)
        self.upconv1 = nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1, bias=True)
        self.upconv2 = nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1, bias=True)

        # High-resolution reconstruction
        self.hr_conv = nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv_last = nn.Conv2d(num_features, out_channels, kernel_size=3, stride=1, padding=1, bias=True)

        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.rrdb_trunk(fea))
        fea = fea + trunk

        # Upsampling
        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))

        out = self.conv_last(self.lrelu(self.hr_conv(fea)))
        return out

# =================== PAIRED LR-HR DATASET ===================
class PairedSRDataset(Dataset):
    """
    Dataset for paired LR-HR images.
    Assumes LR and HR images already exist and filenames match.
    """
    def __init__(self, lr_dir, hr_dir):
        self.lr_paths = sorted([
            os.path.join(lr_dir, f)
            for f in os.listdir(lr_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ])
        self.hr_paths = sorted([
            os.path.join(hr_dir, f)
            for f in os.listdir(hr_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ])

        assert len(self.lr_paths) == len(self.hr_paths), "LR and HR images count mismatch!"

        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.lr_paths)

    def __getitem__(self, idx):
        lr = Image.open(self.lr_paths[idx]).convert("RGB")
        hr = Image.open(self.hr_paths[idx]).convert("RGB")
        lr = self.transform(lr)
        hr = self.transform(hr)
        return lr, hr

# =================== QUICK TEST ===================
if __name__ == "__main__":
    model = IREGenerator(
        in_channels=3,
        out_channels=3,
        num_features=64,
        num_rrdb_blocks=3,  # reduced for quick test
        growth_channels=32
    )

    # Test forward pass
    x = torch.randn(1, 3, 64, 64)
    y = model(x)
    print("Generator output shape:", y.shape)  # Expected: [1, 3, 256, 256]

    # Test dataset
    dataset = PairedSRDataset("data/LR", "data/HR")
    lr, hr = dataset[0]
    print("LR shape:", lr.shape, "HR shape:", hr.shape)
