In [11]:
import numpy as np 
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import os
import glob
import shutil
from PIL import Image
import torch.nn.functional as F
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm


In [12]:
# SCPA Block
class SCPA(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(SCPA,self).__init__()
        self.conv1_branch1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2_branch1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.sigmoid = nn.Sigmoid()
        self.conv1_branch2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2_branch2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv3_branch2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv4_branch2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.final_conv =    nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
    def forward(self,x):
        # branch 1
        branch1 = self.conv1_branch1(x)
        branch1 = self.conv2_branch1(branch1)
        # branch 2
        branch2 = self.conv1_branch2(x)
        branch2a = self.conv2_branch2(branch2)
        branch2a = self.sigmoid(branch2a)
        branch2b = self.conv3_branch2(branch2)
        branch2 = branch2a*branch2b
        branch2 = self.conv4_branch2(branch2)
        
        #combining branch 1 and branch 2
        output = branch2 + branch1
        
        #final convolutional layer and add it to the orignal input
        final_conv = self.final_conv(output)
        SCPA_output = final_conv + x
        return SCPA_output

class ConvBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,padding=1):
        super(ConvBlock, self).__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,padding=padding)
        self.bn=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# CoordConv Bloack
class CoordConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CoordConv, self).__init__()
        self.conv = nn.Conv2d(in_channels+2, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        batch_size, _, height, width = x.size()
        xx = torch.arange(width).repeat(height, 1)
        yy = torch.arange(height).view(-1, 1).repeat(1, width)
        xx = xx.float() / (width - 1)
        yy = yy.float() / (height - 1)
        xx = xx.repeat(batch_size, 1, 1).unsqueeze(1)
        yy = yy.repeat(batch_size, 1, 1).unsqueeze(1)
        if x.is_cuda:
            xx = xx.cuda()
            yy = yy.cuda()
        x = torch.cat([x, xx, yy], dim=1)
        x = self.conv(x)
        return x

if __name__ == "__main__":
    model = CoordConv(in_channels=3, out_channels=5)
    input_tensor = torch.randn(1, 3, 400, 592)
    output_tensor = model(input_tensor)
    print(output_tensor.shape)
    
class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        #Calculate attention weights
        att_weights = self.conv1(x)
        att_weights = self.sigmoid(att_weights)
        #Apply attention to input features
        output = x * att_weights
        return output

#inverse residual block
class inverted_residual_block(nn.Module):
    def __init__(self,in_channels,out_channels,expansion_factor,stride): #3361
        super(inverted_residual_block,self).__init__()
        self.stride=stride
        hidden_dim=in_channels*expansion_factor
        self.use_residual=self.stride==1 and in_channels==out_channels

        layers=[]
        if expansion_factor!= 1:
            layers.append(nn.Conv2d(in_channels,hidden_dim,kernel_size=1,bias=False))
            layers.append(nn.BatchNorm2d(hidden_dim))
            layers.append(nn.ReLU6(inplace=True))

        layers.append(nn.Conv2d(hidden_dim,hidden_dim,kernel_size=3,stride=stride,padding=1,groups=hidden_dim,bias=False))
        layers.append(nn.BatchNorm2d(hidden_dim))
        layers.append(nn.ReLU6(inplace=True))

        layers.append(nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=False))
        layers.append(nn.BatchNorm2d(out_channels))

        self.conv=nn.Sequential(*layers)

    def forward(self,x):
        if self.use_residual:
            return x+self.conv(x)
        else:
            return self.conv(x)
    
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        return out + residual

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate=32, num_layers=4):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1))
            self.layers.append(nn.BatchNorm2d(growth_rate))
            self.layers.append(nn.ReLU(inplace=True))

    def forward(self, x):
        features = [x]
        for i in range(0, len(self.layers), 3):
            out = self.layers[i](torch.cat(features, dim=1))
            out = self.layers[i + 1](out)
            out = self.layers[i + 2](out)
            features.append(out)
        return torch.cat(features, dim=1)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=3):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7)
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return x * self.sigmoid(x)

class ResidualDenseAttention(nn.Module):
    def __init__(self, in_channels, growth_rate=32, num_layers=4, kernel_size=7):
        super(ResidualDenseAttention, self).__init__()
        self.residual_block = ResidualBlock(in_channels)
        self.dense_block = DenseBlock(in_channels, growth_rate, num_layers)
        self.attention_block = SpatialAttention(kernel_size)

    def forward(self, x):
        residual_out = self.residual_block(x)
        dense_out = self.dense_block(residual_out)
        attention_out = self.attention_block(dense_out)
        return attention_out + x

class downsampling_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(downsampling_block, self).__init__()
        self.conv=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU(inplace=True)
        self.pool=nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x=self.conv(x)
        x=self.bn(x)
        x=self.relu(x)
        x=self.pool(x)
        return x
    
class upsampling_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(upsampling_block, self).__init__()
        self.upconv=nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv=nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn=nn.BatchNorm2d(out_channels)
        self.relu=nn.ReLU(inplace=True)

    def forward(self, x):
        x=self.upconv(x)
        x=self.conv(x)
        x=self.bn(x)
        x=self.relu(x)
        return x

class SCPA_Branch(nn.Module):
    def __init__(self,in_channel=3,out_channel=3):
        super(SCPA_Branch,self).__init__()
        self.coord_layer = CoordConv(in_channels =3,out_channels = 5)
        self.SCPA_1 = SCPA(in_channels = 5, out_channels =5)
        self.SCPA_2 = SCPA(in_channels = 5, out_channels =5)
        self.SCPA_3 = SCPA(in_channels = 5, out_channels =5)
        self.SCPA_4 = SCPA(in_channels = 5, out_channels =5)
        self.SCPA_5 = SCPA(in_channels = 5, out_channels =5)
        self.conv_layer =nn.Conv2d(in_channels=5,out_channels =3,kernel_size = 3,padding =1)
        
    def forward(self,x):
        x = self.coord_layer(x)
        x = self.SCPA_1(x)
        x = self.SCPA_2(x)
        x = self.SCPA_3(x)
        x = self.SCPA_4(x)
        x = self.SCPA_5(x)
        x = self.conv_layer(x)
        return x 

# Denoising branch
'''
convolutional block -> 4 inv residual block -> attention block -> convolution block 
'''
class DenoiseBranch(nn.Module):
    def __init__(self,in_channel=3,out_channel=3):
        super(DenoiseBranch,self).__init__()
        self.conv_1 = ConvBlock(3,3)
        self.inv_1 = inverted_residual_block(3,3,6,1)
        self.inv_2 = inverted_residual_block(3,3,6,1)
        self.inv_3 = inverted_residual_block(3,3,6,1)
        self.inv_4 = inverted_residual_block(3,3,6,1)
        self.attention = AttentionBlock(in_channels = 3)
        self.conv_2 = ConvBlock(3,3)
    
    def forward(self,x):
        x = self.conv_1(x)
        x = self.inv_1(x)
        x = self.inv_2(x)
        x = self.inv_3(x)
        x = self.inv_4(x)
        x = self.attention(x)
        x = self.conv_2(x)
        return x
    
class Encoder(nn.Module):
    def __init__(self,in_channel=3,out_channel=3):
        super(Encoder,self).__init__()
        self.RDA_Block1= ResidualDenseAttention(in_channels=in_channel, growth_rate=32, num_layers=4, kernel_size=7)
        self.RDA_Block2= ResidualDenseAttention(in_channels=in_channel, growth_rate=32, num_layers=4, kernel_size=7)
        self.Downsampler = downsampling_block(in_channel,out_channel)
    def forward(self,x):
        x = self.RDA_Block1(x)
        x = self.RDA_Block2(x)
        x = self.Downsampler(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self,in_channel=3,out_channel=3):
        super(Decoder,self).__init__()
        self.RDA_Block1= ResidualDenseAttention(in_channels=in_channel, growth_rate=32, num_layers=4, kernel_size=7)
        self.RDA_Block2= ResidualDenseAttention(in_channels=in_channel, growth_rate=32, num_layers=4, kernel_size=7)
        self.Upsampler = upsampling_block(in_channel,out_channel)
    def forward(self,x):
        x = self.RDA_Block1(x)
        x = self.RDA_Block2(x)
        x = self.Upsampler(x)
        return x
    
class FinalAutoencoder(nn.Module):
    def __init__(self,in_channel=3,out_channel=3):
        super(FinalAutoencoder,self).__init__()
        self.encoder1 = Encoder(in_channel,32)
        self.encoder2 = Encoder(32,64)
        self.encoder3 = Encoder(64,128)
        self.encoder4 = Encoder(128,256)
        self.decoder1 = Decoder(256,128)
        self.decoder2 = Decoder(128+128,64)
        self.decoder3 = Decoder(64+64,32)
        self.decoder4 = Decoder(32+32,out_channel)

    def forward(self,x):
        x1 = self.encoder1(x)
        x2 = self.encoder2(x1)
        x3 = self.encoder3(x2)
        x4 = self.encoder4(x3)
        x5 = self.decoder1(x4)
        x6 = self.decoder2(torch.cat([x5,x3],1))
        x7 = self.decoder3(torch.cat([x6,x2],1))
        x8 = self.decoder4(torch.cat([x7,x1],1))
        return x8
    
class LowLightModel(nn.Module):
    def __init__(self,in_channel=3,out_channel=3):
        super(LowLightModel,self).__init__()
        self.SCPA_branch = SCPA_Branch()
        self.Denoiser = DenoiseBranch()
        self.AutoEncoder = FinalAutoencoder()
        self.Conv = ConvBlock(3,3)
        
    def forward(self,x):
        denoised = self.Denoiser(x)
        SCPA = self.SCPA_branch(x)
        Auto_input = SCPA + x
        Auto_output = self.AutoEncoder(Auto_input)
        output = self.Conv(Auto_output)
        final = output + denoised
        return final

torch.Size([1, 5, 400, 592])


In [13]:
class LowLightDataset(Dataset):
    def __init__(self, low_img_dir, high_img_dir, transform=None):
        self.low_img_dir = low_img_dir
        self.high_img_dir = high_img_dir
        self.low_images = sorted(os.listdir(low_img_dir))
        self.high_images = sorted(os.listdir(high_img_dir))
        self.transform = transform

    def __len__(self):
        return len(self.low_images)

    def __getitem__(self, idx):
        low_img_path = os.path.join(self.low_img_dir, self.low_images[idx])
        high_img_path = os.path.join(self.high_img_dir, self.high_images[idx])
        low_image = Image.open(low_img_path).convert("RGB")
        high_image = Image.open(high_img_path).convert("RGB")

        if self.transform:
            low_image = self.transform(low_image)
            high_image = self.transform(high_image)

        return low_image, high_image

transform = transforms.Compose([
    transforms.Resize((400, 592)),
    transforms.ToTensor(),
])

In [14]:
def psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return 100
    pixel_max = 1.0
    return 20 * torch.log10(pixel_max / torch.sqrt(mse))

In [15]:
import torch.optim as optim
train_low_dir = '/kaggle/input/dataset00/augmented_Train/augmented/low'
train_high_dir = '/kaggle/input/dataset00/augmented_Train/augmented/high'
val_low_dir = '/kaggle/input/dataset00/augmented_Train/val/low'
val_high_dir = '/kaggle/input/dataset00/augmented_Train/val/high'

train_dataset = LowLightDataset(train_low_dir, train_high_dir, transform=transform)
val_dataset = LowLightDataset(val_low_dir, val_high_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

In [16]:
def gaussian(window_size, sigma):
    gauss = torch.tensor([torch.exp(torch.tensor(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2))) for x in range(window_size)], dtype=torch.float32)
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window

def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    return ssim_map.mean() if size_average else ssim_map.mean(1).mean(1).mean(1)

class SSIMLoss(nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIMLoss, self).__init__()
        self.window_size = window_size
        self.size_average = size_average

    def forward(self, img1, img2):
        return 1 - ssim(img1, img2, self.window_size, self.size_average)

# Combined Loss Function
class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.ssim_loss = SSIMLoss()
        self.l1_loss = nn.L1Loss()

    def forward(self, output, target):
        # L1 loss
        l1_loss = self.l1_loss(output, target)
        
        # SSIM loss
        ssim_loss = self.ssim_loss(output, target)
        
        # Gradient loss
        grad_loss = self.gradient_loss(output, target)
        
        # Combined loss
        total_loss = 0.1 * ssim_loss + l1_loss + grad_loss
        
        return total_loss

    def gradient_loss(self, output, target):
        # Compute gradients
        output_grad_x = torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:])
        output_grad_y = torch.abs(output[:, :, :-1, :] - output[:, :, 1:, :])
        target_grad_x = torch.abs(target[:, :, :, :-1] - target[:, :, :, 1:])
        target_grad_y = torch.abs(target[:, :, :-1, :] - target[:, :, 1:, :])
        
        # Compute gradient loss
        grad_loss_x = F.l1_loss(output_grad_x, target_grad_x)
        grad_loss_y = F.l1_loss(output_grad_y, target_grad_y)
        
        return grad_loss_x + grad_loss_y


In [17]:
model = LowLightModel()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = CombinedLoss()
model.load_state_dict(torch.load('model_weights.pth'))

<All keys matched successfully>

In [21]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)
new_epochs = 5
start_epoch = 0
total_epochs = start_epoch + new_epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(start_epoch, total_epochs):
    model.train()
    train_loss = 0.0
    train_psnr = 0.0
    
    for low_img, high_img in tqdm(train_loader, desc=f"Epoch [{epoch+1}/{total_epochs}] Training", leave=False):
        low_img, high_img = low_img.to(device), high_img.to(device)
        
        output = model(low_img)
        loss = criterion(output, high_img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        train_loss += loss.item() 
        batch_psnr = psnr(output, high_img).item()
        train_psnr += batch_psnr
    
    model.eval()
    train_loss /= len(train_loader)
    train_psnr /= len(train_loader)
    
    print(f"Epoch [{epoch+1}/{total_epochs}], Train Loss: {train_loss:.4f}, Train PSNR: {train_psnr:.2f} dB")
    if (epoch + 1) % 2 == 0:
        model_path = f'model_epoch_{epoch+1}.pth'
        torch.save(model.state_dict(), model_path)
        print(f"Model saved at epoch {epoch+1}")


                                                                       

Epoch [1/5], Train Loss: 0.1361, Train PSNR: 24.16 dB


                                                                       

Epoch [2/5], Train Loss: 0.1316, Train PSNR: 24.55 dB
Model saved at epoch 2


                                                                       

Epoch [3/5], Train Loss: 0.1321, Train PSNR: 24.47 dB


                                                                       

Epoch [4/5], Train Loss: 0.1295, Train PSNR: 24.72 dB
Model saved at epoch 4


                                                                       

Epoch [5/5], Train Loss: 0.1311, Train PSNR: 24.64 dB




In [19]:
low_dir = '/kaggle/input/dataset00/augmented_Train/val/low'
high_dir = '/kaggle/input/dataset00/augmented_Train/val/high'

# List all images in the low light directory
low_images = os.listdir(low_dir)

# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

psnr_values = []

# Iterate over all images in the validation set
for image_name in low_images:
    input_image_path = os.path.join(low_dir, image_name)
    high_image_path = os.path.join(high_dir, image_name)
    
    # Open and preprocess the input image
    input_image = Image.open(input_image_path).convert('RGB')
    input_tensor = transform(input_image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        enhanced_tensor = model(input_tensor).cpu()

    # Read and convert the original high light image
    original_image = cv2.imread(high_image_path)
    original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
    
    # Convert original image to tensor
    original_tensor = transform(Image.fromarray(original_image_rgb)).unsqueeze(0)
    
    # Calculate PSNR and store the value
    psnr_value = psnr(original_tensor, enhanced_tensor)
    psnr_values.append(psnr_value.item())

# Calculate the average PSNR
average_psnr = sum(psnr_values) / len(psnr_values)
print(f'Average PSNR: {average_psnr:.2f} dB')

Average PSNR: 24.58 dB


In [None]:
import random
model.eval()
transform = transforms.Compose([
    transforms.Resize((400, 592)),
    transforms.ToTensor() 
])
model.eval()

low_dir = './test/low'
high_dir = './test/predicted'

low_images = os.listdir(low_dir)

random_images = random.sample(low_images,5)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

psnr_values = []

for image_name in random_images:
    input_image_path = os.path.join(low_dir, image_name)
    high_image_path = os.path.join(high_dir, image_name)
    input_image = Image.open(input_image_path).convert('RGB')
    input_tensor = transform(input_image).unsqueeze(0).to(device)
    with torch.no_grad():
        enhanced_tensor = model(input_tensor).cpu()

    enhanced_image = transforms.ToPILImage()(enhanced_tensor.squeeze())

    original_image = cv2.imread(high_image_path)
    original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
    
    fig, axs = plt.subplots(1, 3, figsize=(10, 5))  
    axs[0].imshow(input_image)
    axs[0].set_title('Low Light')
    axs[0].axis('off')
    
    axs[1].imshow(enhanced_image)
    axs[1].set_title('Enhanced')
    axs[1].axis('off')
    
    axs[2].imshow(original_image_rgb)
    axs[2].set_title('Ground Truth')
    axs[2].axis('off')
    
    plt.show()
    
    original_tensor = transform(Image.fromarray(original_image_rgb)).unsqueeze(0)
    psnr_value = psnr(original_tensor, enhanced_tensor)
    psnr_values.append(psnr_value)
    print(f'PSNR: {psnr_value:.2f} dB')
average_psnr = sum(psnr_values) / len(psnr_values)
print(f'Average PSNR: {average_psnr:.2f} dB')


In [None]:
import os
import torch
from PIL import Image
import torchvision.transforms as transforms
import cv2
from torchmetrics.functional import peak_signal_noise_ratio as psnr

model.eval()
transform = transforms.Compose([
    transforms.Resize((400, 592)),
    transforms.ToTensor() 
])
low_dir = './test/low'
high_dir = './test/predicted'

low_images = os.listdir(low_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

psnr_values = []

for image_name in low_images:
    input_image_path = os.path.join(low_dir, image_name)
    high_image_path = os.path.join(high_dir, image_name)
    
    input_image = Image.open(input_image_path).convert('RGB')
    input_tensor = transform(input_image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        enhanced_tensor = model(input_tensor).cpu()

    original_image = cv2.imread(high_image_path)
    original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
    
    original_tensor = transform(Image.fromarray(original_image_rgb)).unsqueeze(0)
    
    psnr_value = psnr(original_tensor, enhanced_tensor)
    psnr_values.append(psnr_value.item())

average_psnr = sum(psnr_values) / len(psnr_values)
print(f'Average PSNR: {average_psnr:.2f} dB')
