In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import torchvision.models as models
!pip install pytorch-msssim
from pytorch_msssim import ssim as pytorch_ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image 

Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-1.0.0


In [2]:
class Inc(nn.Module):
    def __init__(self,in_channels,filters):
        super(Inc, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1,padding=(1-1) // 2),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=(3, 3), stride=(1, 1),dilation=1,padding=(3-1) // 2),
            nn.LeakyReLU(),
            )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1,padding=(1-1) // 2),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=(5, 5), stride=(1, 1),dilation=1,padding=(5-1) // 2),
            nn.LeakyReLU(),
            )
        self.branch3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1),
            nn.LeakyReLU(),

        )
        self.branch4 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1),
            nn.LeakyReLU(),
        )
    def forward(self,input):
        o1 = self.branch1(input)
        o2 = self.branch2(input)
        o3 = self.branch3(input)
        o4 = self.branch4(input)
        return torch.cat([o1,o2,o3,o4],dim=1)

def swish(x):
    return x * x.sigmoid()

def hard_sigmoid(x, inplace=False):
    return nn.ReLU6(inplace=inplace)(x + 3) / 6

def hard_swish(x, inplace=False):
    return x * hard_sigmoid(x, inplace)

class HardSigmoid(nn.Module):
    def __init__(self, inplace=False):
        super(HardSigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_sigmoid(x, inplace=self.inplace)

class HardSwish(nn.Module):
    def __init__(self, inplace=False):
        super(HardSwish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_swish(x, inplace=self.inplace)

def _make_divisible(v, divisor=8, min_value=None):  ## 将通道数变成8的整数倍
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class SELayer(nn.Module):
    def __init__(self, inp, oup, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Conv2d(oup, _make_divisible(inp // reduction), 1, 1, 0,),
                nn.ReLU(),
                nn.Conv2d(_make_divisible(inp // reduction), oup, 1, 1, 0),
                HardSigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class DSConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DSConvBlock, self).__init__()
        self.DW = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, groups=in_channels, padding=1, bias=False)
        self.BN1 = nn.BatchNorm2d(in_channels)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.BN2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        a = self.HS(self.BN1(self.DW(x)))
        a = self.HS(self.BN2(self.PW(a)))
        return a

class ConvBlock1(nn.Module):
    def __init__(self):
        super(ConvBlock1, self).__init__()
        self.DW = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, groups=16, padding=1, bias=False)
        self.BN = nn.BatchNorm2d(16)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False)
        self.BNN = nn.BatchNorm2d(32)

    def forward(self, x):
        a = self.HS(self.BN(self.DW(x)))
        a = self.HS(self.BNN(self.PW(a)))
        return a

class ConvBlock2(nn.Module):
    def __init__(self):
        super(ConvBlock2, self).__init__()
        self.DW = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, groups=32, padding=1, bias=False)
        self.BN = nn.BatchNorm2d(32)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False)
        self.BNN = nn.BatchNorm2d(64)

    def forward(self, x):
        a = self.HS(self.BN(self.DW(x)))
        a = self.HS(self.BNN(self.PW(a)))
        return a

class ConvBlock3(nn.Module):
    def __init__(self):
        super(ConvBlock3, self).__init__()
        self.DW = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, groups=64, padding=1, bias=False)
        self.BN = nn.BatchNorm2d(64)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False)
        self.BNN = nn.BatchNorm2d(32)

    def forward(self, x):
        a = self.HS(self.BN(self.DW(x)))
        a = self.HS(self.BNN(self.PW(a)))
        return a

class ConvBlock4(nn.Module):
    def __init__(self):
        super(ConvBlock4, self).__init__()
        self.DW = nn.Conv2d(in_channels=80, out_channels=80, kernel_size=3, stride=1, groups=80, padding=1, bias=False)
        self.BN = nn.BatchNorm2d(80)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=80, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False)
        self.BNN = nn.BatchNorm2d(32)
        self.SE = SELayer(80, 80)

    def forward(self, x):

        a = self.HS(self.BN(self.DW(x)))
        a = self.SE(a)
        a = self.HS(self.BNN(self.PW(a)))
        return a

class Mynet(nn.Module):
    def __init__(self, num_layers=3):
        super(Mynet, self).__init__()
        self.input = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1, stride=1, padding=0, bias=False)  ## 第一层卷积
        self.output = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1, stride=1, padding=0, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.block1 = ConvBlock1()
        self.block2 = ConvBlock2()
        self.block3 = ConvBlock3()
        self.block4 = ConvBlock4()

    def forward(self, x):
        x = self.input(x)
        x1 = self.block1(x)
        x2 = self.block2(x1)
        # x2 = torch.cat((x, x2), 1)
        x3 = self.block3(x2)
        x3 = torch.cat((x, x1, x3), 1)
        x4 = self.block4(x3)
        out = self.output(x4)
        return out

In [3]:
# Define the complete model with additional DS Conv blocks and ConvBlock4
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.inception_block_r = Inc(in_channels=1, filters=64)
        self.inception_block_g = Inc(in_channels=1, filters=64)
        self.inception_block_b = Inc(in_channels=1, filters=64)
        self.se_layer_r = SELayer(inp=256, oup=256)
        self.se_layer_g = SELayer(inp=256, oup=256)
        self.se_layer_b = SELayer(inp=256, oup=256)
        self.ds_conv1 = DSConvBlock(in_channels=768, out_channels=256)
        self.ds_conv2 = DSConvBlock(in_channels=256, out_channels=128)
        self.ds_conv3 = DSConvBlock(in_channels=128, out_channels=16)
        self.ds_conv4 = DSConvBlock(in_channels=16, out_channels=32)
        self.ds_conv5 = DSConvBlock(in_channels=32, out_channels=64)
        self.ds_conv6 = DSConvBlock(in_channels=64, out_channels=32)
        self.conv_block4 = ConvBlock4()
        self.final_conv = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Split the input into R, G, B channels
        r, g, b = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:3, :, :]

        # Process each channel independently through Inception and SE layers
        r = self.se_layer_r(self.inception_block_r(r))
        g = self.se_layer_g(self.inception_block_g(g))
        b = self.se_layer_b(self.inception_block_b(b))

        # Concatenate the outputs along the channel dimension (dim=1)
        x = torch.cat([r, g, b], dim=1)  # Shape: (1, 768, 256, 256)

        # Pass through the initial depthwise separable convolution blocks
        x = self.ds_conv1(x)  # Output shape: (1, 256, 256, 256)
        x = self.ds_conv2(x)  # Output shape: (1, 128, 256, 256)
        x = self.ds_conv3(x)  # Output shape: (1, 16, 256, 256)

        # Apply additional DS Conv blocks
        x1 = self.ds_conv4(x)  # Output shape: (1, 32, 256, 256)
        x2 = self.ds_conv5(x1)  # Output shape: (1, 64, 256, 256)
        x3 = self.ds_conv6(x2)  # Output shape: (1, 32, 256, 256)

        # Concatenate all outputs along the channel dimension (dim=1)
        x = torch.cat([x, x1, x2, x3], dim=1)  # Shape: (1, 16 + 32 + 64 + 32, 256, 256) = (1, 144, 256, 256)
        x = x[:,0:80,:,:]  # Output shape: (1, 80, 256, 256)
        # Adjust channels before passing through ConvBlock4
        x = self.conv_block4(x)  # Output shape: (1, 32, 256, 256)
        # Apply final 1x1 Conv with sigmoid
        x = self.final_conv(x)  # Output shape: (1, 3, 256, 256)
        x = self.sigmoid(x)  # Output shape: (1, 3, 256, 256)
        return x


# Loss Functions
class SSIMLoss(nn.Module):
    def __init__(self):
        super(SSIMLoss, self).__init__()

    def forward(self, img1, img2):
        ssim_loss = 1 - pytorch_ssim(img1, img2, data_range=1, size_average=True)
        return ssim_loss

class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.vgg = nn.Sequential(*list(vgg)[:36]).eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.criterion = nn.MSELoss()

    def forward(self, x, y):
        x_vgg = self.vgg(x)
        y_vgg = self.vgg(y)
        loss = self.criterion(x_vgg, y_vgg)
        return loss

class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.ssim_loss = SSIMLoss()
        self.vgg_loss = VGGLoss()
        self.mse_loss = nn.MSELoss()

    def forward(self, output, target):
        loss = self.mse_loss(output, target) + self.ssim_loss(output, target) + self.vgg_loss(output, target)
        return loss

# Custom Dataset class for paired images
class UIEDataset(Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.transform = transform
        self.image_names = [img for img in os.listdir(raw_dir) if img in os.listdir(reference_dir)]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        raw_image_path = os.path.join(self.raw_dir, self.image_names[idx])
        reference_image_path = os.path.join(self.reference_dir, self.image_names[idx])

        raw_image = Image.open(raw_image_path).convert("RGB")
        reference_image = Image.open(reference_image_path).convert("RGB")

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image  # Return both raw and reference images as a pair

# Calculate PSNR
def calculate_psnr(img1, img2):
    img1_np = img1.detach().cpu().numpy().transpose(0, 2, 3, 1)
    img2_np = img2.detach().cpu().numpy().transpose(0, 2, 3, 1)
    return np.mean([psnr(img1_np[i], img2_np[i], data_range=1) for i in range(img1_np.shape[0])])


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import torchvision.models as models
from pytorch_msssim import ssim as pytorch_ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image

class Inc(nn.Module):
    def __init__(self,in_channels,filters):
        super(Inc, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1,padding=(1-1) // 2),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=(3, 3), stride=(1, 1),dilation=1,padding=(3-1) // 2),
            nn.LeakyReLU(),
            )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1,padding=(1-1) // 2),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=(5, 5), stride=(1, 1),dilation=1,padding=(5-1) // 2),
            nn.LeakyReLU(),
            )
        self.branch3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1),
            nn.LeakyReLU(),

        )
        self.branch4 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1),
            nn.LeakyReLU(),
        )
    def forward(self,input):
        o1 = self.branch1(input)
        o2 = self.branch2(input)
        o3 = self.branch3(input)
        o4 = self.branch4(input)
        return torch.cat([o1,o2,o3,o4],dim=1)

def swish(x):
    return x * x.sigmoid()

def hard_sigmoid(x, inplace=False):
    return nn.ReLU6(inplace=inplace)(x + 3) / 6

def hard_swish(x, inplace=False):
    return x * hard_sigmoid(x, inplace)

class HardSigmoid(nn.Module):
    def __init__(self, inplace=False):
        super(HardSigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_sigmoid(x, inplace=self.inplace)

class HardSwish(nn.Module):
    def __init__(self, inplace=False):
        super(HardSwish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_swish(x, inplace=self.inplace)

def _make_divisible(v, divisor=8, min_value=None):  ## å°†é€šé“æ•°å˜æˆ8çš„æ•´æ•°å€
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class SELayer(nn.Module):
    def __init__(self, inp, oup, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Conv2d(oup, _make_divisible(inp // reduction), 1, 1, 0,),
                nn.ReLU(),
                nn.Conv2d(_make_divisible(inp // reduction), oup, 1, 1, 0),
                HardSigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class DSConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DSConvBlock, self).__init__()
        self.DW = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, groups=in_channels, padding=1, bias=False)
        self.BN1 = nn.BatchNorm2d(in_channels)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.BN2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        a = self.HS(self.BN1(self.DW(x)))
        a = self.HS(self.BN2(self.PW(a)))
        return a

class ConvBlock1(nn.Module):
    def __init__(self):
        super(ConvBlock1, self).__init__()
        self.DW = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, groups=16, padding=1, bias=False)
        self.BN = nn.BatchNorm2d(16)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False)
        self.BNN = nn.BatchNorm2d(32)

    def forward(self, x):
        a = self.HS(self.BN(self.DW(x)))
        a = self.HS(self.BNN(self.PW(a)))
        return a

class ConvBlock2(nn.Module):
    def __init__(self):
        super(ConvBlock2, self).__init__()
        self.DW = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, groups=32, padding=1, bias=False)
        self.BN = nn.BatchNorm2d(32)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False)
        self.BNN = nn.BatchNorm2d(64)

    def forward(self, x):
        a = self.HS(self.BN(self.DW(x)))
        a = self.HS(self.BNN(self.PW(a)))
        return a

class ConvBlock3(nn.Module):
    def __init__(self):
        super(ConvBlock3, self).__init__()
        self.DW = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, groups=64, padding=1, bias=False)
        self.BN = nn.BatchNorm2d(64)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False)
        self.BNN = nn.BatchNorm2d(32)

    def forward(self, x):
        a = self.HS(self.BN(self.DW(x)))
        a = self.HS(self.BNN(self.PW(a)))
        return a

class ConvBlock4(nn.Module):
    def __init__(self):
        super(ConvBlock4, self).__init__()
        self.DW = nn.Conv2d(in_channels=80, out_channels=80, kernel_size=3, stride=1, groups=80, padding=1, bias=False)
        self.BN = nn.BatchNorm2d(80)
        self.HS = HardSwish()
        self.PW = nn.Conv2d(in_channels=80, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False)
        self.BNN = nn.BatchNorm2d(32)
        self.SE = SELayer(80, 80)

    def forward(self, x):

        a = self.HS(self.BN(self.DW(x)))
        a = self.SE(a)
        a = self.HS(self.BNN(self.PW(a)))
        return a

class Mynet(nn.Module):
    def __init__(self, num_layers=3):
        super(Mynet, self).__init__()
        self.input = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1, stride=1, padding=0, bias=False)  ## ç¬¬ä¸€å±‚å·ç§¯
        self.output = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1, stride=1, padding=0, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.block1 = ConvBlock1()
        self.block2 = ConvBlock2()
        self.block3 = ConvBlock3()
        self.block4 = ConvBlock4()

    def forward(self, x):
        x = self.input(x)
        x1 = self.block1(x)
        x2 = self.block2(x1)
        # x2 = torch.cat((x, x2), 1)
        x3 = self.block3(x2)
        x3 = torch.cat((x, x1, x3), 1)
        x4 = self.block4(x3)
        out = self.output(x4)
        return out

from torchvision.models import VGG19_Weights
# Define the complete model with additional DS Conv blocks and ConvBlock4
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.inception_block_r = Inc(in_channels=1, filters=64)
        self.inception_block_g = Inc(in_channels=1, filters=64)
        self.inception_block_b = Inc(in_channels=1, filters=64)
        self.se_layer_r = SELayer(inp=256, oup=256)
        self.se_layer_g = SELayer(inp=256, oup=256)
        self.se_layer_b = SELayer(inp=256, oup=256)
        self.ds_conv1 = DSConvBlock(in_channels=768, out_channels=256)
        self.ds_conv2 = DSConvBlock(in_channels=256, out_channels=128)
        self.ds_conv3 = DSConvBlock(in_channels=128, out_channels=16)
        self.ds_conv4 = DSConvBlock(in_channels=16, out_channels=32)
        self.ds_conv5 = DSConvBlock(in_channels=32, out_channels=64)
        self.ds_conv6 = DSConvBlock(in_channels=64, out_channels=32)
        self.conv_block4 = ConvBlock4()
        self.final_conv = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Split the input into R, G, B channels
        r, g, b = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:3, :, :]

        # Process each channel independently through Inception and SE layers
        r = self.se_layer_r(self.inception_block_r(r))
        g = self.se_layer_g(self.inception_block_g(g))
        b = self.se_layer_b(self.inception_block_b(b))

        # Concatenate the outputs along the channel dimension (dim=1)
        x = torch.cat([r, g, b], dim=1)  # Shape: (1, 768, 256, 256)

        # Pass through the initial depthwise separable convolution blocks
        x = self.ds_conv1(x)  # Output shape: (1, 256, 256, 256)
        x = self.ds_conv2(x)  # Output shape: (1, 128, 256, 256)
        x = self.ds_conv3(x)  # Output shape: (1, 16, 256, 256)

        # Apply additional DS Conv blocks
        x1 = self.ds_conv4(x)  # Output shape: (1, 32, 256, 256)
        x2 = self.ds_conv5(x1)  # Output shape: (1, 64, 256, 256)
        x3 = self.ds_conv6(x2)  # Output shape: (1, 32, 256, 256)

        # Concatenate all outputs along the channel dimension (dim=1)
        x = torch.cat([x, x1, x2, x3], dim=1)  # Shape: (1, 16 + 32 + 64 + 32, 256, 256) = (1, 144, 256, 256)
        x = x[:,0:80,:,:]  # Output shape: (1, 80, 256, 256)
        # Adjust channels before passing through ConvBlock4
        x = self.conv_block4(x)  # Output shape: (1, 32, 256, 256)
        # Apply final 1x1 Conv with sigmoid
        x = self.final_conv(x)  # Output shape: (1, 3, 256, 256)
        x = self.sigmoid(x)  # Output shape: (1, 3, 256, 256)
        return x


# Loss Functions
class SSIMLoss(nn.Module):
    def __init__(self):
        super(SSIMLoss, self).__init__()

    def forward(self, img1, img2):
        ssim_loss = 1 - pytorch_ssim(img1, img2, data_range=1, size_average=True)
        return ssim_loss

class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        vgg = models.vgg19(weights=VGG19_Weights.DEFAULT).features
        self.vgg = nn.Sequential(*list(vgg)[:36]).eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.criterion = nn.MSELoss()

    def forward(self, x, y):
        x_vgg = self.vgg(x)
        y_vgg = self.vgg(y)
        loss = self.criterion(x_vgg, y_vgg)
        return loss

class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.ssim_loss = SSIMLoss()
        self.vgg_loss = VGGLoss()
        self.mse_loss = nn.MSELoss()

    def forward(self, output, target):
        loss = self.mse_loss(output, target) + self.ssim_loss(output, target) + self.vgg_loss(output, target)
        return loss


import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
import pytorch_msssim
from tqdm import tqdm  # Progress bar for caching

class StandardDataset(Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.transform = transform

        self.raw_images = sorted(os.listdir(self.raw_dir))
        self.reference_images = sorted(os.listdir(self.reference_dir))

    def __len__(self):
        return len(self.raw_images)

    def __getitem__(self, idx):
        raw_image_path = os.path.join(self.raw_dir, self.raw_images[idx])
        reference_image_path = os.path.join(self.reference_dir, self.reference_images[idx])

        raw_image = Image.open(raw_image_path).convert("RGB")
        reference_image = Image.open(reference_image_path).convert("RGB")

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image

# Ensure calculations like PSNR are performed on the GPU
def calculate_psnr(outputs, targets):
    outputs_np = outputs.detach().cpu().numpy()
    targets_np = targets.detach().cpu().numpy()
    psnrs = [psnr(outputs_np[i].transpose(1, 2, 0), targets_np[i].transpose(1, 2, 0), data_range=1.0)
             for i in range(outputs.size(0))]
    return np.mean(psnrs)

# Calculate Mean Square Error (MSE) on GPU
def calculate_mse(outputs, targets):
    mse_loss = nn.MSELoss().to(outputs.device)
    mse_value = mse_loss(outputs, targets)
    return mse_value.item()

def train_model(resume_training=True, checkpoint_path='/home4/qaiser.khan/UnderWaterImgEnh/Dataset2_model_state_dict.pth'):
    # Hyperparameters
    learning_rate = 0.0002
    batch_size = 4
    num_epochs = 100
    start_epoch = 10  # Will update if resuming training

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Model, Loss, Optimizer
    model = CustomModel().to(device)
    criterion = CombinedLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Resume training from checkpoint if provided
    if resume_training:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint)
        print(f"Resuming training from saved model weights...")

    # Dataset and DataLoader
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    train_dataset = StandardDataset(
        raw_dir='/home4/qaiser.khan/UnderWaterImgEnh/ImgEnh-2/Train/Raw',
        reference_dir='/home4/qaiser.khan/UnderWaterImgEnh/ImgEnh-2/Train/Reference',
        transform=transform
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # For plotting
    epoch_losses = []
    epoch_psnrs = []
    epoch_ssims = []
    epoch_mses = []

    best_psnr = float('-inf')
    best_mse = float('inf')
    best_epoch = -1

    print(f"Number of images in the dataset: {len(train_dataset)}")
    print("Training started...")

    total_start_time = time.time()

    for epoch in range(start_epoch, num_epochs):
        epoch_start_time = time.time()

        model.train()
        epoch_loss = 0
        epoch_psnr = 0
        epoch_ssim = 0
        epoch_mse = 0

        for batch_idx, (images, targets) in enumerate(train_loader):
            images = images.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_psnr += calculate_psnr(outputs, targets)
            epoch_ssim += pytorch_msssim.ssim(outputs, targets).item()
            epoch_mse += calculate_mse(outputs, targets)

        avg_loss = epoch_loss / len(train_loader)
        avg_psnr = epoch_psnr / len(train_loader)
        avg_ssim = epoch_ssim / len(train_loader)
        avg_mse = epoch_mse / len(train_loader)

        epoch_losses.append(avg_loss)
        epoch_psnrs.append(avg_psnr)
        epoch_ssims.append(avg_ssim)
        epoch_mses.append(avg_mse)

        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time

        print(f'Epoch [{epoch+1}/{num_epochs}], Time: {epoch_duration:.2f}s, '
              f'Average Loss: {avg_loss:.4f}, Average PSNR: {avg_psnr:.4f}, '
              f'Average SSIM: {avg_ssim:.4f}, Average MSE: {avg_mse:.4f}')

        # Save best weights every 25 epochs
        if (epoch + 1) % 25 == 0:
            torch.save(model.state_dict(), f'best_model_weights_epoch_{epoch+1}.pth')
            print(f'Best model weights saved at epoch {epoch + 1}')

    total_end_time = time.time()
    total_training_time = total_end_time - total_start_time

    print(f'Total Training Time: {total_training_time:.2f}s')

    # Save the final model
    torch.save(model, 'final_model.pt')
    print('Entire model saved!')
    torch.save(model.state_dict(), 'final_model_state_dict.pth')
    print('Model state dict saved!')

    # Save enhanced images
    os.makedirs('enhanced_images', exist_ok=True)
    model.eval()
    with torch.no_grad():
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            outputs = model(images)
            for j in range(images.size(0)):
                save_image(outputs[j].cpu(), f'enhanced_images/enhanced_image_{i}_{j}.png')
    print('All enhanced images saved!')

    # Plot training metrics
    plt.figure(figsize=(16, 4))

    plt.subplot(1, 4, 1)
    plt.plot(range(1, num_epochs + 1), epoch_losses, label='Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss over Epochs')

    plt.subplot(1, 4, 2)
    plt.plot(range(1, num_epochs + 1), epoch_psnrs, label='PSNR', color='g')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR')
    plt.title('PSNR over Epochs')

    plt.subplot(1, 4, 3)
    plt.plot(range(1, num_epochs + 1), epoch_ssims, label='SSIM', color='r')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')
    plt.title('SSIM over Epochs')

    plt.subplot(1, 4, 4)
    plt.plot(range(1, num_epochs + 1), epoch_mses, label='MSE', color='b')
    plt.xlabel('Epoch')
    plt.ylabel('MSE')
    plt.title('MSE over Epochs')

    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()

# To resume training
train_model(resume_training=True)

# To start training from scratch
#train_model()


In [None]:
# Use this code to start training on multi GPUs (2, 3 etc)
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
import pytorch_msssim

# updated to save all enhanced images with a size of 416 by 416
def calculate_psnr(outputs, targets):
    """
    Calculate PSNR for each image in the batch.
    """
    outputs_np = outputs.detach().cpu().numpy()
    targets_np = targets.detach().cpu().numpy()
    psnrs = []
    for i in range(outputs_np.shape[0]):
        psnr_value = psnr(outputs_np[i].transpose(1, 2, 0), targets_np[i].transpose(1, 2, 0), data_range=1.0)
        psnrs.append(psnr_value)
    return np.mean(psnrs)

# Calculate Mean Square Error (MSE)
def calculate_mse(outputs, targets):
    """
    Calculate MSE for each image in the batch.
    """
    mse_loss = nn.MSELoss()
    mse_value = mse_loss(outputs, targets)
    return mse_value.item()

def train_model():
    # Hyperparameters
    learning_rate = 0.01
    batch_size = 4
    num_epochs = 2

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    devices = [0, 1]  # Specify GPU IDs to use

    # Model, Loss, Optimizer
    model = CustomModel().to(device)
    model = nn.DataParallel(model, device_ids=devices)  # Use DataParallel to split the model across multiple GPUs

    criterion = CombinedLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Dataset and DataLoader
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    train_dataset = UIEDataset('/kaggle/input/imgenh-2/dataset-2/Train', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size * len(devices), shuffle=True)

    # For plotting
    epoch_losses = []
    epoch_psnrs = []
    epoch_ssims = []
    epoch_mses = []

    print(f"Number of images in the dataset: {len(train_dataset)}")
    print("Training started...")

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_psnr = 0
        epoch_ssim = 0
        epoch_mse = 0
        for batch_idx, (images, targets) in enumerate(train_loader):
            images = images.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_psnr += calculate_psnr(outputs, targets)
            epoch_ssim += pytorch_msssim.ssim(outputs, targets).item()
            epoch_mse += calculate_mse(outputs, targets)

        avg_loss = epoch_loss / len(train_loader)
        avg_psnr = epoch_psnr / len(train_loader)
        avg_ssim = epoch_ssim / len(train_loader)
        avg_mse = epoch_mse / len(train_loader)

        epoch_losses.append(avg_loss)
        epoch_psnrs.append(avg_psnr)
        epoch_ssims.append(avg_ssim)
        epoch_mses.append(avg_mse)

        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}, Average PSNR: {avg_psnr:.4f}, Average SSIM: {avg_ssim:.4f}, Average MSE: {avg_mse:.4f}')

    # Save model weights at the end of training
    torch.save(model.state_dict(), 'UIEModel_final.pth')
    print('Model weights saved!')

    # Save all enhanced images at the end of training
    os.makedirs('enhanced_images', exist_ok=True)  # Create a directory for enhanced images
    model.eval()
    with torch.no_grad():
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            outputs = model(images)
            for j in range(images.size(0)):
                save_image(outputs[j].cpu(), f'enhanced_images/enhanced_image_{i}_{j}.png')
    print('All enhanced images saved!')

    # Plotting the metrics
    plt.figure(figsize=(16, 4))

    plt.subplot(1, 4, 1)
    plt.plot(range(1, num_epochs + 1), epoch_losses, label='Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss vs Epoch')
    plt.legend()

    plt.subplot(1, 4, 2)
    plt.plot(range(1, num_epochs + 1), epoch_psnrs, label='PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR')
    plt.title('PSNR vs Epoch')
    plt.legend()

    plt.subplot(1, 4, 3)
    plt.plot(range(1, num_epochs + 1), epoch_ssims, label='SSIM')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')
    plt.title('SSIM vs Epoch')
    plt.legend()

    plt.subplot(1, 4, 4)
    plt.plot(range(1, num_epochs + 1), epoch_mses, label='MSE')
    plt.xlabel('Epoch')
    plt.ylabel('MSE')
    plt.title('MSE vs Epoch')
    plt.legend()

    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()

if __name__ == "__main__":
    train_model()


In [None]:
import os
import zipfile

# Define the path to the Kaggle output directory and the directory to be zipped
output_dir = '/kaggle/working/'  # Path to the directory containing images
zip_filename = '/kaggle/working/enhanced_images.zip'  # Path for the zip file

# Create a zip file
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    # Walk through the directory
    for root, dirs, files in os.walk(output_dir):
        for file in files:
            # Create the full file path
            file_path = os.path.join(root, file)
            # Add file to the zip file with relative path
            zipf.write(file_path, os.path.relpath(file_path, os.path.join(output_dir, '..')))

print(f'Zip file created at: {zip_filename}')

# Note: You can download the zip file using Kaggle's interface.


In [None]:
# Our training model code 
import os
import time  # Import time module to measure training time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
import pytorch_msssim

# updated to save all enhanced images with a size of 256 by 256 
def calculate_psnr(outputs, targets):
    """
    Calculate PSNR for each image in the batch.
    """
    outputs_np = outputs.detach().cpu().numpy()
    targets_np = targets.detach().cpu().numpy()
    psnrs = []
    for i in range(outputs_np.shape[0]):
        psnr_value = psnr(outputs_np[i].transpose(1, 2, 0), targets_np[i].transpose(1, 2, 0), data_range=1.0)
        psnrs.append(psnr_value)
    return np.mean(psnrs)

# Calculate Mean Square Error (MSE)
def calculate_mse(outputs, targets):
    """
    Calculate MSE for each image in the batch.
    """
    mse_loss = nn.MSELoss()
    mse_value = mse_loss(outputs, targets)
    return mse_value.item()

def train_model():
    # Hyperparameters
    learning_rate = 0.01
    batch_size = 4
    num_epochs = 150

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Model, Loss, Optimizer
    model = CustomModel().to(device)
    criterion = CombinedLoss().to(device)
    optimizer = optim.Adamax(model.parameters(), lr=learning_rate)

    # Dataset and DataLoader
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    train_dataset = UIEDataset(
    raw_dir='/kaggle/input/uieb-data/UIEB/Train/Raw',
    reference_dir='/kaggle/input/uieb-data/UIEB/Train/Reference',
    transform=transform
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


    # For plotting
    epoch_losses = []
    epoch_psnrs = []
    epoch_ssims = []
    epoch_mses = []  # Added list to store MSE values

    print(f"Number of images in the dataset: {len(train_dataset)}")
    print("Training started...")

    # Start the total training timer
    total_start_time = time.time()

    for epoch in range(num_epochs):
        epoch_start_time = time.time()  # Start the timer for the epoch

        model.train()
        epoch_loss = 0
        epoch_psnr = 0
        epoch_ssim = 0
        epoch_mse = 0  # Initialize MSE for each epoch
        for batch_idx, (images, targets) in enumerate(train_loader):
            images = images.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            epoch_psnr += calculate_psnr(outputs, targets)
            epoch_ssim += pytorch_msssim.ssim(outputs, targets).item()
            epoch_mse += calculate_mse(outputs, targets)  # Calculate and accumulate MSE

        avg_loss = epoch_loss / len(train_loader)
        avg_psnr = epoch_psnr / len(train_loader)
        avg_ssim = epoch_ssim / len(train_loader)
        avg_mse = epoch_mse / len(train_loader)  # Average MSE for the epoch

        epoch_losses.append(avg_loss)
        epoch_psnrs.append(avg_psnr)
        epoch_ssims.append(avg_ssim)
        epoch_mses.append(avg_mse)  # Store the average MSE

        epoch_end_time = time.time()  # End the timer for the epoch
        epoch_duration = epoch_end_time - epoch_start_time

        print(f'Epoch [{epoch+1}/{num_epochs}], Time: {epoch_duration:.2f}s, '
              f'Average Loss: {avg_loss:.4f}, Average PSNR: {avg_psnr:.4f}, '
              f'Average SSIM: {avg_ssim:.4f}, Average MSE: {avg_mse:.4f}')

    # End the total training timer
    total_end_time = time.time()
    total_training_time = total_end_time - total_start_time

    print(f'Total Training Time: {total_training_time:.2f}s')

    # Save the entire model at the end of training
    torch.save(model, 'Dataset1_model.pth')
    print('Entire model saved!')

    # Save all enhanced images at the end of training
    os.makedirs('enhanced_images', exist_ok=True)  # Create a directory for enhanced images
    model.eval()
    with torch.no_grad():
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            outputs = model(images)
            for j in range(images.size(0)):
                save_image(outputs[j].cpu(), f'enhanced_images/enhanced_image_{i}_{j}.png')
    print('All enhanced images saved!')

    # Plotting the metrics
    plt.figure(figsize=(16, 4))

    plt.subplot(1, 4, 1)
    plt.plot(range(1, num_epochs + 1), epoch_losses, label='Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss vs Epoch')
    plt.legend()

    plt.subplot(1, 4, 2)
    plt.plot(range(1, num_epochs + 1), epoch_psnrs, label='PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR')
    plt.title('PSNR vs Epoch')
    plt.legend()

    plt.subplot(1, 4, 3)
    plt.plot(range(1, num_epochs + 1), epoch_ssims, label='SSIM')
    plt.xlabel('Epoch')
    plt.ylabel('SSIM')
    plt.title('SSIM vs Epoch')
    plt.legend()

    plt.subplot(1, 4, 4)  # Added subplot for MSE
    plt.plot(range(1, num_epochs + 1), epoch_mses, label='MSE')
    plt.xlabel('Epoch')
    plt.ylabel('MSE')
    plt.title('MSE vs Epoch')
    plt.legend()

    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()

if __name__ == "__main__":
    train_model()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
import os
import time

# Define your custom dataset class
class UIEDataset(torch.utils.data.Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.transform = transform
        self.raw_images = sorted(os.listdir(raw_dir))
        self.reference_images = sorted(os.listdir(reference_dir))

    def __len__(self):
        return len(self.raw_images)

    def __getitem__(self, idx):
        raw_image = Image.open(os.path.join(self.raw_dir, self.raw_images[idx])).convert('RGB')
        reference_image = Image.open(os.path.join(self.reference_dir, self.reference_images[idx])).convert('RGB')
        
        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)
        
        return raw_image, reference_image

# Define the complete model with additional DS Conv blocks and ConvBlock4
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.inception_block_r = Inc(in_channels=1, filters=64)
        self.inception_block_g = Inc(in_channels=1, filters=64)
        self.inception_block_b = Inc(in_channels=1, filters=64)
        self.se_layer_r = SELayer(inp=256, oup=256)
        self.se_layer_g = SELayer(inp=256, oup=256)
        self.se_layer_b = SELayer(inp=256, oup=256)
        self.ds_conv1 = DSConvBlock(in_channels=768, out_channels=256)
        self.ds_conv2 = DSConvBlock(in_channels=256, out_channels=128)
        self.ds_conv3 = DSConvBlock(in_channels=128, out_channels=16)
        self.ds_conv4 = DSConvBlock(in_channels=16, out_channels=32)
        self.ds_conv5 = DSConvBlock(in_channels=32, out_channels=64)
        self.ds_conv6 = DSConvBlock(in_channels=64, out_channels=32)
        self.conv_block4 = ConvBlock4()
        self.final_conv = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Split the input into R, G, B channels
        r, g, b = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:3, :, :]

        # Process each channel independently through Inception and SE layers
        r = self.se_layer_r(self.inception_block_r(r))
        g = self.se_layer_g(self.inception_block_g(g))
        b = self.se_layer_b(self.inception_block_b(b))

        # Concatenate the outputs along the channel dimension (dim=1)
        x = torch.cat([r, g, b], dim=1)  # Shape: (1, 768, 256, 256)

        # Pass through the initial depthwise separable convolution blocks
        x = self.ds_conv1(x)  # Output shape: (1, 256, 256, 256)
        x = self.ds_conv2(x)  # Output shape: (1, 128, 256, 256)
        x = self.ds_conv3(x)  # Output shape: (1, 16, 256, 256)

        # Apply additional DS Conv blocks
        x1 = self.ds_conv4(x)  # Output shape: (1, 32, 256, 256)
        x2 = self.ds_conv5(x1)  # Output shape: (1, 64, 256, 256)
        x3 = self.ds_conv6(x2)  # Output shape: (1, 32, 256, 256)

        # Concatenate all outputs along the channel dimension (dim=1)
        x = torch.cat([x, x1, x2, x3], dim=1)  # Shape: (1, 16 + 32 + 64 + 32, 256, 256) = (1, 144, 256, 256)
        x = x[:,0:80,:,:]  # Output shape: (1, 80, 256, 256)
        # Adjust channels before passing through ConvBlock4
        x = self.conv_block4(x)  # Output shape: (1, 32, 256, 256)
        # Apply final 1x1 Conv with sigmoid
        x = self.final_conv(x)  # Output shape: (1, 3, 256, 256)
        x = self.sigmoid(x)  # Output shape: (1, 3, 256, 256)
        return x


# Utility function for plotting loss
def plot_loss(train_losses, val_losses=None, output_dir=None, batch_size=None, lr=None, optimizer_name=None):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    if val_losses:
        plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Loss over Epochs (BS={batch_size}, LR={lr}, Opt={optimizer_name})')
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, f'loss_plot_bs{batch_size}_lr{lr}_opt{optimizer_name}.png'))
    plt.show()

# Define directories
raw_dir = '/kaggle/input/uieb-data/UIEB/Test/Raw'
reference_dir = '/kaggle/input/uieb-data/UIEB/Test/Reference'
output_dir = '/kaggle/working/'
model_path = '/kaggle/input/uieb-model/UIEModel_final(1).pth'

# Hyperparameter grid
batch_sizes = [2, 4]
learning_rates = [0.0002, 0.0001, 0.02,0.01]
optimizers = {
    'SGD': optim.SGD,
    'Adam': optim.Adam,
    'WSGD': optim.SGD,  # Placeholder for your WSGD implementation
    'Adamax': optim.Adamax,
    'Nadam': optim.NAdam
}

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load weights function
def load_weights(model, path):
    checkpoint = torch.load(path)
    model_state_dict = model.state_dict()
    
    # Check if all keys match
    for key in checkpoint.keys():
        if key not in model_state_dict:
            print(f"Skipping key: {key}")
    # Load only the matching keys
    model_state_dict.update({k: v for k, v in checkpoint.items() if k in model_state_dict})
    model.load_state_dict(model_state_dict)

# Training function
def train_and_evaluate(batch_size, lr, optimizer_name):
    print(f"\nTraining with Batch Size: {batch_size}, Learning Rate: {lr}, Optimizer: {optimizer_name}")
    
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    dataset = UIEDataset(raw_dir, reference_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model = CustomModel().to(device)  # Move model to GPU
    load_weights(model, model_path)
    model.train()

    criterion = nn.MSELoss()  # Use appropriate loss function
    optimizer_class = optimizers[optimizer_name]
    optimizer = optimizer_class(model.parameters(), lr=lr)

    num_epochs = 10
    train_losses = []
    results = []

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        start_time = time.time()  # Start time for the epoch
        
        for raw, reference in dataloader:
            raw, reference = raw.to(device), reference.to(device)  # Move data to GPU
            optimizer.zero_grad()
            outputs = model(raw)
            loss = criterion(outputs, reference)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * raw.size(0)

        epoch_loss /= len(dataloader.dataset)
        train_losses.append(epoch_loss)
        end_time = time.time()  # End time for the epoch
        elapsed_time = end_time - start_time
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Time Taken: {elapsed_time:.2f}s")
        results.append({'batch_size': batch_size, 'learning_rate': lr, 'optimizer': optimizer_name,
                        'epoch': epoch + 1, 'loss': epoch_loss, 'time_taken': elapsed_time})

    # Save loss plot
    plot_loss(train_losses, output_dir=output_dir, batch_size=batch_size, lr=lr, optimizer_name=optimizer_name)
    
    # Save results to CSV
    df = pd.DataFrame(results)
    csv_filename = f'results_bs{batch_size}_lr{lr}_opt{optimizer_name}.csv'
    df.to_csv(os.path.join(output_dir, csv_filename), index=False)

# Grid search
all_results = []
for batch_size in batch_sizes:
    for lr in learning_rates:
        for optimizer_name in optimizers.keys():
            train_and_evaluate(batch_size, lr, optimizer_name)
            all_results.append((batch_size, lr, optimizer_name))

# Plotting the comparative analysis similar to the provided image
def plot_comparative_analysis(output_dir, all_results):
    data = []
    for (batch_size, lr, optimizer_name) in all_results:
        csv_filename = f'results_bs{batch_size}_lr{lr}_opt{optimizer_name}.csv'
        df = pd.read_csv(os.path.join(output_dir, csv_filename))
        final_loss = df['loss'].iloc[-1]  # Get the final loss value
        data.append((batch_size, lr, optimizer_name, final_loss))

    # Create a DataFrame for plotting
    comparative_df = pd.DataFrame(data, columns=['Batch Size', 'Learning Rate', 'Optimizer', 'Final Loss'])

    # Plot the comparative analysis
    plt.figure(figsize=(12, 6))
    
    markers = ['o', 's', 'D', '^', 'v', '*', 'x', '+']
    for i, optimizer_name in enumerate(optimizers.keys()):
        subset = comparative_df[comparative_df['Optimizer'] == optimizer_name]
        plt.plot(subset['Learning Rate'], subset['Final Loss'], label=optimizer_name, marker=markers[i % len(markers)])
        
        # Highlight the best-performing point for each optimizer
        best_index = subset['Final Loss'].idxmin()
        best_lr = subset.loc[best_index, 'Learning Rate']
        best_loss = subset.loc[best_index, 'Final Loss']
        plt.scatter(best_lr, best_loss, color='red', s=100, edgecolor='black', zorder=5)

    plt.xlabel('Learning Rate')
    plt.ylabel('Final Loss')
    plt.title('Comparative Analysis of Optimizers and Learning Rates')
    plt.xscale('log')
    plt.grid(True, which="both", ls="--")
    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'comparative_analysis.png'))
    plt.show()

# Plot and save the comparative analysis
plot_comparative_analysis(output_dir, all_results)


In [4]:
#Evaluation code from github with normalizing images 
import os
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import torch
from torchvision.utils import save_image

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load saved model
model = torch.load('/kaggle/input/data2-finalmodels/final_model.pt')
model.to(device)  # Ensure the model is on the correct device (GPU or CPU)
print("Full model loaded successfully!")

# Optional: Check some of the model's parameters to confirm loading
for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]}")  # Print the first two values
    break  # Remove or modify this to print more layers or parameters

# Define the RawImageDataset
class RawImageDataset(Dataset):
    def __init__(self, raw_dir, transform=None):
        self.raw_dir = raw_dir
        self.transform = transform
        self.raw_image_paths = [os.path.join(raw_dir, img) for img in os.listdir(raw_dir)]

    def __len__(self):
        return len(self.raw_image_paths)

    def __getitem__(self, idx):
        raw_image_path = self.raw_image_paths[idx]
        raw_image = Image.open(raw_image_path).convert("RGB")
        
        if self.transform:
            raw_image = self.transform(raw_image)
        
        return raw_image, os.path.basename(raw_image_path)  # Return the image and its filename
   
    # Dataset and DataLoader
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Adjust the path for the raw test images
test_dataset = RawImageDataset('/kaggle/input/imgenh-2/dataset-2/Test/Raw', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Save all enhanced test images with the same names as input images
os.makedirs('enhanced_test_images', exist_ok=True)
model.eval()
with torch.no_grad():
    for i, (images, image_names) in enumerate(test_loader):
        images = images.to(device)
        outputs = model(images)
        for j in range(images.size(0)):
            # Save the output image with the same name as the input image
            save_image(outputs[j].cpu(), os.path.join('enhanced_test_images', image_names[j]))

print('All enhanced test images saved!')


Full model loaded successfully!
Layer: inception_block_r.branch1.0.weight | Size: torch.Size([64, 1, 1, 1]) | Values : tensor([[[[ 0.8855]]],


        [[[-0.0899]]]], device='cuda:0', grad_fn=<SliceBackward0>)
All enhanced test images saved!


In [5]:
#EVALUATION METRICS CODE
#UIQM FILE CONTENT
import os
import cv2
import numpy as np
from skimage import data, color
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal as sig
import math
from skimage.util import img_as_ubyte, img_as_float64
from skimage.color import rgb2gray
from skimage.color import rgb2hsv
import matplotlib.pyplot as plt
from skimage.io import imread
import warnings
warnings.filterwarnings('ignore')
import skimage
from numpy import load
from numpy import expand_dims
import matplotlib
from matplotlib import pyplot
import sys
import PIL
from PIL import Image
import pandas as pd
import numpy as np
import scipy.misc
import imageio
import glob
import os
import cv2
!pip install sewar
import sewar
import math
from math import log2, log10
from scipy import ndimage
import skimage
from skimage import color
from skimage.metrics import structural_similarity
def mu_a(x, alpha_L=0.1, alpha_R=0.1):
    """
      Calculates the asymetric alpha-trimmed mean
    """
    # sort pixels by intensity - for clipping
    x = sorted(x)
    # get number of pixels
    K = len(x)
    # calculate T alpha L and T alpha R
    T_a_L = math.ceil(alpha_L*K)
    T_a_R = math.floor(alpha_R*K)
    # calculate mu_alpha weight
    weight = (1/(K-T_a_L-T_a_R))
    # loop through flattened image starting at T_a_L+1 and ending at K-T_a_R
    s   = int(T_a_L+1)
    e   = int(K-T_a_R)
    val = sum(x[s:e])
    val = weight*val
    return val

def s_a(x, mu):
    val = 0
    for pixel in x:
        val += math.pow((pixel-mu), 2)
    return val/len(x)

def _uicm(x):
    R = x[:,:,0].flatten()
    G = x[:,:,1].flatten()
    B = x[:,:,2].flatten()
    RG = R-G
    YB = ((R+G)/2)-B
    mu_a_RG = mu_a(RG)
    mu_a_YB = mu_a(YB)
    s_a_RG = s_a(RG, mu_a_RG)
    s_a_YB = s_a(YB, mu_a_YB)
    l = math.sqrt( (math.pow(mu_a_RG,2)+math.pow(mu_a_YB,2)) )
    r = math.sqrt(s_a_RG+s_a_YB)
    return (-0.0268*l)+(0.1586*r)

def sobel(x):
    dx = ndimage.sobel(x,0)
    dy = ndimage.sobel(x,1)
    mag = np.hypot(dx, dy)
    mag *= 255.0 / np.max(mag)
    return mag

def eme(x, window_size):
    """
      Enhancement measure estimation
      x.shape[0] = height
      x.shape[1] = width
    """
    # if 4 blocks, then 2x2...etc.
    k1 = x.shape[1]/window_size
    k2 = x.shape[0]/window_size
    # weight
    w = 2./(k1*k2)
    blocksize_x = window_size
    blocksize_y = window_size
    # make sure image is divisible by window_size - doesn't matter if we cut out some pixels
    x = x[:int(blocksize_y*k2), :int(blocksize_x*k1)]
    val = 0
    for l in range(int(k1)):
        for k in range(int(k2)):
            block = x[k*window_size:window_size*(k+1), l*window_size:window_size*(l+1)]
            max_ = np.max(block)
            min_ = np.min(block)
            # bound checks, can't do log(0)
            if min_ == 0.0: val += 0
            elif max_ == 0.0: val += 0
            else: val += math.log(max_/min_)
    return w*val

def _uism(x):
    """
      Underwater Image Sharpness Measure
    """
    # get image channels
    R = x[:,:,0]
    G = x[:,:,1]
    B = x[:,:,2]
    # first apply Sobel edge detector to each RGB component
    Rs = sobel(R)
    Gs = sobel(G)
    Bs = sobel(B)
    # multiply the edges detected for each channel by the channel itself
    R_edge_map = np.multiply(Rs, R)
    G_edge_map = np.multiply(Gs, G)
    B_edge_map = np.multiply(Bs, B)
    # get eme for each channel
    r_eme = eme(R_edge_map, 10)
    g_eme = eme(G_edge_map, 10)
    b_eme = eme(B_edge_map, 10)
    # coefficients
    lambda_r = 0.299
    lambda_g = 0.587
    lambda_b = 0.144
    return (lambda_r*r_eme) + (lambda_g*g_eme) + (lambda_b*b_eme)

def plip_g(x,mu=1026.0):
    return mu-x

def plip_theta(g1, g2, k):
    g1 = plip_g(g1)
    g2 = plip_g(g2)
    return k*((g1-g2)/(k-g2))

def plip_cross(g1, g2, gamma):
    g1 = plip_g(g1)
    g2 = plip_g(g2)
    return g1+g2-((g1*g2)/(gamma))

def plip_diag(c, g, gamma):
    g = plip_g(g)
    return gamma - (gamma * math.pow((1 - (g/gamma) ), c) )

def plip_multiplication(g1, g2):
    return plip_phiInverse(plip_phi(g1) * plip_phi(g2))
    #return plip_phiInverse(plip_phi(plip_g(g1)) * plip_phi(plip_g(g2)))

def plip_phiInverse(g):
    plip_lambda = 1026.0
    plip_beta   = 1.0
    return plip_lambda * (1 - math.pow(math.exp(-g / plip_lambda), 1 / plip_beta));

def plip_phi(g):
    plip_lambda = 1026.0
    plip_beta   = 1.0
    return -plip_lambda * math.pow(math.log(1 - g / plip_lambda), plip_beta)

def _uiconm(x, window_size):
    plip_lambda = 1026.0
    plip_gamma  = 1026.0
    plip_beta   = 1.0
    plip_mu     = 1026.0
    plip_k      = 1026.0
    # if 4 blocks, then 2x2...etc.
    k1 = x.shape[1]/window_size
    k2 = x.shape[0]/window_size
    # weight
    w = -1./(k1*k2)
    blocksize_x = window_size
    blocksize_y = window_size
    # make sure image is divisible by window_size - doesn't matter if we cut out some pixels
    x = x[:int(blocksize_y*k2), :int(blocksize_x*k1)]
    # entropy scale - higher helps with randomness
    alpha = 1
    val = 0
    for l in range(int(k1)):
        for k in range(int(k2)):
            block = x[k*window_size:window_size*(k+1), l*window_size:window_size*(l+1), :]
            max_ = np.max(block)
            min_ = np.min(block)
            top = max_-min_
            bot = max_+min_
            if math.isnan(top) or math.isnan(bot) or bot == 0.0 or top == 0.0: val += 0.0
            else: val += alpha*math.pow((top/bot),alpha) * math.log(top/bot)
            #try: val += plip_multiplication((top/bot),math.log(top/bot))
    return w*val

##########################################################################################

def getUIQM(x):
    """
      Function to return UIQM to be called from other programs
      x: image
    """
    x = x.astype(np.float32)
    ### from https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=7300447
    #c1 = 0.4680; c2 = 0.2745; c3 = 0.2576
    ### from https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=7300447
    c1 = 0.0282; c2 = 0.2953; c3 = 3.5753
    uicm   = _uicm(x)
    uism   = _uism(x)
    uiconm = _uiconm(x, 10)
    uiqm = (c1*uicm) + (c2*uism) + (c3*uiconm)
    return uiqm


def getUCIQE(rgb_in):
    # calculate Chroma
    rgb_in = cv2.normalize(rgb_in, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    (l,a,b)=cv2.split(rgb_in)
    Chroma = np.sqrt(a*a + b*b)
    StdVarianceChroma = np.std(np.reshape(Chroma[:,:],(-1,1)))

    hsv = skimage.color.rgb2hsv(rgb_in)
    Saturation = hsv[:,:,2]
    MeanSaturation = np.mean(np.reshape(Saturation[:,:],(-1,1)))

    ContrastLuminance = max(np.reshape(l[:,:],(-1,1))) - min(np.reshape(l[:,:],(-1,1)))
    UCIQE = 0.4680 * StdVarianceChroma + 0.2745 * ContrastLuminance + 0.2576 * MeanSaturation
    return float(UCIQE)

def improve_contrast_image_using_clahe(bgr_image):
    hsv = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2HSV)
    hsv_planes = cv2.split(hsv)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    hsv_planes[2] = clahe.apply(hsv_planes[2])
    hsv = cv2.merge(hsv_planes)
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

Collecting sewar
  Downloading sewar-0.4.6.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: sewar
  Building wheel for sewar (setup.py) ... [?25ldone
[?25h  Created wheel for sewar: filename=sewar-0.4.6-py3-none-any.whl size=11420 sha256=d2d94363ae2ae094d3ba00bb0286e0fd1f28c6c41147c5cbf97f333483606ebf
  Stored in directory: /root/.cache/pip/wheels/3f/af/02/9c6556ba287b62a945d737def09b8b8c35c9e1d82b9dfae84c
Successfully built sewar
Installing collected packages: sewar
Successfully installed sewar-0.4.6


In [6]:
from skimage.metrics import structural_similarity, peak_signal_noise_ratio, mean_squared_error
import numpy as np
!pip install libsvm-official
from libsvm.svmutil import *
import matplotlib.pyplot as plt
import cv2
import os
import skimage
import imageio as iio

def NormalizeData(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

generated = []
gt = []

# SET THE TEST IMAGE PATH IN gt_addrr
gt_addrr = "/kaggle/input/imgenh-2/dataset-2/Test/Reference"

# SET THE ENHANCED TEST IMAGE PATH IN addrr
addrr = "/kaggle/working/enhanced_test_images"

# Ensure both lists of images have the same order by sorting the filenames
gt_filenames = sorted(os.listdir(gt_addrr))
generated_filenames = sorted(os.listdir(addrr))

# Load the images, resize them, and append them to the lists
for item in generated_filenames:
    if item.endswith(".jpg"):
        image_path = os.path.join(addrr, item)
        image = cv2.imread(image_path)
        if image is not None:  # Ensure the image was loaded
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype("float32")
            image = cv2.resize(image, (256, 256))  # Resize to a common size
            generated.append(image)
        else:
            print(f"Warning: Could not load image {image_path}")

for item in gt_filenames:
    if item.endswith(".jpg"):
        image_path = os.path.join(gt_addrr, item)
        image = cv2.imread(image_path)
        if image is not None:  # Ensure the image was loaded
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype("float32")
            image = cv2.resize(image, (256, 256))  # Resize to a common size
            gt.append(image)
        else:
            print(f"Warning: Could not load image {image_path}")

# Ensure both lists have the same length before proceeding
if len(generated) != len(gt):
    raise ValueError("The number of generated images and ground truth images do not match!")

# Initialize lists to store the metrics
SSIM_results = []
PSNR_results = []
UIQM = []
UCIQE = []
MSE = []

# Calculate metrics for each image pair
for i in range(len(generated)):
    print(f"Processing image pair {i+1}/{len(generated)}")
    
    # Normalize images
    norm_generated = NormalizeData(generated[i])
    norm_gt = NormalizeData(gt[i])
    
    # Calculate and store UIQM and UCIQE metrics
    UIQM.append(getUIQM(norm_generated))
    UCIQE.append(getUCIQE(norm_generated))
    
    # Calculate and store PSNR and SSIM with explicit win_size, channel_axis, and data_range
    PSNR_results.append(peak_signal_noise_ratio(norm_generated, norm_gt, data_range=1.0))
    SSIM_results.append(structural_similarity(norm_generated, norm_gt, win_size=7, channel_axis=-1, data_range=1.0))
    
    # Calculate and store MSE
    MSE.append(mean_squared_error(norm_generated, norm_gt))

# Print the average of the metrics
print(f"Average SSIM: {np.mean(SSIM_results):.4f}")
print(f"Average PSNR: {np.mean(PSNR_results):.4f} dB")
print(f"Average MSE: {np.mean(MSE):.4f}")
print(f"Average UIQM: {np.mean(UIQM):.4f}")
print(f"Average UCIQE: {np.mean(UCIQE):.4f}")


Collecting libsvm-official
  Downloading libsvm_official-3.35.0.tar.gz (39 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: libsvm-official
  Building wheel for libsvm-official (setup.py) ... [?25ldone
[?25h  Created wheel for libsvm-official: filename=libsvm_official-3.35.0-cp310-cp310-linux_x86_64.whl size=53027 sha256=8ef4976e642e7c001518530e09bcd551cf5f4ae603db3b11aedc8958a56c4a97
  Stored in directory: /root/.cache/pip/wheels/ec/50/a6/962c82577759a39080b2d0b51640bc1be45d02bd5d0d5dde7d
Successfully built libsvm-official
Installing collected packages: libsvm-official
Successfully installed libsvm-official-3.35.0
Processing image pair 1/1806
Processing image pair 2/1806
Processing image pair 3/1806
Processing image pair 4/1806
Processing image pair 5/1806
Processing image pair 6/1806
Processing image pair 7/1806
Processing image pair 8/1806
Processing image pair 9/1806
Processing image pair 10/1806
Processing image pair 11/1806
Processing i

In [5]:
import numpy as np
import torchvision
from torchvision import transforms
import time
from tqdm import tqdm
import os
import torch
import torch.nn as nn
from torchvision.utils import save_image
from glob import glob
from PIL import Image
from os.path import join
from scipy import ndimage
from torch.utils.data import Dataset, DataLoader, Subset
from scipy.ndimage import gaussian_filter  # **Import gaussian_filter**
import math

# Ensure environment variable for Kaggle
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

class UIEDataset(Dataset):
    def __init__(self, raw_dir, reference_dir, transform=None):
        self.raw_dir = raw_dir
        self.reference_dir = reference_dir
        self.transform = transform
        self.image_names = [img for img in os.listdir(raw_dir) if img in os.listdir(reference_dir)]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        raw_image_path = os.path.join(self.raw_dir, self.image_names[idx])
        reference_image_path = os.path.join(self.reference_dir, self.image_names[idx])

        raw_image = Image.open(raw_image_path).convert("RGB")
        reference_image = Image.open(reference_image_path).convert("RGB")

        if self.transform:
            raw_image = self.transform(raw_image)
            reference_image = self.transform(reference_image)

        return raw_image, reference_image

# SSIM and PSNR Calculation
def getSSIM(X, Y):
    assert (X.shape == Y.shape), "Image patches provided have different dimensions"
    nch = 1 if X.ndim == 2 else X.shape[-1]
    mssim = []
    for ch in range(nch):
        Xc, Yc = X[..., ch].astype(np.float64), Y[..., ch].astype(np.float64)
        mssim.append(compute_ssim(Xc, Yc))  # **Updated call to compute_ssim**
    return np.mean(mssim)

def compute_ssim(X, Y):
    K1 = 0.01
    K2 = 0.03
    sigma = 1.5
    win_size = 5

    ux = gaussian_filter(X, sigma)  # **Update for SSIM calculation**
    uy = gaussian_filter(Y, sigma)  # **Update for SSIM calculation**

    uxx = gaussian_filter(X * X, sigma)
    uyy = gaussian_filter(Y * Y, sigma)
    uxy = gaussian_filter(X * Y, sigma)

    N = win_size ** X.ndim
    unbiased_norm = N / (N - 1)
    vx = (uxx - ux * ux) * unbiased_norm
    vy = (uyy - uy * uy) * unbiased_norm
    vxy = (uxy - ux * uy) * unbiased_norm

    R = 255
    C1 = (K1 * R) ** 2
    C2 = (K2 * R) ** 2
    sim = (2 * ux * uy + C1) * (2 * vxy + C2)
    D = (ux ** 2 + uy ** 2 + C1) * (vx + vy + C2)
    SSIM = sim / D
    mssim = SSIM.mean()

    return mssim

def getPSNR(X, Y):
    target_data = np.array(X, dtype=np.float64)
    ref_data = np.array(Y, dtype=np.float64)
    diff = ref_data - target_data
    diff = diff.flatten('C')
    rmse = math.sqrt(np.mean(diff ** 2.))
    if rmse == 0: return 100
    else: return 20 * math.log10(255.0 / rmse)


# UIQM Calculation
def mu_a(x, alpha_L=0.1, alpha_R=0.1):
    x = sorted(x)
    K = len(x)
    T_a_L = math.ceil(alpha_L * K)
    T_a_R = math.floor(alpha_R * K)
    weight = (1 / (K - T_a_L - T_a_R))
    s = int(T_a_L + 1)
    e = int(K - T_a_R)
    val = sum(x[s:e])
    return weight * val

def s_a(x, mu):
    return sum(math.pow((pixel - mu), 2) for pixel in x) / len(x)

def _uicm(x):
    R = x[:, :, 0].flatten()
    G = x[:, :, 1].flatten()
    B = x[:, :, 2].flatten()
    RG = R - G
    YB = ((R + G) / 2) - B
    mu_a_RG = mu_a(RG)
    mu_a_YB = mu_a(YB)
    s_a_RG = s_a(RG, mu_a_RG)
    s_a_YB = s_a(YB, mu_a_YB)
    l = math.sqrt(math.pow(mu_a_RG, 2) + math.pow(mu_a_YB, 2))
    r = math.sqrt(s_a_RG + s_a_YB)
    return (-0.0268 * l) + (0.1586 * r)

def sobel(x):
    dx = ndimage.sobel(x, 0)
    dy = ndimage.sobel(x, 1)
    mag = np.hypot(dx, dy)
    mag *= 255.0 / np.max(mag)
    return mag

def eme(x, window_size):
    k1 = int(x.shape[1] / window_size)
    k2 = int(x.shape[0] / window_size)
    w = 2. / (k1 * k2)
    x = x[:window_size * k2, :window_size * k1]
    val = 0
    for l in range(k1):
        for k in range(k2):
            block = x[k * window_size:window_size * (k + 1), l * window_size:window_size * (l + 1)]
            max_ = np.max(block)
            min_ = np.min(block)
            if min_ == 0.0 or max_ == 0.0: 
                val += 0
            else:
                val += math.log(max_ / min_)
    return w * val

def _uism(x):
    R = x[:, :, 0]
    G = x[:, :, 1]
    B = x[:, :, 2]
    Rs = sobel(R)
    Gs = sobel(G)
    Bs = sobel(B)
    R_edge_map = np.multiply(Rs, R)
    G_edge_map = np.multiply(Gs, G)
    B_edge_map = np.multiply(Bs, B)
    r_eme = eme(R_edge_map, 10)
    g_eme = eme(G_edge_map, 10)
    b_eme = eme(B_edge_map, 10)
    lambda_r = 0.299
    lambda_g = 0.587
    lambda_b = 0.144
    return (lambda_r * r_eme) + (lambda_g * g_eme) + (lambda_b * b_eme)

def plip_g(x, mu=1026.0):
    return mu - x

def plip_theta(g1, g2, k):
    g1 = plip_g(g1)
    g2 = plip_g(g2)
    return k * ((g1 - g2) / (k - g2))

def plip_cross(g1, g2, gamma):
    g1 = plip_g(g1)
    g2 = plip_g(g2)
    return g1 + g2 - ((g1 * g2) / gamma)

def plip_diag(c, g, gamma):
    g = plip_g(g)
    return gamma - (gamma * math.pow((1 - (g / gamma)), c))

def plip_multiplication(g1, g2):
    return plip_phiInverse(plip_phi(g1) * plip_phi(g2))

def plip_phiInverse(g):
    plip_lambda = 1026.0
    plip_beta = 1.0
    return plip_lambda * (1 - math.pow(math.exp(-g / plip_lambda), 1 / plip_beta))

def plip_phi(g):
    plip_lambda = 1026.0
    plip_beta = 1.0
    return -plip_lambda * math.pow(math.log(1 - g / plip_lambda), plip_beta)

def _uiconm(x, window_size):
    plip_lambda = 1026.0
    plip_gamma = 1026.0
    plip_beta = 1.0
    plip_mu = 1026.0
    plip_k = 1026.0
    k1 = int(x.shape[1] / window_size)
    k2 = int(x.shape[0] / window_size)
    w = -1. / (k1 * k2)
    x = x[:window_size * k2, :window_size * k1]
    alpha = 1
    val = 0
    for l in range(k1):
        for k in range(k2):
            block = x[k * window_size:window_size * (k + 1), l * window_size:window_size * (l + 1), :]
            max_ = np.max(block)
            min_ = np.min(block)
            top = max_ - min_
            bot = max_ + min_
            if math.isnan(top) or math.isnan(bot) or bot == 0.0 or top == 0.0:
                val += 0.0
            else:
                val += alpha * math.pow((top / bot), alpha) * math.log(top / bot)
    return w * val

def getUIQM(x):
    x = x.astype(np.float32)
    c1 = 0.0282
    c2 = 0.2953
    c3 = 3.5753
    uicm = _uicm(x)
    uism = _uism(x)
    uiconm = _uiconm(x, 10)
    return (c1 * uicm) + (c2 * uism) + (c3 * uiconm)

def test(config, test_dataloader, test_model):
    with torch.no_grad():
        for i, (input, target) in enumerate(test_dataloader):
            input = input.to(config['device'])
            output = test_model(input)
            
            for j in range(output.size(0)):
                # Assuming name[j] is part of the test_dataloader output and is a string filename
                name = test_dataloader.dataset.image_names[i * config['batch_size'] + j]  # Get the image name

                output_image = output[j].cpu().clamp(0, 1)  # Clamp the output to the range [0, 1]
                output_image = transforms.ToPILImage()(output_image)

                output_path = os.path.join(config['output_images_path'], name)  # Combine the output path and image name
                output_image.save(output_path)

    print("Testing completed.")



def setup(config):
    if torch.cuda.is_available():
        config['device'] = "cuda"
    else:
        config['device'] = "cpu"

    # Load the entire model
    model = torch.load(config['snapshot_path'], map_location=config['device'])
    model.to(config['device'])
    model.eval()

    transform = transforms.Compose([
        transforms.Resize((config['resize'], config['resize'])),
        transforms.ToTensor()
    ])
    
    # Ensure both raw and reference directories are passed
    test_dataset = UIEDataset(config['test_images_path'], config['label_images_path'], transform)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)
    
    print("Test Dataset Reading Completed.")
    return test_dataloader, model



def SSIMs_PSNRs(gtr_dir, gen_dir, im_res=(256, 256)):
    gtr_paths = sorted(glob(join(gtr_dir, "*.*")))
    gen_paths = sorted(glob(join(gen_dir, "*.*")))
    ssims, psnrs = [], []
    
    for gtr_path, gen_path in zip(gtr_paths, gen_paths):
        r_im = Image.open(gtr_path).resize(im_res)
        g_im = Image.open(gen_path).resize(im_res)
        
        # SSIM calculation
        ssim = getSSIM(np.array(r_im), np.array(g_im))
        ssims.append(ssim)
        
        # PSNR calculation
        r_im = r_im.convert("L")
        g_im = g_im.convert("L")
        psnr = getPSNR(np.array(r_im), np.array(g_im))
        psnrs.append(psnr)
        
    # Calculate averages and standard deviations
    avg_ssim = np.mean(ssims)
    avg_psnr = np.mean(psnrs)
    std_ssim = np.std(ssims)
    std_psnr = np.std(psnrs)
    
    return avg_ssim, avg_psnr, std_ssim, std_psnr

def measure_UIQMs(dir_name, im_res=(256, 256)):
    paths = sorted(glob(join(dir_name, "*.*")))
    uqims = []
    
    for img_path in paths:
        im = Image.open(img_path).resize(im_res)
        uiqm = getUIQM(np.array(im))
        uqims.append(uiqm)
    
    # Calculate averages and standard deviations
    avg_uiqm = np.mean(uqims)
    std_uiqm = np.std(uqims)
    
    return avg_uiqm, std_uiqm

# Kaggle specific setup
if __name__ == '__main__':
    config = {
        'snapshot_path': "/kaggle/input/data2-finalmodels/final_model.pt",
        'test_images_path': "/kaggle/input/imgenh-2/dataset-2/Test/Raw",
        'output_images_path': "/kaggle/working/Gen-output-1",
        'batch_size': 1,
        'resize': 256,
        'calculate_metrics': True,
        'label_images_path': "/kaggle/input/imgenh-2/dataset-2/Test/Reference"
    }

    if not os.path.exists(config['output_images_path']):
        os.mkdir(config['output_images_path'])

    start_time = time.time()
    ds_test, model = setup(config)
    test(config, ds_test, model)
    print("Total testing time:", time.time() - start_time)

    # Calculate metrics if specified
    if config['calculate_metrics']:
        gen_uqims, _ = measure_UIQMs(config['output_images_path'])
        avg_uiqm, _ = measure_UIQMs(config['output_images_path'])

        avg_ssim, avg_psnr, _, _ = SSIMs_PSNRs(config['label_images_path'], config['output_images_path'])
        print(f"Average SSIM: {avg_ssim:.4f}")
        print(f"Average PSNR: {avg_psnr:.4f} dB")
        print(f"Average UIQM: {avg_uiqm:.4f}")
        

      


Test Dataset Reading Completed.
Testing completed.
Total testing time: 121.78503847122192
Average SSIM: 0.8603
Average PSNR: 29.5985 dB
Average UIQM: 3.0313
