In [1]:
import os
from torch.utils.data import Dataset, ConcatDataset
from torchvision import transforms
from PIL import Image

# CustomDataset: Loads HR images from nested subfolders and corresponding LR images from a given subfolder.
class CustomDataset(Dataset):
    def __init__(self, root_dir, lr_subfolder_relative_path, transform_hr=None, transform_lr=None):
        self.hr_images = []
        for dirpath, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith('.png') and 'x4' not in file:
                    self.hr_images.append(os.path.join(dirpath, file))
        # Sort based on numeric order of filenames (assumes filenames are numbers)
        self.hr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path))[0]))
        
        lr_root = os.path.join(root_dir, lr_subfolder_relative_path)
        self.lr_images = []
        for dirpath, _, files in os.walk(lr_root):
            for file in files:
                if file.endswith('.png') and 'x4' in file:
                    self.lr_images.append(os.path.join(dirpath, file))
        self.lr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path).replace('x4',''))[0]))
        
        if len(self.hr_images) != len(self.lr_images):
            raise ValueError("The number of HR and LR images do not match!")
            
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr = Image.open(self.hr_images[idx]).convert('RGB')
        lr = Image.open(self.lr_images[idx]).convert('RGB')
        if self.transform_hr:
            hr = self.transform_hr(hr)
        if self.transform_lr:
            lr = self.transform_lr(lr)
        return {'hr': hr, 'lr': lr}

# SeparateDirsDataset: Loads HR and LR images from two separate directories.
class SeparateDirsDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform_hr=None, transform_lr=None):
        self.hr_images = [os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith('.png')]
        self.lr_images = [os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith('.png')]
        self.hr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path))[0]))
        self.lr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path).replace('x4',''))[0]))
        if len(self.hr_images) != len(self.lr_images):
            raise ValueError("The number of HR and LR images do not match!")
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr = Image.open(self.hr_images[idx]).convert('RGB')
        lr = Image.open(self.lr_images[idx]).convert('RGB')
        if self.transform_hr:
            hr = self.transform_hr(hr)
        if self.transform_lr:
            lr = self.transform_lr(lr)
        return {'hr': hr, 'lr': lr}

# Define transforms.
# For x4 SR, we assume HR images are resized to 64x64 and LR images to 16x16.
transform_hr = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])
transform_lr = transforms.Compose([
    transforms.Resize((16, 16)),
    transforms.ToTensor(),
])

# Set your dataset paths (update these if necessary):
root_directory = "/kaggle/input/lsdir-hr"             # For CustomDataset
lr_relative_path = "/kaggle/input/lsdir-hr/train_x4"     # For CustomDataset
hr_directory = "/kaggle/input/flickr2k/Flickr2K/Flickr2K_HR"  # For SeparateDirsDataset
lr_directory = "/kaggle/input/flickr2k/Flickr2K/Flickr2K_LR_bicubic/X4"  # For SeparateDirsDataset

# Create dataset objects.
dataset_nested = CustomDataset(root_dir=root_directory,
                               lr_subfolder_relative_path=lr_relative_path,
                               transform_hr=transform_hr,
                               transform_lr=transform_lr)
dataset_separate = SeparateDirsDataset(hr_dir=hr_directory,
                                       lr_dir=lr_directory,
                                       transform_hr=transform_hr,
                                       transform_lr=transform_lr)

# Combine datasets.
combined_dataset = ConcatDataset([dataset_nested, dataset_separate])
print(f"Total images: {len(combined_dataset)}")


Total images: 87641


In [2]:
import torchvision.transforms as T
import os
from torch.utils.data import Dataset, ConcatDataset
from torchvision import transforms
from PIL import Image

val_transform_hr = T.Compose([
    T.Resize((128, 128)),  # force same HR size for every image
    T.ToTensor(),
])

class ValDownsampleDataset(Dataset):
    def __init__(self, hr_dir, transform_hr=None):
        self.hr_dir = hr_dir
        self.hr_images = [os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith('.png')]
        self.hr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path))[0]))
        self.transform_hr = transform_hr

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')

        # 1) Transform HR to a fixed size.
        if self.transform_hr:
            hr = self.transform_hr(hr_img)  
        else:
            hr = T.ToTensor()(hr_img)

        # 2) Create LR by x4 downsampling from *that transformed HR* (128×128 -> 32×32).
        #    Do NOT use the original hr_img.size here. Instead, use hr.size().
        c, h, w = hr.shape  # e.g. h=128, w=128
        lr_width, lr_height = w // 4, h // 4

        # Convert hr back to PIL for downsampling
        hr_pil = T.ToPILImage()(hr)
        lr_pil = hr_pil.resize((lr_width, lr_height), Image.BICUBIC)
        lr = T.ToTensor()(lr_pil)

        return {'hr': hr, 'lr': lr}

# Then use `val_transform_hr` when creating the dataset.
val_hr_directory = "/kaggle/input/div2k-dataset/DIV2K_valid_HR/DIV2K_valid_HR"
val_dataset_1 = ValDownsampleDataset(
    hr_dir=val_hr_directory,
    transform_hr=val_transform_hr  # now all HR images become exactly (3,128,128)
)

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset

# Re-define Dataset Loaders (if not already in the same script)
class CustomDataset(Dataset):
    def __init__(self, root_dir, lr_subfolder_relative_path, transform_hr=None, transform_lr=None):
        self.hr_images = []
        for dirpath, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith('.png') and 'x4' not in file:
                    self.hr_images.append(os.path.join(dirpath, file))
        self.hr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path))[0]))

        lr_root = os.path.join(root_dir, lr_subfolder_relative_path)
        self.lr_images = []
        for dirpath, _, files in os.walk(lr_root):
            for file in files:
                if file.endswith('.png') and 'x4' in file:
                    self.lr_images.append(os.path.join(dirpath, file))
        self.lr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path).replace('x4',''))[0]))

        if len(self.hr_images) != len(self.lr_images):
            raise ValueError("The number of HR and LR images do not match!")

        self.transform_hr = transform_hr
        self.transform_lr = transform_lr

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr = Image.open(self.hr_images[idx]).convert('RGB')
        lr = Image.open(self.lr_images[idx]).convert('RGB')
        if self.transform_hr:
            hr = self.transform_hr(hr)
        if self.transform_lr:
            lr = self.transform_lr(lr)
        return {'hr': hr, 'lr': lr}

class SeparateDirsDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform_hr=None, transform_lr=None):
        self.hr_images = [os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith('.png')]
        self.lr_images = [os.path.join(lr_dir, f) for f in os.listdir(lr_dir) if f.endswith('.png')]
        self.hr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path))[0]))
        self.lr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path).replace('x4',''))[0]))
        if len(self.hr_images) != len(self.lr_images):
            raise ValueError("The number of HR and LR images do not match!")
        self.transform_hr = transform_hr
        self.transform_lr = transform_lr

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr = Image.open(self.hr_images[idx]).convert('RGB')
        lr = Image.open(self.lr_images[idx]).convert('RGB')
        if self.transform_hr:
            hr = self.transform_hr(hr)
        if self.transform_lr:
            lr = self.transform_lr(lr)
        return {'hr': hr, 'lr': lr}

class ValDownsampleDataset(Dataset):
    def __init__(self, hr_dir, transform_hr=None):
        self.hr_dir = hr_dir
        self.hr_images = [os.path.join(hr_dir, f) for f in os.listdir(hr_dir) if f.endswith('.png')]
        self.hr_images.sort(key=lambda path: int(os.path.splitext(os.path.basename(path))[0]))
        self.transform_hr = transform_hr
        self.to_pil = transforms.ToPILImage()
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')

        if self.transform_hr:
            hr = self.transform_hr(hr_img)
        else:
            hr = self.to_tensor(hr_img)

        c, h, w = hr.shape
        lr_width, lr_height = w // 4, h // 4

        hr_pil = self.to_pil(hr)
        lr_pil = hr_pil.resize((lr_width, lr_height), Image.BICUBIC)
        lr = self.to_tensor(lr_pil)

        return {'hr': hr, 'lr': lr}

# Channel Attention Module (Reused)
class ChannelAttention(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(num_channels, num_channels // reduction_ratio, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(num_channels // reduction_ratio, num_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, num_channels, height, width = x.size()
        out = self.global_avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        weight = self.sigmoid(out)
        return x * weight

# Simplified Spatial Attention Module (Convolution-based)
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return self.sigmoid(out)

# Hybrid Attention Block (Simplified)
class HybridAttentionBlock(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16, spatial_kernel_size=7):
        super(HybridAttentionBlock, self).__init__()
        self.channel_attention = ChannelAttention(num_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(spatial_kernel_size)
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out_ca = self.channel_attention(out)
        out_sa = self.spatial_attention(out)
        out = out_ca * out_sa
        return out + residual

# HAT-Inspired Network
class HATInspiredNet(nn.Module):
    def __init__(self, num_channels=3, num_filters=64, num_hab=10, scale_factor=4, reduction_ratio=16, spatial_kernel_size=7):
        super(HATInspiredNet, self).__init__()
        self.num_filters = num_filters
        self.scale_factor = scale_factor
        self.initial_conv = nn.Conv2d(num_channels, num_filters, kernel_size=3, padding=1)
        self.hab_blocks = nn.Sequential(*[
            HybridAttentionBlock(num_filters, reduction_ratio, spatial_kernel_size) for _ in range(num_hab)
        ])
        self.conv_after_hab = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Conv2d(num_filters, num_filters * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
        self.final_conv = nn.Conv2d(num_filters, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.initial_conv(x)
        residual = x
        x = self.hab_blocks(x)
        x = self.conv_after_hab(x)
        x += residual
        x = self.upsample(x)
        x = self.final_conv(x)
        return x

# Residual Channel Attention Block (RCAB)
class ResidualChannelAttentionBlock(nn.Module):
    def __init__(self, num_channels, reduction_ratio=16):
        super(ResidualChannelAttentionBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.ca = ChannelAttention(num_channels, reduction_ratio)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.ca(out)
        return out + residual

# Residual Group (RG)
class ResidualGroup(nn.Module):
    def __init__(self, num_channels, num_rcab, reduction_ratio=16):
        super(ResidualGroup, self).__init__()
        layers = [ResidualChannelAttentionBlock(num_channels, reduction_ratio) for _ in range(num_rcab)]
        self.rcabs = nn.Sequential(*layers)
        self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        out = self.rcabs(x)
        out = self.conv(out)
        return out + residual

# RCAN-inspired Network
class RCANNet(nn.Module):
    def __init__(self, num_channels=3, num_filters=64, num_rg=10, num_rcab=20, scale_factor=4, reduction_ratio=16):
        super(RCANNet, self).__init__()
        self.num_filters = num_filters
        self.scale_factor = scale_factor
        self.initial_conv = nn.Conv2d(num_channels, num_filters, kernel_size=3, padding=1)
        self.residual_groups = nn.ModuleList([
            ResidualGroup(num_filters, num_rcab, reduction_ratio) for _ in range(num_rg)
        ])
        self.conv_after_rg = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Conv2d(num_filters, num_filters * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
        self.final_conv = nn.Conv2d(num_filters, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.initial_conv(x)
        residual = x
        for rg in self.residual_groups:
            x = rg(x)
        x = self.conv_after_rg(x)
        x += residual
        x = self.upsample(x)
        x = self.final_conv(x)
        return x

# Simplified SwinIR Block (Convolutional Approximation)
class SwinIRBlock(nn.Module):
    def __init__(self, num_channels, window_size=8, reduction_ratio=4):
        super(SwinIRBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.ca = ChannelAttention(num_channels, reduction_ratio)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.ca(out)
        return out + residual

# SwinIR-Inspired Network (Simplified)
class SwinIRInspiredNet(nn.Module):
    def __init__(self, num_channels=3, num_filters=64, num_blocks=10, scale_factor=4, window_size=8):
        super(SwinIRInspiredNet, self).__init__()
        self.num_filters = num_filters
        self.scale_factor = scale_factor
        self.initial_conv = nn.Conv2d(num_channels, num_filters, kernel_size=3, padding=1)
        self.swinir_blocks = nn.Sequential(*[
            SwinIRBlock(num_filters, window_size) for _ in range(num_blocks)
        ])
        self.conv_after_blocks = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.upsample = nn.Sequential(
            nn.Conv2d(num_filters, num_filters * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor)
        )
        self.final_conv = nn.Conv2d(num_filters, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.initial_conv(x)
        residual = x
        x = self.swinir_blocks(x)
        x = self.conv_after_blocks(x)
        x += residual
        x = self.upsample(x)
        x = self.final_conv(x)
        return x

# Function to calculate PSNR (Reused)
def calculate_psnr(img1, img2):
    img1 = img1.mul(255).byte().cpu().numpy()
    img2 = img2.mul(255).byte().cpu().numpy()
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse))

def load_checkpoint(model, optimizer, scheduler, path, device):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        epoch = checkpoint['epoch']
        best_psnr_individual = checkpoint['best_psnr']
        print(f"Loaded checkpoint from {path} at epoch {epoch} with best PSNR: {best_psnr_individual:.4f}")

        # Move optimizer state to the correct device
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

        return epoch, best_psnr_individual
    return -1, -1

def train_ensemble_model(train_dataloader, val_dataloader, hat_model, rcan_model, swinir_model, criterion, hat_optimizer, rcan_optimizer, swinir_optimizer, hat_scheduler, rcan_scheduler, swinir_scheduler, num_epochs, device, save_dir="ensemble_checkpoints"):
    os.makedirs(save_dir, exist_ok=True)
    best_psnr = 0.0
    start_epoch = 0
    checkpoint_path = os.path.join(save_dir, "best_ensemble_model.pth")

    hat_checkpoint_path = os.path.join("hat_checkpoints", "best_model.pth")
    rcan_checkpoint_path = os.path.join("/kaggle/input/rcan/pytorch/default/1/best_model (2).pth")
    swinir_checkpoint_path = os.path.join("swinir_checkpoints", "best_model.pth")

    load_checkpoint(hat_model, hat_optimizer, hat_scheduler, hat_checkpoint_path, device)
    load_checkpoint(rcan_model, rcan_optimizer, rcan_scheduler, rcan_checkpoint_path, device)
    load_checkpoint(swinir_model, swinir_optimizer, swinir_scheduler, swinir_checkpoint_path, device)

    hat_model.to(device)
    rcan_model.to(device)
    swinir_model.to(device)

    print("Starting ensemble training.")

    for epoch in range(start_epoch, num_epochs):
        hat_model.train()
        rcan_model.train()
        swinir_model.train()

        train_loss = 0.0
        train_psnr_sum = 0.0
        num_train_batches = len(train_dataloader)
        train_progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Train Ensemble]")

        for batch in train_progress_bar:
            lr_images = batch['lr'].to(device)
            hr_images = batch['hr'].to(device)

            # Zero gradients for all optimizers
            hat_optimizer.zero_grad()
            rcan_optimizer.zero_grad()
            swinir_optimizer.zero_grad()

            # Forward passes
            hat_output = hat_model(lr_images)
            rcan_output = rcan_model(lr_images)
            swinir_output = swinir_model(lr_images)

            # Calculate individual losses
            hat_loss = criterion(hat_output, hr_images)
            rcan_loss = criterion(rcan_output, hr_images)
            swinir_loss = criterion(swinir_output, hr_images)

            # Ensemble loss (can be a simple average or a weighted sum)
            total_loss = (hat_loss + rcan_loss + swinir_loss) / 3

            # Backpropagate the total loss through all models
            total_loss.backward()

            # Update optimizers
            hat_optimizer.step()
            rcan_optimizer.step()
            swinir_optimizer.step()

            train_loss += total_loss.item() * lr_images.size(0)

            # Calculate PSNR for the average output (optional, for monitoring)
            avg_output = (hat_output + rcan_output + swinir_output) / 3
            for i in range(avg_output.size(0)):
                train_psnr_sum += calculate_psnr(avg_output[i].clamp(0, 1), hr_images[i])

            train_progress_bar.set_postfix({'loss': f'{total_loss.item():.4f}'})

        avg_train_loss = train_loss / len(train_dataloader.dataset)
        avg_train_psnr = train_psnr_sum / len(train_dataloader.dataset)

        # Validation
        hat_model.eval()
        rcan_model.eval()
        swinir_model.eval()
        val_psnr_sum = 0.0
        val_progress_bar = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Val Ensemble]")

        with torch.no_grad():
            for batch in val_progress_bar:
                lr_images = batch['lr'].to(device)
                hr_images = batch['hr'].to(device)

                hat_output = hat_model(lr_images)
                rcan_output = rcan_model(lr_images)
                swinir_output = swinir_model(lr_images)

                avg_output = (hat_output + rcan_output + swinir_output) / 3
                for i in range(avg_output.size(0)):
                    val_psnr_sum += calculate_psnr(avg_output[i].clamp(0, 1), hr_images[i])

        avg_val_psnr = val_psnr_sum / len(val_dataloader.dataset)

        hat_scheduler.step()
        rcan_scheduler.step()
        swinir_scheduler.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Ensemble Train Loss: {avg_train_loss:.4f}, Ensemble Train PSNR (Avg): {avg_train_psnr:.4f}, Ensemble Val PSNR (Avg): {avg_val_psnr:.4f}")

        # Save the ensemble state (optional - you might prefer saving individual best models)
        if avg_val_psnr > best_psnr:
            best_psnr = avg_val_psnr
            torch.save({
                'epoch': epoch,
                'hat_state_dict': hat_model.state_dict(),
                'rcan_state_dict': rcan_model.state_dict(),
                'swinir_state_dict': swinir_model.state_dict(),
                'hat_optimizer_state_dict': hat_optimizer.state_dict(),
                'rcan_optimizer_state_dict': rcan_optimizer.state_dict(),
                'swinir_optimizer_state_dict': swinir_optimizer.state_dict(),
                'hat_scheduler_state_dict': hat_scheduler.state_dict(),
                'rcan_scheduler_state_dict': rcan_scheduler.state_dict(),
                'swinir_scheduler_state_dict': swinir_scheduler.state_dict(),
                'best_psnr': best_psnr
            }, checkpoint_path)
            print(f"Ensemble Validation PSNR improved. Saved checkpoint to {checkpoint_path}")

if __name__ == '__main__':
    # Hyperparameters
    batch_size = 32
    num_epochs = 200 # Adjust as needed for ensemble training
    learning_rate = 1e-4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize individual models
    hat_model = HATInspiredNet(num_channels=3, num_filters=64, num_hab=10, scale_factor=4)
    rcan_model = RCANNet(num_channels=3, num_filters=64, num_rg=10, num_rcab=20, scale_factor=4)
    swinir_model = SwinIRInspiredNet(num_channels=3, num_filters=64, num_blocks=10, scale_factor=4)

    # Initialize optimizers
    hat_optimizer = optim.Adam(hat_model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
    rcan_optimizer = optim.Adam(rcan_model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
    swinir_optimizer = optim.Adam(swinir_model.parameters(), lr=learning_rate, betas=(0.9, 0.999))

    # Initialize schedulers
    hat_scheduler = optim.lr_scheduler.StepLR(hat_optimizer, step_size=100, gamma=0.5)
    rcan_scheduler = optim.lr_scheduler.StepLR(rcan_optimizer, step_size=100, gamma=0.5)
    swinir_scheduler = optim.lr_scheduler.StepLR(swinir_optimizer, step_size=100, gamma=0.5)

    # Loss function
    criterion = nn.L1Loss()

    # Create data loaders (Reused)
    train_dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_dataloader = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Train the ensemble model
    train_ensemble_model(train_dataloader, val_dataloader, hat_model, rcan_model, swinir_model, criterion, hat_optimizer, rcan_optimizer, swinir_optimizer, hat_scheduler, rcan_scheduler, swinir_scheduler, num_epochs, device, save_dir="ensemble_checkpoints")

    print("Ensemble training finished!")

  checkpoint = torch.load(path)


Loaded checkpoint from hat_checkpoints/best_model.pth at epoch 1 with best PSNR: 30.3371
Loaded checkpoint from /kaggle/input/rcan/pytorch/default/1/best_model (2).pth at epoch 7 with best PSNR: 30.5619
Loaded checkpoint from swinir_checkpoints/best_model.pth at epoch 1 with best PSNR: 30.3465
Starting ensemble training.


Epoch 1/200 [Train Ensemble]: 100%|██████████| 2739/2739 [33:41<00:00,  1.36it/s, loss=0.0470]
Epoch 1/200 [Val Ensemble]: 100%|██████████| 4/4 [00:06<00:00,  1.74s/it]


Epoch [1/200], Ensemble Train Loss: 0.0464, Ensemble Train PSNR (Avg): 30.4584, Ensemble Val PSNR (Avg): 30.5119
Ensemble Validation PSNR improved. Saved checkpoint to ensemble_checkpoints/best_ensemble_model.pth


Epoch 2/200 [Train Ensemble]: 100%|██████████| 2739/2739 [33:29<00:00,  1.36it/s, loss=0.0494]
Epoch 2/200 [Val Ensemble]: 100%|██████████| 4/4 [00:07<00:00,  1.76s/it]


Epoch [2/200], Ensemble Train Loss: 0.0461, Ensemble Train PSNR (Avg): 30.4812, Ensemble Val PSNR (Avg): 30.5123
Ensemble Validation PSNR improved. Saved checkpoint to ensemble_checkpoints/best_ensemble_model.pth


Epoch 3/200 [Train Ensemble]: 100%|██████████| 2739/2739 [33:27<00:00,  1.36it/s, loss=0.0425]
Epoch 3/200 [Val Ensemble]: 100%|██████████| 4/4 [00:06<00:00,  1.63s/it]


Epoch [3/200], Ensemble Train Loss: 0.0457, Ensemble Train PSNR (Avg): 30.5131, Ensemble Val PSNR (Avg): 30.5373
Ensemble Validation PSNR improved. Saved checkpoint to ensemble_checkpoints/best_ensemble_model.pth


Epoch 4/200 [Train Ensemble]: 100%|██████████| 2739/2739 [33:11<00:00,  1.38it/s, loss=0.0468]
Epoch 4/200 [Val Ensemble]: 100%|██████████| 4/4 [00:06<00:00,  1.62s/it]


Epoch [4/200], Ensemble Train Loss: 0.0453, Ensemble Train PSNR (Avg): 30.5443, Ensemble Val PSNR (Avg): 30.5470
Ensemble Validation PSNR improved. Saved checkpoint to ensemble_checkpoints/best_ensemble_model.pth


Epoch 5/200 [Train Ensemble]: 100%|██████████| 2739/2739 [33:05<00:00,  1.38it/s, loss=0.0494]
Epoch 5/200 [Val Ensemble]: 100%|██████████| 4/4 [00:06<00:00,  1.63s/it]


Epoch [5/200], Ensemble Train Loss: 0.0449, Ensemble Train PSNR (Avg): 30.5712, Ensemble Val PSNR (Avg): 30.5404


Epoch 6/200 [Train Ensemble]: 100%|██████████| 2739/2739 [33:11<00:00,  1.38it/s, loss=0.0424]
Epoch 6/200 [Val Ensemble]: 100%|██████████| 4/4 [00:06<00:00,  1.68s/it]


Epoch [6/200], Ensemble Train Loss: 0.0446, Ensemble Train PSNR (Avg): 30.5936, Ensemble Val PSNR (Avg): 30.5525
Ensemble Validation PSNR improved. Saved checkpoint to ensemble_checkpoints/best_ensemble_model.pth


Epoch 7/200 [Train Ensemble]: 100%|██████████| 2739/2739 [33:05<00:00,  1.38it/s, loss=0.0404]
Epoch 7/200 [Val Ensemble]: 100%|██████████| 4/4 [00:06<00:00,  1.73s/it]


Epoch [7/200], Ensemble Train Loss: 0.0443, Ensemble Train PSNR (Avg): 30.6140, Ensemble Val PSNR (Avg): 30.5578
Ensemble Validation PSNR improved. Saved checkpoint to ensemble_checkpoints/best_ensemble_model.pth


Epoch 8/200 [Train Ensemble]: 100%|██████████| 2739/2739 [33:20<00:00,  1.37it/s, loss=0.0381]
Epoch 8/200 [Val Ensemble]: 100%|██████████| 4/4 [00:06<00:00,  1.70s/it]


Epoch [8/200], Ensemble Train Loss: 0.0441, Ensemble Train PSNR (Avg): 30.6315, Ensemble Val PSNR (Avg): 30.5723
Ensemble Validation PSNR improved. Saved checkpoint to ensemble_checkpoints/best_ensemble_model.pth


Epoch 9/200 [Train Ensemble]: 100%|██████████| 2739/2739 [33:13<00:00,  1.37it/s, loss=0.0429]
Epoch 9/200 [Val Ensemble]: 100%|██████████| 4/4 [00:06<00:00,  1.70s/it]


Epoch [9/200], Ensemble Train Loss: 0.0438, Ensemble Train PSNR (Avg): 30.6475, Ensemble Val PSNR (Avg): 30.5722


Epoch 10/200 [Train Ensemble]:  73%|███████▎  | 2012/2739 [24:41<08:32,  1.42it/s, loss=0.0432]