In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/dlp-jan-2025-nppe-3/sample_submission.csv
/kaggle/input/dlp-jan-2025-nppe-3/submission.py
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00159.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00056.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00017.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00124.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00140.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00068.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00019.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00266.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00236.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00148.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00152.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00226.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00008.png
/kaggle/input/dlp-jan-2025-nppe-3/archive/val/gt/gt_00216.png
/kaggle/input/dlp-jan-2025-n

In [2]:
# Run this in a terminl or notebook
!pip install torch torchvision timm pytorch_msssim opencv-python tqdm pandas numpy matplotlib

Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch_msssim
Successfully installed pytorch_msssim-1.0.0


In [3]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm import tqdm
import pandas as pd

# Set random seeds for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True

# Custom dataset class
class LowLightDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        
        if split in ['train', 'val']:
            self.lr_dir = os.path.join(root_dir, split, split)
            self.hr_dir = os.path.join(root_dir, split, 'gt')
            self.lr_files = sorted(os.listdir(self.lr_dir))
            self.hr_files = sorted(os.listdir(self.hr_dir))
        else:  # test
            self.lr_dir = os.path.join(root_dir, 'test')
            self.lr_files = sorted(os.listdir(self.lr_dir))
            self.hr_files = None
    
    def __len__(self):
        return len(self.lr_files)
    
    def __getitem__(self, idx):
        lr_path = os.path.join(self.lr_dir, self.lr_files[idx])
        lr_img = cv2.imread(lr_path)
        lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB)
        
        if self.split in ['train', 'val']:
            hr_path = os.path.join(self.hr_dir, self.hr_files[idx])
            hr_img = cv2.imread(hr_path)
            hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)
        else:
            hr_img = None
            
        # Convert to float and normalize
        lr_img = lr_img.astype(np.float32) / 255.0
            
        # Convert to tensors
        lr_tensor = torch.from_numpy(lr_img.transpose(2, 0, 1))
        
        if hr_img is not None:
            hr_img = hr_img.astype(np.float32) / 255.0
            hr_tensor = torch.from_numpy(hr_img.transpose(2, 0, 1))
            return lr_tensor, hr_tensor, self.lr_files[idx]
        else:
            return lr_tensor, self.lr_files[idx]

# Data visualization function
def visualize_samples(dataset, num_samples=2):
    fig, axes = plt.subplots(num_samples, 2, figsize=(12, 8))
    
    for i in range(num_samples):
        lr_img, hr_img, _ = dataset[i]
        
        # Convert tensors to numpy arrays for visualization
        lr_img = lr_img.numpy().transpose(1, 2, 0)
        hr_img = hr_img.numpy().transpose(1, 2, 0)
        
        axes[i, 0].imshow(lr_img)
        axes[i, 0].set_title('Low Resolution')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(hr_img)
        axes[i, 1].set_title('High Resolution')
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Create datasets and dataloaders
train_dataset = LowLightDataset(root_dir='/kaggle/input/dlp-jan-2025-nppe-3/archive', split='train')
val_dataset = LowLightDataset(root_dir='/kaggle/input/dlp-jan-2025-nppe-3/archive', split='val')
test_dataset = LowLightDataset(root_dir='/kaggle/input/dlp-jan-2025-nppe-3/archive', split='test')

# Visualize training samples
# visualize_samples(train_dataset)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
def psnr(pred, target):
    # Ensure same dimensions
    if pred.shape != target.shape:
        pred = F.interpolate(pred, size=target.shape[2:], mode='bilinear', align_corners=False)
    
    mse = F.mse_loss(pred, target)
    return 10 * torch.log10(1.0 / mse)

def combined_loss(pred, target):
    # Ensure same dimensions
    if pred.shape != target.shape:
        pred = F.interpolate(pred, size=target.shape[2:], mode='bilinear', align_corners=False)
    
    # L1 Loss (better for PSNR optimization)
    return F.l1_loss(pred, target)

In [5]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torchvision.models import resnet18, ResNet18_Weights
# import math # For calculating padding in ResidualBlock

# # --- Helper Blocks ---

# class SEBlock(nn.Module):
#     """ Squeeze-and-Excitation Block """
#     def __init__(self, channel, reduction=16):
#         super(SEBlock, self).__init__()
#         self.avg_pool = nn.AdaptiveAvgPool2d(1)
#         self.fc = nn.Sequential(
#             nn.Linear(channel, channel // reduction, bias=False),
#             nn.ReLU(inplace=True),
#             nn.Linear(channel // reduction, channel, bias=False),
#             nn.Sigmoid()
#         )

#     def forward(self, x):
#         b, c, _, _ = x.size()
#         y = self.avg_pool(x).view(b, c)
#         y = self.fc(y).view(b, c, 1, 1)
#         return x * y.expand_as(x)

# class ResidualBlock(nn.Module):
#     """ Basic Residual Block """
#     def __init__(self, channels, use_batchnorm=True):
#         super(ResidualBlock, self).__init__()
#         self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=not use_batchnorm)
#         self.bn1 = nn.BatchNorm2d(channels) if use_batchnorm else nn.Identity()
#         self.relu = nn.ReLU(inplace=True)
#         self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=not use_batchnorm)
#         self.bn2 = nn.BatchNorm2d(channels) if use_batchnorm else nn.Identity()

#     def forward(self, x):
#         residual = x
#         out = self.relu(self.bn1(self.conv1(x)))
#         out = self.bn2(self.conv2(out))
#         out += residual # Add skip connection
#         out = self.relu(out) # Apply ReLU *after* addition
#         return out

# class ResidualDecoderBlockSE(nn.Module):
#     """Upsamples, concatenates skip, adjusts channels, applies Residual Blocks and SE."""
#     def __init__(self, in_channels_up, in_channels_skip, out_channels, num_res_blocks=2, use_batchnorm=True, use_se=True):
#         super().__init__()
        
#         # Upsampling using PixelShuffle
#         self.upsample_conv = nn.Conv2d(in_channels_up, out_channels * 4, kernel_size=1, padding=0)
#         self.pixel_shuffle = nn.PixelShuffle(2)
#         # Channels after pixel_shuffle = out_channels

#         # Convolution to adjust channels after concatenation
#         concat_channels = out_channels + in_channels_skip
#         self.conv_adjust = nn.Conv2d(concat_channels, out_channels, kernel_size=1, padding=0) # 1x1 conv for channel adjustment
#         self.relu_adjust = nn.ReLU(inplace=True)

#         # Residual Blocks
#         res_blocks = []
#         for _ in range(num_res_blocks):
#             res_blocks.append(ResidualBlock(out_channels, use_batchnorm=use_batchnorm))
#         self.res_blocks = nn.Sequential(*res_blocks)

#         # Optional SE Block
#         self.se = SEBlock(out_channels) if use_se else nn.Identity()

#     def forward(self, x_up, x_skip):
#         x_up = self.upsample_conv(x_up)
#         x_up = self.pixel_shuffle(x_up)

#         # Ensure spatial dimensions match for concatenation
#         if x_up.shape[2:] != x_skip.shape[2:]:
#              x_up = F.interpolate(x_up, size=x_skip.shape[2:], mode='bilinear', align_corners=False)

#         x = torch.cat([x_up, x_skip], dim=1)

#         # Adjust channels and apply residual blocks + SE
#         x = self.relu_adjust(self.conv_adjust(x))
#         x = self.res_blocks(x)
#         x = self.se(x)
#         return x

# class UpsampleBlockSR(nn.Module):
#     """ Upsampling block for the SR head """
#     def __init__(self, in_channels, out_channels, use_batchnorm=True):
#          super().__init__()
#          # Use Conv -> PixelShuffle -> ReLU pattern
#          self.conv = nn.Conv2d(in_channels, out_channels * 4, kernel_size=3, padding=1, bias=not use_batchnorm)
#          self.ps = nn.PixelShuffle(2)
#          self.bn = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
#          self.relu = nn.ReLU(inplace=True)

#     def forward(self, x):
#         x = self.conv(x)
#         x = self.ps(x)
#         x = self.relu(self.bn(x))
#         return x


# # --- Main Model V2 ---

# class UNetSR4x_v2(nn.Module):
#     def __init__(self, pretrained=True, use_batchnorm_decoder=True, use_se_decoder=True, decoder_res_blocks=2, bottleneck_res_blocks=2):
#         super().__init__()

#         # --- Encoder (ResNet-18 based) ---
#         weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
#         resnet = resnet18(weights=weights)

#         self.encoder_init_conv = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu) # 64, H/2, W/2
#         self.encoder_pool = resnet.maxpool # 64, H/4, W/4
#         self.encoder_layer1 = resnet.layer1 # 64, H/4, W/4
#         self.encoder_layer2 = resnet.layer2 # 128, H/8, W/8
#         self.encoder_layer3 = resnet.layer3 # 256, H/16, W/16
#         self.encoder_layer4 = resnet.layer4 # 512, H/32, W/32

#         # --- Bottleneck (with Residual Blocks) ---
#         bottleneck_layers = [nn.Conv2d(512, 512, kernel_size=1)] # Adjust channels if needed first
#         for _ in range(bottleneck_res_blocks):
#              bottleneck_layers.append(ResidualBlock(512, use_batchnorm=use_batchnorm_decoder))
#         # Optional: Add SE block in bottleneck
#         # bottleneck_layers.append(SEBlock(512))
#         self.bottleneck = nn.Sequential(*bottleneck_layers)


#         # --- Decoder (Using ResidualDecoderBlockSE) ---
#         # Channels: up_in, skip_in, out_ch
#         self.decoder_layer4 = ResidualDecoderBlockSE(512, 256, 256, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder) # H/16
#         self.decoder_layer3 = ResidualDecoderBlockSE(256, 128, 128, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder) # H/8
#         self.decoder_layer2 = ResidualDecoderBlockSE(128, 64, 64, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder)   # H/4
#         self.decoder_layer1 = ResidualDecoderBlockSE(64, 64, 64, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder)    # H/2
#         self.decoder_init = ResidualDecoderBlockSE(64, 64, 64, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder)     # H/1

#         # # --- Super-Resolution Head (Upsample 2 more times for 4x total) ---
#         # # Using the simpler UpsampleBlockSR here, could also use Residual Blocks
#         # self.upsampler_sr = nn.Sequential(
#         #     UpsampleBlockSR(64, 32, use_batchnorm=use_batchnorm_decoder), # -> 2x Res (32 channels)
#         #     UpsampleBlockSR(32, 16, use_batchnorm=use_batchnorm_decoder)  # -> 4x Res (16 channels)
#         # )

#         # --- Super-Resolution Head (Upsample 2 more times for 4x total) ---
#         self.upsampler_sr = nn.Sequential(
#             UpsampleBlockSR(64, 32, use_batchnorm=use_batchnorm_decoder), # -> 2x Res (32 channels)
#             UpsampleBlockSR(32, 16, use_batchnorm=use_batchnorm_decoder)  # -> 4x Res (16 channels)  <--- THIS IS THE SECOND 2x UPSCALE
#         )

#         # --- Final Output Layer ---
#         self.final_conv = nn.Conv2d(16, 3, kernel_size=3, padding=1)

#     def forward(self, x):
#         # --- Encoder ---
#         skip_init = self.encoder_init_conv(x)       # H/2, 64
#         pooled = self.encoder_pool(skip_init)       # H/4, 64
#         skip_layer1 = self.encoder_layer1(pooled)   # H/4, 64
#         skip_layer2 = self.encoder_layer2(skip_layer1) # H/8, 128
#         skip_layer3 = self.encoder_layer3(skip_layer2) # H/16, 256
#         encoded = self.encoder_layer4(skip_layer3)   # H/32, 512

#         # --- Bottleneck ---
#         bottleneck = self.bottleneck(encoded)

#         # --- Decoder ---
#         d4 = self.decoder_layer4(bottleneck, skip_layer3) # H/16, 256
#         d3 = self.decoder_layer3(d4, skip_layer2)         # H/8, 128
#         d2 = self.decoder_layer2(d3, skip_layer1)         # H/4, 64
#         d1 = self.decoder_layer1(d2, skip_init)           # H/2, 64 # Using skip_init (after conv+bn+relu)
#         d0 = self.decoder_init(d1, skip_init)             # H/1, 64 # Re-using skip_init here

#         # --- Super-Resolution Head ---
#         up_sr = self.upsampler_sr(d0) # 4x Res, 16 channels

#         # --- Final Output ---
#         output = self.final_conv(up_sr)

#         # Consider if sigmoid is the best choice. If targets are [0,1], it's okay.
#         # If using losses like L1/L2 on non-normalized data, remove activation.
#         # Tanh might be used if targets are [-1, 1].
#         return torch.sigmoid(output)

# # --- Initialization and Summary ---
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Instantiate the new model
# model = UNetSR4x_v2(
#     pretrained=True,
#     use_batchnorm_decoder=True, # BatchNorm often helps stabilization
#     use_se_decoder=True,        # Use Channel Attention
#     decoder_res_blocks=2,       # Number of residual blocks per decoder stage
#     bottleneck_res_blocks=2     # Number of residual blocks in the bottleneck
# ).to(device)

# # Print model summary
# def count_parameters(model):
#     return sum(p.numel() for p in model.parameters() if p.requires_grad)

# print(f"Model V2 has {count_parameters(model):,} trainable parameters")

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights

# --- Helper Blocks remain the same ---
class SEBlock(nn.Module):
    """ Squeeze-and-Excitation Block """
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class ResidualBlock(nn.Module):
    """ Basic Residual Block """
    def __init__(self, channels, use_batchnorm=True):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=not use_batchnorm)
        self.bn1 = nn.BatchNorm2d(channels) if use_batchnorm else nn.Identity()
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=not use_batchnorm)
        self.bn2 = nn.BatchNorm2d(channels) if use_batchnorm else nn.Identity()

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = self.relu(out)
        return out

class ResidualDecoderBlockSE(nn.Module):
    """Upsamples, concatenates skip, adjusts channels, applies Residual Blocks and SE."""
    def __init__(self, in_channels_up, in_channels_skip, out_channels, num_res_blocks=2, use_batchnorm=True, use_se=True):
        super().__init__()
        self.upsample_conv = nn.Conv2d(in_channels_up, out_channels * 4, kernel_size=1, padding=0)
        self.pixel_shuffle = nn.PixelShuffle(2)
        concat_channels = out_channels + in_channels_skip
        self.conv_adjust = nn.Conv2d(concat_channels, out_channels, kernel_size=1, padding=0)
        self.relu_adjust = nn.ReLU(inplace=True)
        res_blocks = [ResidualBlock(out_channels, use_batchnorm=use_batchnorm) for _ in range(num_res_blocks)]
        self.res_blocks = nn.Sequential(*res_blocks)
        self.se = SEBlock(out_channels) if use_se else nn.Identity()

    def forward(self, x_up, x_skip):
        x_up = self.upsample_conv(x_up)
        x_up = self.pixel_shuffle(x_up)
        if x_up.shape[2:] != x_skip.shape[2:]:
            x_up = F.interpolate(x_up, size=x_skip.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x_up, x_skip], dim=1)
        x = self.relu_adjust(self.conv_adjust(x))
        x = self.res_blocks(x)
        x = self.se(x)
        return x

class UpsampleBlockSR(nn.Module):
    """ Upsampling block for the SR head """
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
         super().__init__()
         self.conv = nn.Conv2d(in_channels, out_channels * 4, kernel_size=3, padding=1, bias=not use_batchnorm)
         self.ps = nn.PixelShuffle(2)
         self.bn = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
         self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.ps(x)
        x = self.relu(self.bn(x))
        return x

# --- Updated Main Model using ResNet-50 ---
class UNetSR4x_v2_ResNet50(nn.Module):
    def __init__(self, pretrained=True, use_batchnorm_decoder=True, use_se_decoder=True, decoder_res_blocks=2, bottleneck_res_blocks=2):
        super().__init__()

        # --- Encoder (ResNet-50 based) ---
        weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        resnet = resnet50(weights=weights)
        
        # The initial layers remain similar
        self.encoder_init_conv = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)   # 64, H/2, W/2
        self.encoder_pool = resnet.maxpool  # 64, H/4, W/4
        self.encoder_layer1 = resnet.layer1  # 256, H/4, W/4
        self.encoder_layer2 = resnet.layer2  # 512, H/8, W/8
        self.encoder_layer3 = resnet.layer3  # 1024, H/16, W/16
        self.encoder_layer4 = resnet.layer4  # 2048, H/32, W/32

        # --- Bottleneck ---
        # Reduce channel dimension from 2048 to 512 to match decoder input expectations
        self.channel_reduction = nn.Conv2d(2048, 512, kernel_size=1)
        bottleneck_layers = [ResidualBlock(512, use_batchnorm=use_batchnorm_decoder) for _ in range(bottleneck_res_blocks)]
        self.bottleneck = nn.Sequential(*bottleneck_layers)

        # --- Decoder (Using ResidualDecoderBlockSE) ---
        # Adjust channels according to encoder skip connections:
        # skip_layer3: from encoder_layer3, 1024 channels -> reduce to 256 before concatenation
        self.reduce_skip3 = nn.Conv2d(1024, 256, kernel_size=1)
        self.decoder_layer4 = ResidualDecoderBlockSE(512, 256, 256, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder)
        # skip_layer2: 512 channels -> reduce to 128
        self.reduce_skip2 = nn.Conv2d(512, 128, kernel_size=1)
        self.decoder_layer3 = ResidualDecoderBlockSE(256, 128, 128, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder)
        # skip_layer1: 256 channels -> reduce to 64
        self.reduce_skip1 = nn.Conv2d(256, 64, kernel_size=1)
        self.decoder_layer2 = ResidualDecoderBlockSE(128, 64, 64, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder)
        # skip_init: from encoder_init_conv, 64 channels remains the same
        self.decoder_layer1 = ResidualDecoderBlockSE(64, 64, 64, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder)
        self.decoder_init = ResidualDecoderBlockSE(64, 64, 64, num_res_blocks=decoder_res_blocks, use_batchnorm=use_batchnorm_decoder, use_se=use_se_decoder)

        # --- Super-Resolution Head ---
        self.upsampler_sr = nn.Sequential(
            UpsampleBlockSR(64, 32, use_batchnorm=use_batchnorm_decoder),
            UpsampleBlockSR(32, 16, use_batchnorm=use_batchnorm_decoder)
        )

        self.final_conv = nn.Conv2d(16, 3, kernel_size=3, padding=1)

    def forward(self, x):
        # --- Encoder ---
        skip_init = self.encoder_init_conv(x)        # H/2, 64
        pooled = self.encoder_pool(skip_init)          # H/4, 64
        skip_layer1 = self.encoder_layer1(pooled)        # H/4, 256
        skip_layer2 = self.encoder_layer2(skip_layer1)   # H/8, 512
        skip_layer3 = self.encoder_layer3(skip_layer2)   # H/16, 1024
        encoded = self.encoder_layer4(skip_layer3)       # H/32, 2048

        # --- Bottleneck ---
        reduced = self.channel_reduction(encoded)        # Now 512 channels
        bottleneck = self.bottleneck(reduced)

        # --- Decoder ---
        # Reduce skip connections to match expected channels
        skip3 = self.reduce_skip3(skip_layer3)           # 256 channels
        d4 = self.decoder_layer4(bottleneck, skip3)        # H/16, 256
        skip2 = self.reduce_skip2(skip_layer2)           # 128 channels
        d3 = self.decoder_layer3(d4, skip2)                # H/8, 128
        skip1 = self.reduce_skip1(skip_layer1)           # 64 channels
        d2 = self.decoder_layer2(d3, skip1)                # H/4, 64
        d1 = self.decoder_layer1(d2, skip_init)            # H/2, 64
        d0 = self.decoder_init(d1, skip_init)              # H/1, 64

        # --- Super-Resolution Head ---
        up_sr = self.upsampler_sr(d0)
        output = self.final_conv(up_sr)
        return torch.sigmoid(output)

# --- Example usage ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNetSR4x_v2_ResNet50(
    pretrained=True,
    use_batchnorm_decoder=True,
    use_se_decoder=True,
    decoder_res_blocks=2,
    bottleneck_res_blocks=2
).to(device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model with ResNet-50 backbone has {count_parameters(model):,} trainable parameters")

In [32]:
def train_model(model, train_loader, val_loader, num_epochs=5):
    device = next(model.parameters()).device
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)
    
    best_psnr = 0.0
    best_model_path = 'best_model.pth'
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        
        with tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]') as pbar:
            for lr_imgs, hr_imgs, _ in pbar:
                lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
                
                optimizer.zero_grad()
                outputs = model(lr_imgs)
                
                # Ensure outputs have the same shape as hr_imgs for loss calculation
                if outputs.shape != hr_imgs.shape:
                    outputs = F.interpolate(outputs, size=hr_imgs.shape[2:], mode='bilinear', align_corners=False)
                
                loss = combined_loss(outputs, hr_imgs)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                pbar.set_postfix({'loss': loss.item()})
        
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation phase
        model.eval()
        val_psnr_score = 0.0
        
        with torch.no_grad():
            for lr_imgs, hr_imgs, _ in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]'):
                lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
                
                outputs = model(lr_imgs)
                
                # Ensure outputs have the same shape as hr_imgs for evaluation
                if outputs.shape != hr_imgs.shape:
                    outputs = F.interpolate(outputs, size=hr_imgs.shape[2:], mode='bilinear', align_corners=False)
                
                psnr_score = psnr(outputs, hr_imgs).item()
                val_psnr_score += psnr_score
        
        avg_val_psnr = val_psnr_score / len(val_loader)
        
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val PSNR: {avg_val_psnr:.2f} dB")
        
        # Update learning rate based on validation PSNR
        scheduler.step(avg_val_psnr)
        
        # Save the best model
        if avg_val_psnr > best_psnr:
            best_psnr = avg_val_psnr
            torch.save(model.state_dict(), best_model_path)
            print(f"Model saved with PSNR: {best_psnr:.2f} dB")
    
    return best_model_path

# Train the model
best_model_path = train_model(model, train_loader, val_loader, num_epochs=70)

In [9]:
import torch
import torch.nn.functional as F
import numpy as np
import cv2 # Make sure to import OpenCV
import os
from tqdm import tqdm

# --- Helper for TTA ---
def augment_and_predict(model, img_tensor):
    """ Applies augmentations, predicts, and inverse transforms """
    predictions = []

    # 1. Original
    pred_orig = model(img_tensor)
    # predictions.append(pred_orig)

    # # 2. Horizontal Flip
    # img_flipped_h = torch.flip(img_tensor, dims=[3]) # Flip width dimension
    # pred_flipped_h = model(img_flipped_h)
    # pred_inv_flipped_h = torch.flip(pred_flipped_h, dims=[3]) # Inverse transform
    # predictions.append(pred_inv_flipped_h)

    # # --- Add more augmentations if desired (e.g., vertical flip) ---
    # # # 3. Vertical Flip
    # # img_flipped_v = torch.flip(img_tensor, dims=[2]) # Flip height dimension
    # # pred_flipped_v = model(img_flipped_v)
    # # pred_inv_flipped_v = torch.flip(pred_flipped_v, dims=[2]) # Inverse transform
    # # predictions.append(pred_inv_flipped_v)

    # # --- Average the predictions ---
    # ensembled_output = torch.stack(predictions).mean(dim=0)
    return pred_orig


# --- Main Prediction Function ---
def run_prediction(model, model_path, test_loader, upscale_factor=4): # Added upscale_factor
    device = next(model.parameters()).device
    try:
        # Use weights_only=True for security if loading untrusted checkpoints
        model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
        print(f"Loaded model weights from {model_path} (weights_only=True)")
    except TypeError: # Fallback for older PyTorch versions or non-weights_only checkpoints
         model.load_state_dict(torch.load(model_path, map_location=device))
         print(f"Loaded model weights from {model_path} (standard loading)")
    except Exception as e:
        print(f"Error loading model weights: {e}")
        return # Exit if model loading fails

    model.eval() # Set model to evaluation mode

    output_dir = 'predictions'
    os.makedirs(output_dir, exist_ok=True)
    print(f"Saving predictions to: {output_dir}")

    with torch.no_grad(): # Disable gradient calculations
        for idx, batch_data in enumerate(tqdm(test_loader, desc="Generating predictions")):

            # --- Adapt based on your DataLoader output ---
            if len(batch_data) == 2:
                lr_img, filename_tuple = batch_data
            elif len(batch_data) == 1: # If loader only returns images
                lr_img = batch_data[0]
                # Create a dummy filename if needed, or adapt saving logic
                filename_tuple = (f"test_img_{idx:04d}.png",)
            else:
                 print(f"Unexpected data format from DataLoader: {batch_data}")
                 continue # Skip this batch
            # --- End Adapt ---

            # Ensure batch size is 1 for this processing logic
            if lr_img.shape[0] != 1:
                print(f"Warning: Batch size > 1 detected ({lr_img.shape[0]}). Processing only the first image.")
                lr_img = lr_img[0].unsqueeze(0)
                # Adjust filename handling if needed, here we just use the first
                filename = filename_tuple[0]
            else:
                 filename = filename_tuple[0] # Extract filename string

            lr_img = lr_img.to(device)

            # Get prediction with test-time augmentation
            output = augment_and_predict(model, lr_img)

            # --- Check Upscaling Factor (Adjust as needed) ---
            target_h = lr_img.shape[2] * upscale_factor
            target_w = lr_img.shape[3] * upscale_factor
            if output.shape[2] != target_h or output.shape[3] != target_w:
                print(f"Warning: Output shape {output.shape[2:]} doesn't match target {target_h}x{target_w}. Interpolating.")
                output = F.interpolate(output,
                                      size=(target_h, target_w),
                                      mode='bilinear',
                                      align_corners=False)
            # --- End Check ---

            # Process output image (convert to savable format)
            # Squeeze batch dim, move to CPU, convert to NumPy, transpose to HWC
            output_np = output.squeeze(0).cpu().numpy().transpose(1, 2, 0)
            # Denormalize (assuming model output is [0, 1]) and clip
            output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)

            # Prepare filename and save path
            save_name = filename # Use the original filename
            output_path = os.path.join(output_dir, save_name)

            # Save the prediction image (converting RGB to BGR for cv2)
            try:
                cv2.imwrite(output_path, cv2.cvtColor(output_np, cv2.COLOR_RGB2BGR))
            except Exception as e:
                print(f"Error saving image {output_path}: {e}")


    print(f"\nAll prediction images saved to '{output_dir}' folder")

In [10]:
run_prediction(model, best_model_path, test_loader, upscale_factor=4)

Loaded model weights from best_model.pth (weights_only=True)
Saving predictions to: predictions


Generating predictions:   0%|          | 0/60 [00:00<?, ?it/s]



Generating predictions:   2%|▏         | 1/60 [00:00<00:14,  3.96it/s]



Generating predictions:   5%|▌         | 3/60 [00:00<00:06,  8.60it/s]



Generating predictions:   8%|▊         | 5/60 [00:00<00:04, 11.09it/s]



Generating predictions:  12%|█▏        | 7/60 [00:00<00:04, 12.87it/s]



Generating predictions:  15%|█▌        | 9/60 [00:00<00:03, 13.97it/s]



Generating predictions:  18%|█▊        | 11/60 [00:00<00:03, 14.45it/s]



Generating predictions:  22%|██▏       | 13/60 [00:01<00:03, 14.87it/s]



Generating predictions:  25%|██▌       | 15/60 [00:01<00:02, 15.27it/s]



Generating predictions:  28%|██▊       | 17/60 [00:01<00:02, 15.38it/s]



Generating predictions:  32%|███▏      | 19/60 [00:01<00:02, 15.27it/s]



Generating predictions:  35%|███▌      | 21/60 [00:01<00:02, 15.51it/s]



Generating predictions:  38%|███▊      | 23/60 [00:01<00:02, 15.61it/s]



Generating predictions:  42%|████▏     | 25/60 [00:01<00:02, 15.54it/s]



Generating predictions:  45%|████▌     | 27/60 [00:01<00:02, 15.44it/s]



Generating predictions:  48%|████▊     | 29/60 [00:02<00:02, 15.48it/s]



Generating predictions:  52%|█████▏    | 31/60 [00:02<00:01, 15.41it/s]



Generating predictions:  55%|█████▌    | 33/60 [00:02<00:01, 15.68it/s]



Generating predictions:  58%|█████▊    | 35/60 [00:02<00:01, 15.82it/s]



Generating predictions:  62%|██████▏   | 37/60 [00:02<00:01, 15.70it/s]



Generating predictions:  65%|██████▌   | 39/60 [00:02<00:01, 15.67it/s]



Generating predictions:  68%|██████▊   | 41/60 [00:02<00:01, 15.63it/s]



Generating predictions:  72%|███████▏  | 43/60 [00:02<00:01, 15.22it/s]



Generating predictions:  75%|███████▌  | 45/60 [00:03<00:00, 15.18it/s]



Generating predictions:  78%|███████▊  | 47/60 [00:03<00:00, 14.94it/s]



Generating predictions:  82%|████████▏ | 49/60 [00:03<00:00, 14.97it/s]



Generating predictions:  85%|████████▌ | 51/60 [00:03<00:00, 14.87it/s]



Generating predictions:  88%|████████▊ | 53/60 [00:03<00:00, 15.17it/s]



Generating predictions:  95%|█████████▌| 57/60 [00:03<00:00, 15.90it/s]



Generating predictions: 100%|██████████| 60/60 [00:04<00:00, 14.72it/s]


All prediction images saved to 'predictions' folder





In [11]:
import numpy as np
import pandas as pd
from PIL import Image

def images_to_csv(folder_path, output_csv):
    data_rows = []
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            image_path = os.path.join(folder_path, filename)
            image = Image.open(image_path).convert('L') 
            image_array = np.array(image).flatten()[::8]
            # Replace 'test_' with 'gt_' in the ID
            image_id = filename.split('.')[0].replace('test_', 'gt_')
            data_rows.append([image_id, *image_array])
    column_names = ['ID'] + [f'pixel_{i}' for i in range(len(data_rows[0]) - 1)]
    df = pd.DataFrame(data_rows, columns=column_names)
    df.to_csv(output_csv, index=False)
    print(f'Successfully saved to {output_csv}')

folder_path = '/kaggle/working/predictions'
output_csv = 'submission.csv'
images_to_csv(folder_path, output_csv)

Successfully saved to submission.csv
