## Kaggle competetion link
https://www.kaggle.com/t/3e046f08abcb44969613c750d7fe9243


In [1]:
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
from os.path import join
from torchvision.transforms import v2
from torch import optim, nn
from torchmetrics.image import PeakSignalNoiseRatio
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

In [2]:
class ImageEnhancementDataset(Dataset):
    def __init__(self, image_folder, label_folder, input_transform=None, label_transform=None):
        """
        Args:
            image_folder (str): Path to the folder containing input images.
            label_folder (str): Path to the folder containing label images.
            input_transform (callable, optional): Transform to be applied to input images.
            label_transform (callable, optional): Transform to be applied to label images.
        """
        self.image_folder = image_folder
        self.label_folder = label_folder
        self.input_transform = input_transform
        self.label_transform = label_transform

        # Ensure both folders have the same number of files
        self.image_filenames = sorted(os.listdir(image_folder))
        self.label_filenames = sorted(os.listdir(label_folder))

        assert len(self.image_filenames) == len(self.label_filenames), "Mismatch between image and label counts."

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        # Load the input and label images
        image_path = os.path.join(self.image_folder, self.image_filenames[idx])
        label_path = os.path.join(self.label_folder, self.label_filenames[idx])

        image = Image.open(image_path) #.convert("RGB")
        label = Image.open(label_path) #.convert("RGB")

        # Ensure label size is 4x the input size
        input_size = image.size  # (width, height)
        label_size = label.size  # (width, height)
        expected_label_size = (input_size[0] * 4, input_size[1] * 4)

        assert label_size == expected_label_size, (
            f"Label size {label_size} does not match the expected size {expected_label_size} for input {input_size}."
        )

        # Apply transformations if provided
        if self.input_transform:
            image = self.input_transform(image)

        if self.label_transform:
            label = self.label_transform(label)

        return image, label

In [3]:
root_dir = "/kaggle/input/enhance-the-dark-world/archive"
train_dir = join(root_dir,"train")
val_dir = join(root_dir,"val")
test_dir = join(root_dir,"test")

In [4]:
image_transform = v2.Compose([
        v2.ToImage(), 
        v2.ToDtype(torch.float32, scale=True),
        transforms.Normalize(mean=[0.183, 0.198, 0.176], std=[0.083, 0.085, 0.081]),  # Normalize to a standard range
])

label_transform=  v2.Compose([v2.ToImage(),  v2.ToDtype(torch.float32, scale=True)])
train_dataset = ImageEnhancementDataset(
    join(train_dir,"train"), join(train_dir,"gt"),input_transform=image_transform,label_transform=label_transform
)

val_dataset = ImageEnhancementDataset(
    join(val_dir,"val"), join(val_dir, "gt"), input_transform=image_transform, label_transform=label_transform
)


batch_size = 4

train_loader = DataLoader(train_dataset,batch_size = batch_size, shuffle = True, num_workers = 4)
val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = 4)


In [5]:


class Conv2d1x1(nn.Module):
    def __init__(self, input_channels: int, reduction_factor: int = 1, out_channels: int = None):
        super().__init__()

        if out_channels is None:
            out_channels = input_channels // reduction_factor

        self.out_channels = out_channels

        # define the 1x1 convolution layer
        self.conv = nn.Conv2d(in_channels=input_channels, out_channels=out_channels, kernel_size=(1, 1))

    def forward(self, x):
        return self.conv(x)


class DepthwiseConv2d(nn.Module):
    def __init__(self, input_channels: int, kernel_size: int):
        super().__init__()

        padding_size = kernel_size // 2

        self.conv = nn.Conv2d(in_channels=input_channels, out_channels=input_channels,
                              kernel_size=(kernel_size, kernel_size), groups=input_channels,
                              padding=(padding_size, padding_size))

    def forward(self, input_tensor):
        return self.conv(input_tensor)


class PointwiseConv2d(nn.Module):
    def __init__(self, input_channels: int, out_channels: int = None):
        super().__init__()

        if out_channels is None:
            out_channels = input_channels

        self.conv = Conv2d1x1(input_channels=input_channels, out_channels=out_channels)

    def forward(self, input_tensor):
        return self.conv(input_tensor)


class TwoFoldAttentionModule(nn.Module):
    class ChannelUnit(nn.Module):
        def __init__(self, input_channels: int):
            super().__init__()

            self.in_channels = input_channels

            # we define a global average pooling layer that extracts first-order statistics of features
            self.global_avg_pooling = nn.AdaptiveAvgPool2d(output_size=1)

            # we then define two 1x1 convolutions that will work on half of the input channels and that will produce
            # half of the output channels; these two 1x1 convolutions will not reduce the number of channels of the
            # input tensor
            conv_1x1_input_channels = input_channels // 2
            self.conv1x1_1 = Conv2d1x1(input_channels=conv_1x1_input_channels)
            self.conv1x1_2 = Conv2d1x1(input_channels=conv_1x1_input_channels)

        def forward(self, input_tensor):
            # first, we feed the input to the global average pooling layer to extract first-order statistics of features
            # input_size = (N, in_channels, H, W)
            first_order_statistics = self.global_avg_pooling(input_tensor)  # output_size = (N, in_channels, 1, 1)

            # after producing first order statistic of features, we need to split the output tensor of the global
            # average pooling into two tensors along the channel dimension
            half_channels = self.in_channels // 2
            first_half_input, second_half_input = torch.split(first_order_statistics,
                                                              split_size_or_sections=half_channels, dim=1)
            # output: two tensors of size (N, in_channels/2, 1, 1)

            # now that we obtained the two halves of the channels, we feed them respectively to the first and second 1x1
            # convolutions that will produce half the output channels
            first_half_output = self.conv1x1_1(first_half_input)  # output_size = (N, in_channels/2, 1, 1)
            second_half_output = self.conv1x1_2(second_half_input)  # output_size = (N, in_channels/2, 1, 1)

            # we then concatenate the two halves of the output channels to get all the output channels
            concatenated_halves = torch.cat((first_half_output, second_half_output), dim=1)
            # output_size = (N, in_channels, 1, 1)

            # now, we compute element-wise multiplication of the input tensor (x) and the output tensor produced by the
            # concatenation operation
            output = torch.mul(concatenated_halves, input_tensor)  # output_size = (N, in_channels, 1, 1)

            return output

    class PositionalUnit(nn.Module):
        def __init__(self, input_channels: int):
            super().__init__()

            # define the average pooling and the max pooling layers with a large kernel
            self.avg_pooling = nn.AvgPool2d(kernel_size=(7, 7))
            self.max_pooling = nn.MaxPool2d(kernel_size=(7, 7))

            # define the final convolutional layer
            kernel_size = 7
            padding_size = kernel_size // 2
            self.conv2d_1 = nn.Conv2d(in_channels=input_channels * 2, out_channels=input_channels,
                                      kernel_size=(kernel_size, kernel_size),
                                      padding=(padding_size, padding_size))

        def forward(self, input_tensor):
            # get the spatial dimensions of the input tensor
            height = input_tensor.size()[2]
            width = input_tensor.size()[3]

            # first, we feed the input tensor to the average pooling layer and to the max pooling layer
            output_max_pool = self.max_pooling(input_tensor)
            output_avg_pool = self.avg_pooling(input_tensor)

            # then, we concatenate the two outputs produced by the max
            output_pool = torch.cat((output_max_pool, output_avg_pool), dim=1)

            # now, we upsample the output concatenation to recover spatial dimensions
            upsampled_out = F.interpolate(output_pool, size=(height, width), mode="bilinear", align_corners=False)

            # once we upsampled the concatenation, we apply
            output = self.conv2d_1(upsampled_out)

            return output

    def __init__(self, input_channels: int):
        super().__init__()

        # define first 1x1 convolution layer
        self.conv1x1_1 = Conv2d1x1(input_channels=input_channels, reduction_factor=16)

        # now, we define respectively the Channel Unit (CA Unit) and the Positional Unit (Pos Unit)
        self.ca_unit = self.ChannelUnit(input_channels=self.conv1x1_1.out_channels)
        self.pos_unit = self.PositionalUnit(input_channels=self.conv1x1_1.out_channels)

        # define the last 1x1 convolution layer used to recover original channel dimension
        self.conv1x1_2 = Conv2d1x1(input_channels=self.conv1x1_1.out_channels, out_channels=input_channels)

        # define the sigmoid function to generate the final attention mast
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor):
        # first, we pass the input tensor to the first 1x1 convolution layer
        conv_1x1_1_out = self.conv1x1_1(input_tensor)

        # then, we pass the output produced by the 1x1 convolution layer to the Channel unit
        ca_unit_out = self.ca_unit(conv_1x1_1_out)

        # we pass the output produced by the 1x1 convolution layer to the Positional Unit
        pos_unit_out = self.pos_unit(conv_1x1_1_out)

        # we compute element-wise sum of the two tensors produced by the two units
        sum_output = torch.add(ca_unit_out, pos_unit_out)

        # we feed the aggregated tensors to the last 1x1 conv layer to recover channel dimensions
        conv_1x1_2_out = self.conv1x1_2(sum_output)

        # we feed the output of the 1x1 conv layer to the sigmoid layer to compute the final attention mast
        sigmoid_out = self.sigmoid(conv_1x1_2_out)

        # finally, we compute the element-wise multiplication between the computed final attention mask and the input
        # tensor
        output = torch.mul(input_tensor, sigmoid_out)

        return output


class AdaptiveResidualBlock(nn.Module):
    class BottleneckPath(nn.Module):
        def __init__(self, input_channels: int):
            super().__init__()

            # define the first depthwise convolutional layer with kernel size 3x3
            self.dw_conv_1 = DepthwiseConv2d(input_channels=input_channels, kernel_size=3)

            # define the first pointwise convolutional layer, followed by LeakyReLU
            self.pw_conv_1 = PointwiseConv2d(input_channels=input_channels)
            self.lrelu_1 = nn.LeakyReLU()

            # define the second depthwise convolutional layer with kernel size 3x3
            self.dw_conv_2 = DepthwiseConv2d(input_channels=input_channels, kernel_size=3)

            # define the TFAM layer, followed by LeakyReLU
            self.tfam = TwoFoldAttentionModule(input_channels=input_channels)
            self.lrelu_2 = nn.LeakyReLU()

            # define the second pointwise convolution layer
            self.pw_conv_2 = PointwiseConv2d(input_channels=input_channels)

        def forward(self, input_tensor):
            # first, we feed the input tensor to the first depthwise convolutional layer
            dw_conv_1_out = self.dw_conv_1(input_tensor)

            # then, we feed the output to the first pointwise convolutional layer and to the first LReLU layer
            pw_conv_1_out = self.pw_conv_1(dw_conv_1_out)
            lrelu_1_out = self.lrelu_1(pw_conv_1_out)

            # after this, we feed the output to the second depthwise convolutional layer
            dw_conv_2_out = self.dw_conv_2(lrelu_1_out)

            # then, we feed the output to the TFAM module and to the second LReLU
            tfam_out = self.tfam(dw_conv_2_out)
            lrelu_2_out = self.lrelu_2(tfam_out)

            # finally, the output is fed to the second pointwise convolutional layer and returned
            output = self.pw_conv_2(lrelu_2_out)

            return output

    class AdaptivePath(nn.Module):
        def __init__(self, input_channels: int):
            super().__init__()

            # define the global average pooling layer
            self.global_avg_pooling = nn.AdaptiveAvgPool2d(output_size=1)

            # define the pointwise convolution layer
            self.pw_conv = PointwiseConv2d(input_channels=input_channels)

        def forward(self, input_tensor):
            # first, we feed the input tensor to the global average pooling layer
            global_avg_out = self.global_avg_pooling(input_tensor)

            # finally, we feed the output to the pointwise convolution layer
            output = self.pw_conv(global_avg_out)

            return output

    class ResidualPath(nn.Module):
        def __init__(self, input_channels: int):
            super().__init__()

            # define the depthwise convolution layer with kernel size 3x3
            self.dw_conv = DepthwiseConv2d(input_channels=input_channels, kernel_size=3)

        def forward(self, input_tensor):
            return self.dw_conv(input_tensor)

    def __init__(self, input_channels: int):
        super().__init__()

        # first, we define the bottleneck path
        self.bn_path = self.BottleneckPath(input_channels=input_channels)

        # second, we define the adaptive path
        self.ad_path = self.AdaptivePath(input_channels=input_channels)

        # third, we define the residual path
        self.res_path = self.ResidualPath(input_channels=input_channels)

    def forward(self, input_tensor):
        # as first step, we pass the input tensor to the bottleneck path
        bn_path_out = self.bn_path(input_tensor)

        # then, we compute element-wise sum between the input tensor and the output of the bottleneck path
        sum_bn_input = torch.add(input_tensor, bn_path_out)

        # we feed the sum to the residual path
        res_path_out = self.res_path(sum_bn_input)

        # then, we pass the input tensor to the adaptive path
        ad_path_out = self.ad_path(input_tensor)

        # finally, we compute the output as the element-wise sum between the output of the residual path and the output
        # of the adaptive path
        output = torch.add(res_path_out, ad_path_out)

        return output


class ResidualConcatenationBlock(nn.Module):
    def __init__(self, input_channels: int):
        super().__init__()

        # the definition of a Residual Concatenation Block contains three different Adaptive Residual Blocks that share
        # the same weights; in order to implement this, we simply define one RCB that will be called three times in the
        # forward function
        self.arb = AdaptiveResidualBlock(input_channels=input_channels)

        # define the first 1x1 convolutional layer
        first_conv_input_channels = input_channels * 2
        self.conv_1x1_1 = Conv2d1x1(input_channels=first_conv_input_channels, out_channels=input_channels)

        # define the second 1x1 convolutional layer
        second_conv_input_channels = input_channels * 3
        self.conv_1x1_2 = Conv2d1x1(input_channels=second_conv_input_channels, out_channels=input_channels)

        # define the third 1x1 convolutional layer
        third_conv_input_channels = input_channels * 4
        self.conv_1x1_3 = Conv2d1x1(input_channels=third_conv_input_channels, out_channels=input_channels)

        
    def forward(self, input_tensor):
        # first, feed the input tensor to the ARB
        arb_1_out = self.arb(input_tensor)

        # second, we concatenate the output of the first ARB block with the input tensor
        concat_1_out = torch.cat((input_tensor, arb_1_out), dim=1)

        # after this, we feed the concatenation to the first 1x1 convolutional layer
        conv_1x1_1_out = self.conv_1x1_1(concat_1_out)

        # we feed the output of the 1x1 convolutional layer to the ARB
        arb_2_out = self.arb(conv_1x1_1_out)

        # we concatenate the output of the second ARB block with the previous concatenation
        concat_2_out = torch.cat((concat_1_out, arb_2_out), dim=1)

        # we feed the concatenation to the second 1x1 conv
        conv_1x1_2_out = self.conv_1x1_2(concat_2_out)

        # we feed the output of the second 1x1 conv layer to the ARB
        arb_3_out = self.arb(conv_1x1_2_out)

        # we concatenate the output of the third ARB block with the previous concatenation
        concat_3_out = torch.cat((concat_2_out, arb_3_out), dim=1)

        # finally, we feed the concatenation to the third  and last 1x1 conv layer
        output = self.conv_1x1_3(concat_3_out)

        return output


class ResidualModule(nn.Module):
    def __init__(self, input_channels: int):
        super().__init__()

        # define the first Residual Concatenation Block
        self.rcb_1 = ResidualConcatenationBlock(input_channels=input_channels)

        # define the first 1x1 convolutional layer
        first_conv_input_channels = input_channels * 2
        self.conv_1x1_1 = Conv2d1x1(input_channels=first_conv_input_channels, out_channels=input_channels)

        # define the second Residual Concatenation Block
        self.rcb_2 = ResidualConcatenationBlock(input_channels=input_channels)

        # define the second 1x1 convolutional layer
        second_conv_input_channels = input_channels * 3
        self.conv_1x1_2 = Conv2d1x1(input_channels=second_conv_input_channels, out_channels=input_channels)

        # define the third Residual Concatenation Block
        self.rcb_3 = ResidualConcatenationBlock(input_channels=input_channels)

        # define the third 1x1 convolutional layer
        third_conv_input_channels = input_channels * 4
        self.conv_1x1_3 = Conv2d1x1(input_channels=third_conv_input_channels, out_channels=input_channels)

    def forward(self, h_sfe):
        # first, feed the input tensor (h_sfe) to the first RCB block
        rcb_1_out = self.rcb_1(h_sfe)

        # second, we concatenate the output of the first RCB block with the input tensor (h_sfe)
        concat_1_out = torch.cat((h_sfe, rcb_1_out), dim=1)

        # after this, we feed the concatenation to the first 1x1 convolutional layer
        conv_1x1_1_out = self.conv_1x1_1(concat_1_out)

        # we feed the output of the 1x1 convolutional layer to the second RCB
        rcb_2_out = self.rcb_2(conv_1x1_1_out)

        # we concatenate the output of the second RCB block with the previous concatenation
        concat_2_out = torch.cat((concat_1_out, rcb_2_out), dim=1)

        # we feed the concatenation to the second 1x1 conv
        conv_1x1_2_out = self.conv_1x1_2(concat_2_out)

        # we feed the output of the second 1x1 conv layer to the third RCB
        rcb_3_out = self.rcb_3(conv_1x1_2_out)

        # we concatenate the output of the third ARB block with the previous concatenation
        concat_3_out = torch.cat((concat_2_out, rcb_3_out), dim=1)

        # finally, we feed the concatenation to the third  and last 1x1 conv layer
        h_rm = self.conv_1x1_3(concat_3_out)

        return h_rm


class FeatureModule(nn.Module):
    def __init__(self, input_channels: int):
        super().__init__()

        # define the first layer, which is a TFAM
        self.tfam = TwoFoldAttentionModule(input_channels=input_channels)

        # define the second layer, which is a 3x3 conv layer
        kernel_size = 3
        padding_size = kernel_size // 2
        self.conv = nn.Conv2d(in_channels=input_channels, out_channels=input_channels,
                              kernel_size=(kernel_size, kernel_size), padding=padding_size)

    def forward(self, h_rm, h_sfe):
        # first, we feed the input tensor (h_rm) to the tfam layer
        tfam_out = self.tfam(h_rm)

        # then, we feed the output of the tfam layer to the convolutional layer
        h_gfe = self.conv(tfam_out)

        # finally, we compute the element-wise sum between the output of the convolutional layer and the shallow
        # features
        h_fm = torch.add(h_gfe, h_sfe)

        return h_fm


class UpNetModule(nn.Module):
    class Upsample2x(nn.Module):
        def __init__(self, input_channels: int):
            super().__init__()

            kernel_size = 3
            padding_size = kernel_size // 2

            # define the submodule that produces a feature map upsampled by 2x
            self.conv = nn.Conv2d(in_channels=input_channels, out_channels=input_channels * 4, kernel_size=(3, 3),
                                  padding=padding_size)
            self.pix_shuf = nn.PixelShuffle(upscale_factor=2)

        def forward(self, input_tensor):
            # feed the input tensor to the conv layer
            conv_out = self.conv(input_tensor)

            # feed the output of the conv layer to the pixel shuffle layer
            pix_shuf_out = self.pix_shuf(conv_out)

            return pix_shuf_out

    class Upsample3x(nn.Module):
        def __init__(self, input_channels: int):
            super().__init__()

            kernel_size = 3
            padding_size = kernel_size // 2

            # define the submodule that produces a feature map upsampled by 3x
            self.conv = nn.Conv2d(in_channels=input_channels, out_channels=input_channels * 9, kernel_size=(3, 3),
                                  padding=padding_size)
            self.pix_shuf = nn.PixelShuffle(upscale_factor=3)

        def forward(self, input_tensor):
            # feed the input tensor to the conv layer
            conv_out = self.conv(input_tensor)

            # feed the output of the conv layer to the pixel shuffle layer
            pix_shuf_out = self.pix_shuf(conv_out)

            return pix_shuf_out

    class Upsample4x(nn.Module):
        def __init__(self, input_channels: int):
            super().__init__()

            # define the first submodule that produces a feature map upsampled by 4x
            self.upsample_4x = nn.Sequential(UpNetModule.Upsample2x(input_channels=input_channels),
                                             UpNetModule.Upsample2x(input_channels=input_channels))

        def forward(self, input_tensor):
            # feed the input tensor to the upsampler
            return self.upsample_4x(input_tensor)

    def __init__(self, input_channels: int):
        super().__init__()

        # define the submodule that produces a feature map upsampled by 2x
        self.upsample_2x = self.Upsample2x(input_channels=input_channels)

        # define the submodule that produces a feature map upsampled by 3x
        self.upsample_3x = self.Upsample3x(input_channels=input_channels)

        # define the submodule that produces a feature map upsampled by 3x
        self.upsample_4x = self.Upsample4x(input_channels=input_channels)

    def forward(self, h_fm, scale: int):
        # feed the input tensor to one of the upsamplers according to the given scale
        if scale == 2:
            upsampled = self.upsample_2x(h_fm)
        elif scale == 3:
            upsampled = self.upsample_3x(h_fm)
        elif scale == 4:
            upsampled = self.upsample_4x(h_fm)
        else:
            raise Exception(f"Scale factor {scale} is invalid, select between 2, 3 or 4")

        return upsampled


class MultiPathResidualNetwork(nn.Module):
    def __init__(self, input_channels: int, n_features: int = 64):
        super().__init__()

        # initialize initial shallow feature extractor
        kernel_size = 3
        padding_size = kernel_size // 2
        self.sfe = nn.Conv2d(in_channels=input_channels, out_channels=n_features, kernel_size=(3, 3),
                             padding=padding_size)

        # define the Residual Module
        self.rm = ResidualModule(input_channels=n_features)

        # define the Feature Module
        self.fm = FeatureModule(input_channels=n_features)

        # define teh UpNet Module
        self.upnet = UpNetModule(input_channels=n_features)

        # define the final 3x3 convolution that restores the channels to three RGB channels
        self.out_conv = nn.Conv2d(in_channels=n_features, out_channels=input_channels, kernel_size=(3, 3),
                                  padding=padding_size)

    def forward(self, lrs, scale: int):
        # input is the batch of low resolution images, with shape (N, 3, 64, 64)
        h_sfe = self.sfe(lrs)  # output size (N, 64, 64, 64)

        # feed h_sfe to the residual module
        h_rm = self.rm(h_sfe)  # output size (N, 64, 64, 64)

        # feed h_rm and h_sfe to the feature module
        h_fm = self.fm(h_rm, h_sfe)  # output size (N, 64, 64, 64)

        # feed h_fm to the upnet module
        upscaled_fm = self.upnet(h_fm, scale)  # output size (N, 64, 64 * scale, 64 * scale)

        # feed upscaled feature map to the last 3x3 conv layer to get the final hr image in 3 RGB channels
        srs = self.out_conv(upscaled_fm)  # output size (N, 3,  64 * scale, 64 * scale)

        return srs

# Model Test
if __name__ == "__main__":
    model = MultiPathResidualNetwork(input_channels=3, n_features=128)
    x = torch.randn(4, 3, 160, 256)  # Input shape: [batch_size, channels, height, width]
    y = model(x, scale = 4)
    print(f"Output shape: {y.shape}")  # Expected output shape: [4, 3, 640, 1024]


Output shape: torch.Size([4, 3, 640, 1024])


In [6]:
from torch.nn import DataParallel


# Create the model
model = MultiPathResidualNetwork(input_channels=3, n_features=128).to("cuda")

model = DataParallel(model)


# Test input
# lrs = torch.randn(1, 3, 160,160).to("cuda")  # Batch of 2 images, 3 channels, 64x64 resolution
lrs = torch.randn(1,3,160,256)
scale = 4

# Forward pass
with torch.no_grad():
    srs = model(lrs, scale)

print("Input shape:", lrs.shape)
print("Output shape:", srs.shape)
assert srs.shape == (1, 3, 160 * scale, 256 * scale), "Output shape is incorrect!"

Input shape: torch.Size([1, 3, 160, 256])
Output shape: torch.Size([1, 3, 640, 1024])


In [7]:
import torch
import torch.nn as nn
from torchmetrics.functional import structural_similarity_index_measure

class SSIM_MSELoss(nn.Module):
    def __init__(self, alpha=0.8):
        """
        Combines SSIM loss and MSE loss for image restoration tasks.

        Args:
            alpha (float): Weight for MSE loss. SSIM weight is (1 - alpha).
        """
        super(SSIM_MSELoss, self).__init__()
        self.alpha = alpha
        self.mse_loss = nn.MSELoss()

    def forward(self, predicted, target):
        """
        Args:
            predicted (torch.Tensor): Predicted output image.
            target (torch.Tensor): Ground truth image.

        Returns:
            torch.Tensor: Combined loss value.
        """
        mse = self.mse_loss(predicted, target)
        ssim = structural_similarity_index_measure(predicted, target, data_range=1.0)  # SSIM assumes normalized data
        combined_loss = self.alpha * mse + (1 - self.alpha) * (1 - ssim)
        return combined_loss


In [8]:
from torchmetrics import PeakSignalNoiseRatio

def train_one_epoch(model, train_loader, criterion, psnr_metric, optimizer, device):
    """
    Train the model for one epoch.
    
    First principles:
    - Iterate through batches of training data
    - Perform forward pass to get model predictions
    - Calculate loss using the criterion
    - Backpropagate and update model weights
    
    Args:
        model (nn.Module): Super-resolution model
        train_loader (DataLoader): Training data loader
        criterion (nn.Module): Loss function
        psnr_metric (torchmetrics.Metric): PSNR metric
        optimizer (torch.optim.Optimizer): Optimization algorithm
        device (torch.device): Computing device (CPU/CUDA)
    
    Returns:
        tuple: (average training loss, average PSNR)
    """
    model.train()
    total_loss = 0.0
    total_psnr = 0.0
    
    progress_bar = tqdm(train_loader, desc='Training', unit='batch')
    for lr_images, hr_images in progress_bar:
        # Move data to the correct device
        lr_images = lr_images.to(device)
        hr_images = hr_images.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        sr_images = model(lrs=lr_images, scale=4)
        
        # Calculate loss
        loss = criterion(sr_images, hr_images)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate PSNR for monitoring
        batch_psnr = psnr_metric(sr_images, hr_images).mean().item()
        
        # Update progress bar and tracking metrics
        total_loss += loss.item()
        total_psnr += batch_psnr
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}', 
            'PSNR': f'{batch_psnr:.2f}'
        })
    
    # Return average loss and PSNR for the epoch
    return total_loss / len(train_loader), total_psnr / len(train_loader)



def validate(model, val_loader, criterion, psnr_metric, device):
    """
    Validate the model on the validation dataset.
    
    First principles:
    - Evaluate model performance on unseen data
    - No gradient computation during validation
    - Calculate loss and PSNR to assess model quality
    
    Args:
        model (nn.Module): Super-resolution model
        val_loader (DataLoader): Validation data loader
        criterion (nn.Module): Loss function
        psnr_metric (torchmetrics.Metric): PSNR metric
        device (torch.device): Computing device (CPU/CUDA)
    
    Returns:
        tuple: (average validation loss, average PSNR)
    """
    model.eval()
    total_loss = 0.0
    total_psnr = 0.0
    
    progress_bar = tqdm(val_loader, desc='Validation', unit='batch')
    with torch.no_grad():
        for lr_images, hr_images in progress_bar:
            # Move data to the correct device
            lr_images = lr_images.to(device)
            hr_images = hr_images.to(device)
            
            # Forward pass
            sr_images = model(lrs=lr_images, scale=4)
            
            # Calculate loss
            loss = criterion(sr_images, hr_images)
            
            # Calculate PSNR for monitoring
            batch_psnr = psnr_metric(sr_images, hr_images).mean().item()
            
            # Update tracking metrics
            total_loss += loss.item()
            total_psnr += batch_psnr
            
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}', 
                'PSNR': f'{batch_psnr:.2f}'
            })
    
    # Return average loss and PSNR for validation
    return total_loss / len(val_loader), total_psnr / len(val_loader)

def train_super_resolution_model(
    model, 
    train_loader, 
    val_loader, 
    num_epochs=3, 
    learning_rate=1e-4, 
    weight_decay=1e-5,
    checkpoint_path="best_model.pth"
):
    """
    Main training loop for super-resolution model.
    
    First principles:
    - Set up training environment
    - Iterate through epochs
    - Train and validate the model
    - Track and save best performing model

    Returns:
        dict: Training history with losses and PSNR values
    """
    # Ensure model is on CUDA
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Loss function (typically L1 or L2 loss for image reconstruction)
    criterion = SSIM_MSELoss(alpha=0.5)
    
    # PSNR metric
    psnr_metric = PeakSignalNoiseRatio().to(device) # torchmetrics.image.PeakSignalNoiseRatio().to(device)
    
    # Optimizer with weight decay (L2 regularization)
    optimizer = optim.Adam(
        model.parameters(), 
        lr=learning_rate, 
        weight_decay=weight_decay
    )
    
    # Learning rate scheduler (optional, but often helpful)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=5
    )
    
    # Track best model
    best_val_psnr = float('-inf')
    best_epoch = 0
    
    # Training history
    history = {
        'train_loss': [],
        'train_psnr': [],
        'val_loss': [],
        'val_psnr': []
    }
    
    # Start training loop
    for epoch in range(num_epochs):
        print(f"\nEpoch [{epoch+1}/{num_epochs}]")
        
        # Train the model for one epoch
        train_loss, train_psnr = train_one_epoch(
            model, train_loader, criterion, psnr_metric, optimizer, device
        )
        history['train_loss'].append(train_loss)
        history['train_psnr'].append(train_psnr)
        
        # Validate the model
        val_loss, val_psnr = validate(model, val_loader, criterion, psnr_metric, device)
        history['val_loss'].append(val_loss)
        history['val_psnr'].append(val_psnr)
        
        # Print the epoch stats
        print(f"Training Loss: {train_loss:.4f}, Training PSNR: {train_psnr:.2f}")
        print(f"Validation Loss: {val_loss:.4f}, Validation PSNR: {val_psnr:.2f}")
        
        # Update learning rate scheduler based on validation loss
        scheduler.step(val_loss)

        torch.cuda.empty_cache()
        
        # Save the best model (based on validation PSNR)
        if val_psnr > best_val_psnr:
            best_val_psnr = val_psnr
            best_epoch = epoch
            print(f"New best model found, saving model at epoch {epoch+1}")
            torch.save(model.state_dict(), checkpoint_path)
    
    print(f"\nTraining completed. Best validation PSNR: {best_val_psnr:.2f} at epoch {best_epoch+1}")
    torch.save(model.state_dict(), "/kaggle/working/final_model.pth")
    # Return the training history
    return history

In [9]:
history = train_super_resolution_model(
    model, train_loader, val_loader, num_epochs=15
)




Epoch [1/15]


Training: 100%|██████████| 277/277 [04:40<00:00,  1.01s/batch, Loss=0.0395, PSNR=30.10]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.13batch/s, Loss=0.0304, PSNR=26.67]


Training Loss: 0.0831, Training PSNR: 25.76
Validation Loss: 0.0379, Validation PSNR: 25.63
New best model found, saving model at epoch 1

Epoch [2/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0578, PSNR=27.95]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.12batch/s, Loss=0.0258, PSNR=27.63]


Training Loss: 0.0513, Training PSNR: 27.42
Validation Loss: 0.0334, Validation PSNR: 26.38
New best model found, saving model at epoch 2

Epoch [3/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0318, PSNR=31.41]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.12batch/s, Loss=0.0228, PSNR=28.61]


Training Loss: 0.0482, Training PSNR: 28.37
Validation Loss: 0.0305, Validation PSNR: 27.25
New best model found, saving model at epoch 3

Epoch [4/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0323, PSNR=29.43]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.12batch/s, Loss=0.0210, PSNR=28.98]


Training Loss: 0.0425, Training PSNR: 29.43
Validation Loss: 0.0285, Validation PSNR: 27.59
New best model found, saving model at epoch 4

Epoch [5/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0404, PSNR=29.56]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.11batch/s, Loss=0.0214, PSNR=28.11]


Training Loss: 0.0414, Training PSNR: 29.61
Validation Loss: 0.0290, Validation PSNR: 26.88

Epoch [6/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0373, PSNR=28.92]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.12batch/s, Loss=0.0216, PSNR=29.18]


Training Loss: 0.0413, Training PSNR: 29.51
Validation Loss: 0.0290, Validation PSNR: 27.79
New best model found, saving model at epoch 6

Epoch [7/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0346, PSNR=31.18]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.12batch/s, Loss=0.0200, PSNR=29.34]


Training Loss: 0.0407, Training PSNR: 29.78
Validation Loss: 0.0273, Validation PSNR: 27.93
New best model found, saving model at epoch 7

Epoch [8/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0414, PSNR=29.78]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.12batch/s, Loss=0.0198, PSNR=29.25]


Training Loss: 0.0403, Training PSNR: 29.83
Validation Loss: 0.0272, Validation PSNR: 27.84

Epoch [9/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0305, PSNR=31.93]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.11batch/s, Loss=0.0195, PSNR=29.49]


Training Loss: 0.0399, Training PSNR: 29.87
Validation Loss: 0.0268, Validation PSNR: 28.07
New best model found, saving model at epoch 9

Epoch [10/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0531, PSNR=28.22]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.11batch/s, Loss=0.0196, PSNR=29.38]


Training Loss: 0.0397, Training PSNR: 29.94
Validation Loss: 0.0268, Validation PSNR: 28.00

Epoch [11/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0087, PSNR=36.17]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.11batch/s, Loss=0.0195, PSNR=29.41]


Training Loss: 0.0396, Training PSNR: 29.95
Validation Loss: 0.0267, Validation PSNR: 28.02

Epoch [12/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0337, PSNR=30.29]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.12batch/s, Loss=0.0196, PSNR=29.40]


Training Loss: 0.0394, Training PSNR: 29.98
Validation Loss: 0.0268, Validation PSNR: 28.02

Epoch [13/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0417, PSNR=30.20]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.12batch/s, Loss=0.0197, PSNR=29.09]


Training Loss: 0.0394, Training PSNR: 30.02
Validation Loss: 0.0269, Validation PSNR: 27.77

Epoch [14/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0503, PSNR=29.30]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.10batch/s, Loss=0.0201, PSNR=29.08]


Training Loss: 0.0394, Training PSNR: 29.97
Validation Loss: 0.0274, Validation PSNR: 27.68

Epoch [15/15]


Training: 100%|██████████| 277/277 [04:41<00:00,  1.02s/batch, Loss=0.0416, PSNR=29.92]
Validation: 100%|██████████| 67/67 [00:21<00:00,  3.12batch/s, Loss=0.0192, PSNR=29.58]


Training Loss: 0.0393, Training PSNR: 30.06
Validation Loss: 0.0263, Validation PSNR: 28.19
New best model found, saving model at epoch 15

Training completed. Best validation PSNR: 28.19 at epoch 15


In [None]:
# Initialize the model
# Create the model
model = MultiPathResidualNetwork(input_channels=3, n_features=128).to("cuda")

model = DataParallel(model)

# Step 2: Load the state dictionary
checkpoint_path = "/kaggle/working/best_model.pth"
model.load_state_dict(torch.load(checkpoint_path))

In [None]:
import torch
from torchvision.transforms import Compose
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms.v2 as v2

def infer_super_res(image_path, model, transform=None, device="cpu", plot=True):
    """
    Perform super-resolution inference on a single image and plot the result.

    Args:
        image_path (str): Path to the input image.
        model (torch.nn.Module): Trained super-resolution model.
        transform (callable, optional): Transformations to be applied to the input image.
        device (str): Device to run the inference on ("cpu" or "cuda").
    """
    # Load and preprocess the input image
    input_image = Image.open(image_path).convert("RGB")
    if transform:
        input_image = transform(input_image)
    
    # Add batch dimension and move to the device
    input_image = input_image.unsqueeze(0).to(device)
    model = model.to(device)
    model.eval()

    # Perform inference
    with torch.no_grad():
        predicted_image = model(input_image, scale = 4).squeeze(0)  # Remove batch dimension
    
    # Convert the predicted tensor back to a NumPy array for plotting
    predicted_image_np = predicted_image.permute(1, 2, 0).cpu().numpy()  # HWC format
    predicted_image_np = (predicted_image_np * 255).astype("uint8")  # Scale to [0, 255]

    if plot == True:
        # Plot the result
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(Image.open(image_path))  # Plot the original image
        plt.title("Low-Light Image")
        plt.axis("off")
    
        plt.subplot(1, 2, 2)
        plt.imshow(predicted_image_np)  # Plot the predicted image
        plt.title("Predicted High-Resolution Image")
        plt.axis("off")
    
        plt.show()
    # Convert NumPy array to PIL image
    predicted_image_pil = Image.fromarray(predicted_image_np)

    return predicted_image_pil

# Example usage
if __name__ == "__main__":
    # Path to the input image
    image_path = "/kaggle/input/enhance-the-dark-world/archive/train/train/gt_00001.png"


    # Transform for the input image using torchvision v2
    image_transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        transforms.Normalize(mean=[0.183, 0.198, 0.176], std=[0.083, 0.085, 0.081]),  # Normalize to a standard range
    ])

    # Perform inference and plot
    device = "cuda" if torch.cuda.is_available() else "cpu"
    infer_super_res(image_path, model, transform=image_transform, device=device)

    i=0
    for dirname, _, filenames in os.walk('/kaggle/input/enhance-the-dark-world/archive/train/train'):
        for filename in filenames:
            image_path = os.path.join(dirname, filename)            
            # Predict using the provided function
            output_image = infer_super_res(image_path, model, transform=image_transform, device=device, plot=True)
            i += 1
            if i == 5:
                break
            

In [16]:
import os
import cv2
import numpy as np
from PIL import Image

out_folder = "/kaggle/working/output_images"

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

create_directory(out_folder)

# Transform for the input image using torchvision v2
image_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    transforms.Normalize(mean=[0.183, 0.198, 0.176], std=[0.083, 0.085, 0.081]),  # Normalize to a standard range
])
for dirname, _, filenames in os.walk('/kaggle/input/enhance-the-dark-world/archive/test'):
    for filename in filenames:
        image_path = os.path.join(dirname, filename)
        
        # Predict using the provided function
        output_image = infer_super_res(image_path, model, transform=image_transform, device=device, plot=False)
        
        # Save the output as a PNG file
        img_name = os.path.join(out_folder, os.path.splitext(filename)[0] + ".png")
        
        # Option 1: Save the PIL image directly
        output_image.save(img_name)
        

In [17]:
import pandas as pd
def images_to_csv(folder_path, output_csv):
    data_rows = []
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            image_path = os.path.join(folder_path, filename)
            image = Image.open(image_path).convert('L') 
            image_array = np.array(image).flatten()[::8]
            # Replace 'test_' with 'gt_' in the ID
            image_id = filename.split('.')[0].replace('test_', 'gt_')
            data_rows.append([image_id, *image_array])
    column_names = ['ID'] + [f'pixel_{i}' for i in range(len(data_rows[0]) - 1)]
    df = pd.DataFrame(data_rows, columns=column_names)
    df.to_csv(output_csv, index=False)
    print(f'Successfully saved to {output_csv}')

images_to_csv("/kaggle/working/output_images", "/kaggle/working/21F1001709.csv")

Successfully saved to /kaggle/working/21F1001709.csv
