In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import time
import random

from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:



class DecomNet(nn.Module):
    def __init__(self, channel=64, kernel_size=3):
        super(DecomNet, self).__init__()
        # Shallow feature extraction
        self.net1_conv0 = nn.Conv2d(4, channel, kernel_size * 3,
                                    padding=4, padding_mode='replicate')
        # Activated layers!
        self.net1_convs = nn.Sequential(nn.Conv2d(channel, channel, kernel_size,
                                                  padding=1, padding_mode='replicate'),
                                        nn.ReLU(),
                                        nn.Conv2d(channel, channel, kernel_size,
                                                  padding=1, padding_mode='replicate'),
                                        nn.ReLU(),
                                        nn.Conv2d(channel, channel, kernel_size,
                                                  padding=1, padding_mode='replicate'),
                                        nn.ReLU(),
                                        nn.Conv2d(channel, channel, kernel_size,
                                                  padding=1, padding_mode='replicate'),
                                        nn.ReLU(),
                                        nn.Conv2d(channel, channel, kernel_size,
                                                  padding=1, padding_mode='replicate'),
                                        nn.ReLU())
        # Final recon layer
        self.net1_recon = nn.Conv2d(channel, 4, kernel_size,
                                    padding=1, padding_mode='replicate')

    def forward(self, input_im):
        input_max= torch.max(input_im, dim=1, keepdim=True)[0]
        input_img= torch.cat((input_max, input_im), dim=1)
        feats0   = self.net1_conv0(input_img)
        featss   = self.net1_convs(feats0)
        outs     = self.net1_recon(featss)
        R        = torch.sigmoid(outs[:, 0:3, :, :])
        L        = torch.sigmoid(outs[:, 3:4, :, :])
        return R, L

class RelightNet(nn.Module):
    def __init__(self, channel=64, kernel_size=3):
        super(RelightNet, self).__init__()

        self.relu         = nn.ReLU()
        self.net2_conv0_1 = nn.Conv2d(4, channel, kernel_size,
                                      padding=1, padding_mode='replicate')

        self.net2_conv1_1 = nn.Conv2d(channel, channel, kernel_size, stride=2,
                                      padding=1, padding_mode='replicate')
        self.net2_conv1_2 = nn.Conv2d(channel, channel, kernel_size, stride=2,
                                      padding=1, padding_mode='replicate')
        self.net2_conv1_3 = nn.Conv2d(channel, channel, kernel_size, stride=2,
                                      padding=1, padding_mode='replicate')

        self.net2_deconv1_1= nn.Conv2d(channel*2, channel, kernel_size,
                                       padding=1, padding_mode='replicate')
        self.net2_deconv1_2= nn.Conv2d(channel*2, channel, kernel_size,
                                       padding=1, padding_mode='replicate')
        self.net2_deconv1_3= nn.Conv2d(channel*2, channel, kernel_size,
                                       padding=1, padding_mode='replicate')

        self.net2_fusion = nn.Conv2d(channel*3, channel, kernel_size=1,
                                     padding=1, padding_mode='replicate')
        self.net2_output = nn.Conv2d(channel, 1, kernel_size=3, padding=0)

    def forward(self, input_L, input_R):
        input_img = torch.cat((input_R, input_L), dim=1)
        out0      = self.net2_conv0_1(input_img)
        out1      = self.relu(self.net2_conv1_1(out0))
        out2      = self.relu(self.net2_conv1_2(out1))
        out3      = self.relu(self.net2_conv1_3(out2))

        out3_up   = F.interpolate(out3, size=(out2.size()[2], out2.size()[3]))
        deconv1   = self.relu(self.net2_deconv1_1(torch.cat((out3_up, out2), dim=1)))
        deconv1_up= F.interpolate(deconv1, size=(out1.size()[2], out1.size()[3]))
        deconv2   = self.relu(self.net2_deconv1_2(torch.cat((deconv1_up, out1), dim=1)))
        deconv2_up= F.interpolate(deconv2, size=(out0.size()[2], out0.size()[3]))
        deconv3   = self.relu(self.net2_deconv1_3(torch.cat((deconv2_up, out0), dim=1)))

        deconv1_rs= F.interpolate(deconv1, size=(input_R.size()[2], input_R.size()[3]))
        deconv2_rs= F.interpolate(deconv2, size=(input_R.size()[2], input_R.size()[3]))
        feats_all = torch.cat((deconv1_rs, deconv2_rs, deconv3), dim=1)
        feats_fus = self.net2_fusion(feats_all)
        output    = self.net2_output(feats_fus)
        return output


class RetinexNet(nn.Module):
    def __init__(self):
        super(RetinexNet, self).__init__()

        self.DecomNet  = DecomNet()
        self.RelightNet= RelightNet()

    def forward(self, input_low, input_high):
        # Forward DecompNet
        input_low = Variable(torch.FloatTensor(torch.from_numpy(input_low))).to(device)
        input_high= Variable(torch.FloatTensor(torch.from_numpy(input_high))).to(device)
        R_low, I_low   = self.DecomNet(input_low)
        R_high, I_high = self.DecomNet(input_high)

        # Forward RelightNet
        I_delta = self.RelightNet(I_low, R_low)

        # Other variables
        I_low_3  = torch.cat((I_low, I_low, I_low), dim=1)
        I_high_3 = torch.cat((I_high, I_high, I_high), dim=1)
        I_delta_3= torch.cat((I_delta, I_delta, I_delta), dim=1)

        # Compute losses
        self.recon_loss_low  = F.l1_loss(R_low * I_low_3,  input_low)
        self.recon_loss_high = F.l1_loss(R_high * I_high_3, input_high)
        self.recon_loss_mutal_low  = F.l1_loss(R_high * I_low_3, input_low)
        self.recon_loss_mutal_high = F.l1_loss(R_low * I_high_3, input_high)
        self.equal_R_loss = F.l1_loss(R_low,  R_high.detach())
        self.relight_loss = F.l1_loss(R_low * I_delta_3, input_high)

        self.Ismooth_loss_low   = self.smooth(I_low, R_low)
        self.Ismooth_loss_high  = self.smooth(I_high, R_high)
        self.Ismooth_loss_delta = self.smooth(I_delta, R_low)

        self.loss_Decom = self.recon_loss_low + \
                          self.recon_loss_high + \
                          0.001 * self.recon_loss_mutal_low + \
                          0.001 * self.recon_loss_mutal_high + \
                          0.1 * self.Ismooth_loss_low + \
                          0.1 * self.Ismooth_loss_high + \
                          0.01 * self.equal_R_loss
        self.loss_Relight = self.relight_loss + \
                            3 * self.Ismooth_loss_delta

        self.output_R_low   = R_low.detach().cpu()
        self.output_I_low   = I_low_3.detach().cpu()
        self.output_I_delta = I_delta_3.detach().cpu()
        self.output_S       = R_low.detach().cpu() * I_delta_3.detach().cpu()

    def gradient(self, input_tensor, direction):
        self.smooth_kernel_x = torch.FloatTensor([[0, 0], [-1, 1]]).view((1, 1, 2, 2)).to(device)
        self.smooth_kernel_y = torch.transpose(self.smooth_kernel_x, 2, 3)

        if direction == "x":
            kernel = self.smooth_kernel_x
        elif direction == "y":
            kernel = self.smooth_kernel_y
        grad_out = torch.abs(F.conv2d(input_tensor, kernel,
                                      stride=1, padding=1))
        return grad_out

    def ave_gradient(self, input_tensor, direction):
        return F.avg_pool2d(self.gradient(input_tensor, direction),
                            kernel_size=3, stride=1, padding=1)

    def smooth(self, input_I, input_R):
        input_R = 0.299*input_R[:, 0, :, :] + 0.587*input_R[:, 1, :, :] + 0.114*input_R[:, 2, :, :]
        input_R = torch.unsqueeze(input_R, dim=1)
        return torch.mean(self.gradient(input_I, "x") * torch.exp(-10 * self.ave_gradient(input_R, "x")) +
                          self.gradient(input_I, "y") * torch.exp(-10 * self.ave_gradient(input_R, "y")))

    def evaluate(self, epoch_num, eval_low_data_names, vis_dir, train_phase):
        print("Evaluating for phase %s / epoch %d..." % (train_phase, epoch_num))

        for idx in range(len(eval_low_data_names)):
            eval_low_img   = Image.open(eval_low_data_names[idx])
            eval_low_img   = np.array(eval_low_img, dtype="float32")/255.0
            eval_low_img   = np.transpose(eval_low_img, (2, 0, 1))
            input_low_eval = np.expand_dims(eval_low_img, axis=0)

            if train_phase == "Decom":
                self.forward(input_low_eval, input_low_eval)
                result_1 = self.output_R_low
                result_2 = self.output_I_low
                input    = np.squeeze(input_low_eval)
                result_1 = np.squeeze(result_1)
                result_2 = np.squeeze(result_2)
                cat_image= np.concatenate([input, result_1, result_2], axis=2)
            if train_phase == "Relight":
                self.forward(input_low_eval, input_low_eval)
                result_1 = self.output_R_low
                result_2 = self.output_I_low
                result_3 = self.output_I_delta
                result_4 = self.output_S
                input = np.squeeze(input_low_eval)
                result_1 = np.squeeze(result_1)
                result_2 = np.squeeze(result_2)
                result_3 = np.squeeze(result_3)
                result_4 = np.squeeze(result_4)
                cat_image= np.concatenate([input, result_1, result_2, result_3, result_4], axis=2)

            cat_image = np.transpose(cat_image, (1, 2, 0))
            # print(cat_image.shape)
            im = Image.fromarray(np.clip(cat_image * 255.0, 0, 255.0).astype('uint8'))
            filepath = os.path.join(vis_dir, 'eval_%s_%d_%d.png' %
                       (train_phase, idx + 1, epoch_num))
            im.save(filepath[:-4] + '.png')


    def save(self, iter_num, ckpt_dir):
        save_dir = ckpt_dir + '/' + self.train_phase + '/'
        save_name= save_dir + '/' + str(iter_num) + '.tar'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        if self.train_phase == 'Decom':
            torch.save(self.DecomNet.state_dict(), save_name)
        elif self.train_phase == 'Relight':
            torch.save(self.RelightNet.state_dict(),save_name)

    def load(self, ckpt_dir):
        load_dir   = ckpt_dir + '/' + self.train_phase + '/'
        if os.path.exists(load_dir):
            load_ckpts = os.listdir(load_dir)
            load_ckpts.sort()
            load_ckpts = sorted(load_ckpts, key=len)
            if len(load_ckpts)>0:
                load_ckpt  = load_ckpts[-1]
                global_step= int(load_ckpt[:-4])
                ckpt_dict  = torch.load(load_dir + load_ckpt)
                if self.train_phase == 'Decom':
                    self.DecomNet.load_state_dict(ckpt_dict)
                elif self.train_phase == 'Relight':
                    self.RelightNet.load_state_dict(ckpt_dict)
                return True, global_step
            else:
                return False, 0
        else:
            return False, 0


    def train(self,
              train_low_data_names,
              train_high_data_names,
              eval_low_data_names,
              batch_size,
              patch_size, epoch,
              lr,
              vis_dir,
              ckpt_dir,
              eval_every_epoch,
              train_phase):
        assert len(train_low_data_names) == len(train_high_data_names)
        numBatch = len(train_low_data_names) // int(batch_size)

        # Create the optimizers
        self.train_op_Decom   = optim.Adam(self.DecomNet.parameters(),
                                           lr=lr[0], betas=(0.9, 0.999))
        self.train_op_Relight = optim.Adam(self.RelightNet.parameters(),
                                           lr=lr[0], betas=(0.9, 0.999))

        # Initialize a network if its checkpoint is available
        self.train_phase= train_phase
        load_model_status, global_step = self.load(ckpt_dir)
        if load_model_status:
            iter_num    = global_step
            start_epoch = global_step // numBatch
            start_step  = global_step % numBatch
            print("Model restore success!")
        else:
            iter_num    = 0
            start_epoch = 0
            start_step  = 0
            print("No pretrained model to restore!")

        print("Start training for phase %s, with start epoch %d start iter %d : " %
             (self.train_phase, start_epoch, iter_num))

        start_time = time.time()
        image_id   = 0
        for epoch in range(start_epoch, epoch):
            self.lr = lr[epoch]
            # Adjust learning rate
            for param_group in self.train_op_Decom.param_groups:
                param_group['lr'] = self.lr
            for param_group in self.train_op_Relight.param_groups:
                param_group['lr'] = self.lr
            for batch_id in range(start_step, numBatch):
                # Generate training data for a batch
                batch_input_low = np.zeros((batch_size, 3, patch_size, patch_size,), dtype="float32")
                batch_input_high= np.zeros((batch_size, 3, patch_size, patch_size,), dtype="float32")
                for patch_id in range(batch_size):
                    # Load images
                    train_low_img = Image.open(train_low_data_names[image_id])
                    train_low_img = np.array(train_low_img, dtype='float32')/255.0
                    train_high_img= Image.open(train_high_data_names[image_id])
                    train_high_img= np.array(train_high_img, dtype='float32')/255.0
                    # Take random crops
                    h, w, _        = train_low_img.shape
                    x = random.randint(0, h - patch_size)
                    y = random.randint(0, w - patch_size)
                    train_low_img = train_low_img[x: x + patch_size, y: y + patch_size, :]
                    train_high_img= train_high_img[x: x + patch_size, y: y + patch_size, :]
                    # Data augmentation
                    if random.random() < 0.5:
                        train_low_img = np.flipud(train_low_img)
                        train_high_img= np.flipud(train_high_img)
                    if random.random() < 0.5:
                        train_low_img = np.fliplr(train_low_img)
                        train_high_img= np.fliplr(train_high_img)
                    rot_type = random.randint(1, 4)
                    if random.random() < 0.5:
                        train_low_img = np.rot90(train_low_img, rot_type)
                        train_high_img= np.rot90(train_high_img, rot_type)
                    # Permute the images to tensor format
                    train_low_img = np.transpose(train_low_img, (2, 0, 1))
                    train_high_img= np.transpose(train_high_img, (2, 0, 1))
                    # Prepare the batch
                    batch_input_low[patch_id, :, :, :] = train_low_img
                    batch_input_high[patch_id, :, :, :]= train_high_img
                    self.input_low = batch_input_low
                    self.input_high= batch_input_high

                    image_id = (image_id + 1) % len(train_low_data_names)
                    if image_id == 0:
                        tmp = list(zip(train_low_data_names, train_high_data_names))
                        random.shuffle(list(tmp))
                        train_low_data_names, train_high_data_names = zip(*tmp)


                # Feed-Forward to the network and obtain loss
                self.forward(self.input_low,  self.input_high)
                if self.train_phase == "Decom":
                    self.train_op_Decom.zero_grad()
                    self.loss_Decom.backward()
                    self.train_op_Decom.step()
                    loss = self.loss_Decom.item()
                elif self.train_phase == "Relight":
                    self.train_op_Relight.zero_grad()
                    self.loss_Relight.backward()
                    self.train_op_Relight.step()
                    loss = self.loss_Relight.item()

                print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
                      % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
                iter_num += 1

            # Evaluate the model and save a checkpoint file for it
            if (epoch + 1) % eval_every_epoch == 0:
                self.evaluate(epoch + 1, eval_low_data_names,
                              vis_dir=vis_dir, train_phase=train_phase)
                self.save(iter_num, ckpt_dir)

        print("Finished training for phase %s." % train_phase)


    def predict(self,
                test_low_data_names,
                res_dir,
                ckpt_dir):

        # Load the network with a pre-trained checkpoint
        self.train_phase= 'Decom'
        load_model_status, _ = self.load(ckpt_dir)
        if load_model_status:
            print(self.train_phase, "  : Model restore success!")
        else:
            print("No pretrained model to restore!")
            raise Exception
        self.train_phase= 'Relight'
        load_model_status, _ = self.load(ckpt_dir)
        if load_model_status:
             print(self.train_phase, ": Model restore success!")
        else:
            print("No pretrained model to restore!")
            raise Exception

        # Set this switch to True to also save the reflectance and shading maps
        save_R_L = False

        # Predict for the test images
        for idx in range(len(test_low_data_names)):
            test_img_path  = test_low_data_names[idx]
            test_img_name  = test_img_path.split('/')[-1]
            print('Processing ', test_img_name)
            test_low_img   = Image.open(test_img_path)
            test_low_img   = np.array(test_low_img, dtype="float32")/255.0
            test_low_img   = np.transpose(test_low_img, (2, 0, 1))
            input_low_test = np.expand_dims(test_low_img, axis=0)

            self.forward(input_low_test, input_low_test)
            result_1 = self.output_R_low
            result_2 = self.output_I_low
            result_3 = self.output_I_delta
            result_4 = self.output_S
            input = np.squeeze(input_low_test)
            result_1 = np.squeeze(result_1)
            result_2 = np.squeeze(result_2)
            result_3 = np.squeeze(result_3)
            result_4 = np.squeeze(result_4)
            if save_R_L:
                cat_image= np.concatenate([input, result_1, result_2, result_3, result_4], axis=2)
            else:
                cat_image= np.concatenate([input, result_4], axis=2)

            cat_image = np.transpose(cat_image, (1, 2, 0))
            # print(cat_image.shape)
            im = Image.fromarray(np.clip(cat_image * 255.0, 0, 255.0).astype('uint8'))
            filepath = res_dir + '/' + test_img_name
            actual_name=test_img_name.split('\\')[-1]
            im.save('/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/eval/enhanced/'+actual_name)

In [None]:
# Train

import os
import argparse
from glob import glob
import numpy as np

data_dir = '/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/train'
ckpt_dir = '/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/checkpoints'
vis_dir = ckpt_dir + '/visuals/'

# Create directories for saving the checkpoints and visuals
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
if not os.path.exists(vis_dir):
    os.makedirs(vis_dir)

# Setup CUDA environment


def train(model):
    epochs = 100
    batch_size = 16
    patch_size = 96
    lr = 0.001
    data_dir = '/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/train'
    ckpt_dir = '/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/checkpoints'
    vis_dir = ckpt_dir + '/visuals/'

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    if not os.path.exists(vis_dir):
        os.makedirs(vis_dir)

    lr = lr * np.ones([epochs])
    lr[20:] = lr[0] / 10.0

    train_low_data_names = glob(data_dir + '/low/*.png')
    train_low_data_names.sort()
    train_high_data_names= glob(data_dir + '/high/*.png')
    train_high_data_names.sort()
    eval_low_data_names  = glob(data_dir + '/eval/low/*.*')
    eval_low_data_names.sort()

    assert len(train_low_data_names) == len(train_high_data_names)
    print('Number of training data: %d' % len(train_low_data_names))

    model.train(train_low_data_names,
                train_high_data_names,
                eval_low_data_names,
                batch_size=batch_size,
                patch_size=patch_size,
                epoch=epochs,
                lr=lr,
                vis_dir=vis_dir,
                ckpt_dir=ckpt_dir,
                eval_every_epoch=10,
                train_phase="Decom")

    # Save the Decom model's state
    torch.save(model.state_dict(), '/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/checkpoints/decom.pth')

    model.train(train_low_data_names,
                train_high_data_names,
                eval_low_data_names,
                batch_size=batch_size,
                patch_size=patch_size,
                epoch=epochs,
                lr=lr,
                vis_dir=vis_dir,
                ckpt_dir=ckpt_dir,
                eval_every_epoch=10,
                train_phase="Relight")

    # Save the Relight model's state
    torch.save(model.state_dict(), '/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/checkpoints/relight.pth')


# Create the model
model = RetinexNet().to(device)

# Train the model
print("CUDA Available: ", torch.cuda.is_available())
train(model)



CUDA Available:  True
Number of training data: 485
Model restore success!
Start training for phase Decom, with start epoch 40 start iter 1200 : 
Decom Epoch: [41] [   1/  30] time: 1.8823, loss: 0.031096
Decom Epoch: [41] [   2/  30] time: 2.3623, loss: 0.058755
Decom Epoch: [41] [   3/  30] time: 2.8462, loss: 0.036819
Decom Epoch: [41] [   4/  30] time: 3.2991, loss: 0.036278
Decom Epoch: [41] [   5/  30] time: 3.7662, loss: 0.049514
Decom Epoch: [41] [   6/  30] time: 4.2159, loss: 0.042496
Decom Epoch: [41] [   7/  30] time: 4.6776, loss: 0.038259
Decom Epoch: [41] [   8/  30] time: 5.1558, loss: 0.032252
Decom Epoch: [41] [   9/  30] time: 5.6195, loss: 0.049727
Decom Epoch: [41] [  10/  30] time: 6.0981, loss: 0.044734
Decom Epoch: [41] [  11/  30] time: 6.5594, loss: 0.050619
Decom Epoch: [41] [  12/  30] time: 7.0373, loss: 0.052969
Decom Epoch: [41] [  13/  30] time: 7.5131, loss: 0.061880
Decom Epoch: [41] [  14/  30] time: 8.0073, loss: 0.040296
Decom Epoch: [41] [  15/  30]

In [None]:
import os
from glob import glob
import numpy as np
import torch


data_dir = '/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/eval/low'  # Directory storing the test data
ckpt_dir = '/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/checkpoints'  # Directory for checkpoints
res_dir = '/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/eval/enhanced'  # Directory for saving the results



# Create directories for saving the results
if not os.path.exists(res_dir):
    os.makedirs(res_dir)



# Initialize the model
model = RetinexNet().to(device)

# Load the trained weights into the model
# Assuming you have two separate model checkpoints for different phases
decom_ckpt_path = os.path.join(ckpt_dir, 'decom.pth')
relight_ckpt_path = os.path.join(ckpt_dir, 'relight.pth')
if os.path.exists(decom_ckpt_path) and os.path.exists(relight_ckpt_path):
    model.load_state_dict(torch.load(decom_ckpt_path, map_location=device))
else:
    print("Checkpoint not found.")

def predict(model):
    test_low_data_names = glob('/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/eval/low/*.png')
    test_low_data_names.sort()
    print('Number of evaluation images: %d' % len(test_low_data_names))

    # Assuming predict function is implemented within the model
    model.predict(test_low_data_names, res_dir=res_dir, ckpt_dir=ckpt_dir)

# Run the prediction
predict(model)


Number of evaluation images: 15
Decom   : Model restore success!
Relight : Model restore success!
Processing  1.png
1.png
Processing  111.png
111.png
Processing  146.png
146.png
Processing  179.png
179.png
Processing  22.png
22.png
Processing  23.png
23.png
Processing  493.png
493.png
Processing  547.png
547.png
Processing  55.png
55.png
Processing  665.png
665.png
Processing  669.png
669.png
Processing  748.png
748.png
Processing  778.png
778.png
Processing  780.png
780.png
Processing  79.png
79.png


In [None]:
import cv2
def compare_psnr(img1, img2, maxvalue):
  img1, img2 = img1.astype(np.float64), img2.astype(np.float64)
  mse = np.mean((img1 - img2) ** 2)
  return 10 * np.log10((maxvalue ** 2) / mse)

def compare_images_in_folders(folder1, folder2, maxvalue):
    psnr_values = []
    for filename1 in os.listdir(folder1):
        if filename1.endswith(('.jpg', '.png', '.jpeg')):
            filename2 = os.path.join(folder2, filename1)
            if os.path.exists(filename2):
                img1 = cv2.imread(os.path.join(folder1, filename1))
                img2 = cv2.imread(filename2)
                if img1 is not None and img2 is not None:
                    psnr = compare_psnr(img1, img2, maxvalue)
                    psnr_values.append(psnr)
                else:
                    print(f"Could not read images: {filename1}, {filename2}")
            else:
                print(f"File not found: {filename2}")
    if psnr_values:
        avg_psnr = np.mean(psnr_values)
        print(f"Average PSNR: {avg_psnr:.2f} dB")
    else:
        print("No valid image pairs found for comparison.")

folder1 = "/content/drive/MyDrive/Colab Notebooks/RetinexNet_self_do/eval/enhanced"
folder2 = "/content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/high"
maxvalue = 255  # Assuming 8-bit images
compare_images_in_folders(folder1, folder2, maxvalue)