In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Model
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
#import pytorch_colors as colors
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class enhance_net_nopool(nn.Module):

	def __init__(self):
		super(enhance_net_nopool, self).__init__()

		self.relu = nn.ReLU(inplace=True)

		number_f = 32
		self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True)
		self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
		self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
		self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True)

		self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
		self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)



	def forward(self, x):

		x1 = self.relu(self.e_conv1(x))
		# p1 = self.maxpool(x1)
		x2 = self.relu(self.e_conv2(x1))
		# p2 = self.maxpool(x2)
		x3 = self.relu(self.e_conv3(x2))
		# p3 = self.maxpool(x3)
		x4 = self.relu(self.e_conv4(x3))

		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
		# x5 = self.upsample(x5)
		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))

		x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
		r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)


		x = x + r1*(torch.pow(x,2)-x)
		x = x + r2*(torch.pow(x,2)-x)
		x = x + r3*(torch.pow(x,2)-x)
		enhance_image_1 = x + r4*(torch.pow(x,2)-x)
		x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
		x = x + r6*(torch.pow(x,2)-x)
		x = x + r7*(torch.pow(x,2)-x)
		enhance_image = x + r8*(torch.pow(x,2)-x)
		r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
		return enhance_image_1,enhance_image,r





In [3]:
# Loss
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.models.vgg import vgg16
import numpy as np

# Color Constancy Loss
class L_color(nn.Module):

    def __init__(self):
        super(L_color, self).__init__()

    def forward(self, x ):

        b,c,h,w = x.shape

        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Drg = torch.pow(mr-mg,2)
        Drb = torch.pow(mr-mb,2)
        Dgb = torch.pow(mb-mg,2)
        k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)

        return k.mean()

# Spatial Consistency Loss
class L_spa(nn.Module):

    def __init__(self):
        super(L_spa, self).__init__()
        # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).to(device).unsqueeze(0).unsqueeze(0)
        kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).to(device).unsqueeze(0).unsqueeze(0)
        kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).to(device).unsqueeze(0).unsqueeze(0)
        kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).to(device).unsqueeze(0).unsqueeze(0)
        self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
        self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
        self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
        self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
        self.pool = nn.AvgPool2d(4)

    def forward(self, org , enhance ):
        b,c,h,w = org.shape

        org_mean = torch.mean(org,1,keepdim=True)
        enhance_mean = torch.mean(enhance,1,keepdim=True)

        org_pool =  self.pool(org_mean)
        enhance_pool = self.pool(enhance_mean)

        weight_diff =torch.max(torch.FloatTensor([1]).to(device) + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).to(device),torch.FloatTensor([0]).to(device)),torch.FloatTensor([0.5]).to(device))
        E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).to(device)) ,enhance_pool-org_pool)


        D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
        D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
        D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
        D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)

        D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
        D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
        D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
        D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)

        D_left = torch.pow(D_org_letf - D_enhance_letf,2)
        D_right = torch.pow(D_org_right - D_enhance_right,2)
        D_up = torch.pow(D_org_up - D_enhance_up,2)
        D_down = torch.pow(D_org_down - D_enhance_down,2)
        E = (D_left + D_right + D_up +D_down)

        E = 25*(D_left + D_right + D_up + D_down)
        return E.mean()

# Exposure Control Loss
class L_exp(nn.Module):

    def __init__(self,patch_size,mean_val):
        super(L_exp, self).__init__()
        # print(1)
        self.pool = nn.AvgPool2d(patch_size)
        self.mean_val = mean_val
    def forward(self, x ):

        b,c,h,w = x.shape
        x = torch.mean(x,1,keepdim=True)
        mean = self.pool(x)

        d = torch.mean(torch.pow(mean - torch.FloatTensor([self.mean_val] ).to(device),2))
        return d

# Illumination Smoothness Loss
class L_TV(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(L_TV,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h =  (x.size()[2]-1) * x.size()[3]
        count_w = x.size()[2] * (x.size()[3] - 1)
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight * 2 * (h_tv/count_h + w_tv/count_w) / batch_size


In [4]:
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
import glob
import random


random.seed(1143)


def populate_train_list(lowlight_images_path):

        image_list_lowlight = glob.glob(lowlight_images_path + "*.png")


        train_list = image_list_lowlight

        random.shuffle(train_list)

        return train_list

class lowlight_loader(data.Dataset):

    def __init__(self, lowlight_images_path):

        self.train_list = populate_train_list(lowlight_images_path)
        self.size = 256
        self.data_list = self.train_list
        print("Total training examples:", len(self.train_list))
        if len(self.train_list) == 0:
            raise ValueError(f"No images found in {lowlight_images_path}. Check the path and file type.")


    def __getitem__(self, index):

        data_lowlight_path = self.data_list[index]
        data_lowlight = Image.open(data_lowlight_path)

        data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)

        data_lowlight = (np.asarray(data_lowlight)/255.0)
        data_lowlight = torch.from_numpy(data_lowlight).float()

        return data_lowlight.permute(2,0,1)

    def __len__(self):
        return len(self.data_list)



In [5]:
#  data Loader
import torch.optim as optim


print("CUDA Available: ", torch.cuda.is_available())
# Initialize your model
model = enhance_net_nopool().to(device)

# Initialize your dataset and dataloader
dataset = lowlight_loader(lowlight_images_path='drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/train_data/')
dataloader = data.DataLoader(dataset, batch_size=16, shuffle=True)

# Define your optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Define your loss function(s)
# Assuming you're using the provided loss functions, you'll need to instantiate them
color_loss = L_color().to(device)
spa_loss = L_spa().to(device)
exp_loss = L_exp(patch_size=16, mean_val=0.6).to(device)
tv_loss = L_TV().to(device)

# perception_loss = perception_loss().to(device) # Make sure you have pretrained VGG available

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    for i, data_lowlight in enumerate(dataloader):
        # Move data to GPU
        data_lowlight = data_lowlight.to(device)

        # Forward pass
        enhanced_image_1, enhanced_image, _ = model(data_lowlight)

        # Calculate loss
        loss_color = 5*color_loss(enhanced_image)
        loss_spa = spa_loss(data_lowlight, enhanced_image)
        loss_TV = 200*tv_loss(_)
        loss_exp = 10*torch.mean(exp_loss(enhanced_image))
        loss = loss_color + loss_spa + loss_TV + loss_exp  # Combine losses as needed


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}')

print('Finished Training')
PATH = './drive/My Drive/Colab Notebooks/Zero_DCE_self_do/Zero_DCE_model.pth'
torch.save(model.state_dict(), PATH)
print('Model saved successfully')

CUDA Available:  True
Total training examples: 485


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Epoch [1/50], Step [1/31], Loss: 3.0667598247528076
Epoch [1/50], Step [2/31], Loss: 3.0426180362701416
Epoch [1/50], Step [3/31], Loss: 3.138869285583496
Epoch [1/50], Step [4/31], Loss: 2.9770939350128174
Epoch [1/50], Step [5/31], Loss: 3.1570420265197754
Epoch [1/50], Step [6/31], Loss: 3.1616480350494385
Epoch [1/50], Step [7/31], Loss: 3.1254079341888428
Epoch [1/50], Step [8/31], Loss: 3.02634334564209
Epoch [1/50], Step [9/31], Loss: 3.0036933422088623
Epoch [1/50], Step [10/31], Loss: 2.9788804054260254
Epoch [1/50], Step [11/31], Loss: 2.9850525856018066
Epoch [1/50], Step [12/31], Loss: 2.980145215988159
Epoch [1/50], Step [13/31], Loss: 2.9028241634368896
Epoch [1/50], Step [14/31], Loss: 2.8682448863983154
Epoch [1/50], Step [15/31], Loss: 3.0270180702209473
Epoch [1/50], Step [16/31], Loss: 2.9220030307769775
Epoch [1/50], Step [17/31], Loss: 2.6797852516174316
Epoch [1/50], Step [18/31], Loss: 2.8650057315826416
Epoch [1/50], Step [19/31], Loss: 2.5169107913970947
Epoch 

In [6]:
# Test with own image
import torch
from PIL import Image
import numpy as np
import os
from torchvision import transforms

# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = enhance_net_nopool().to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/Zero_DCE_model.pth', map_location=device))
model.eval()

# Load your low-light image
def load_image(image_path):
    image = Image.open(image_path)
    # Assuming the input images are resized to 256x256 during training
    image = image.resize((600, 400), Image.ANTIALIAS)
    image = np.asarray(image) / 255.0  # Normalize if that's what you did during training
    image_tensor = torch.from_numpy(image).float()
    return image_tensor.permute(2, 0, 1).unsqueeze(0)  # Add batch dimension

# Enhance the image
def enhance_image(model, image_tensor):
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        enhanced_image, _, _ = model(image_tensor)
    return enhanced_image

# Save the enhanced image
def save_image(tensor, filename):
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = image.permute(1, 2, 0)
    image = image.numpy()
    image = (image * 255).astype(np.uint8)
    image = Image.fromarray(image)
    image.save(filename)

# Process all images in the directory
input_folder = '/content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/low'
output_folder = '/content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced'
os.makedirs(output_folder, exist_ok=True)

for image_name in os.listdir(input_folder):
    if image_name.endswith('.png'):  # Assuming the images are PNGs
        input_path = os.path.join(input_folder, image_name)
        output_path = os.path.join(output_folder, image_name)

        input_image = load_image(input_path)
        enhanced_image = enhance_image(model, input_image)
        save_image(enhanced_image, output_path)
        print(f'Enhanced image saved to {output_path}')


  image = image.resize((600, 400), Image.ANTIALIAS)


Enhanced image saved to /content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced/179.png
Enhanced image saved to /content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced/780.png
Enhanced image saved to /content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced/146.png
Enhanced image saved to /content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced/547.png
Enhanced image saved to /content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced/493.png
Enhanced image saved to /content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced/111.png
Enhanced image saved to /content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced/665.png
Enhanced image saved to /content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced/669.png
Enhanced image saved to /content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced/

In [7]:
import cv2
def compare_psnr(img1, img2, maxvalue):
  img1, img2 = img1.astype(np.float64), img2.astype(np.float64)
  mse = np.mean((img1 - img2) ** 2)
  return 10 * np.log10((maxvalue ** 2) / mse)

def compare_images_in_folders(folder1, folder2, maxvalue):
    psnr_values = []
    for filename1 in os.listdir(folder1):
        if filename1.endswith(('.jpg', '.png', '.jpeg')):
            filename2 = os.path.join(folder2, filename1)
            if os.path.exists(filename2):
                img1 = cv2.imread(os.path.join(folder1, filename1))
                img2 = cv2.imread(filename2)
                if img1 is not None and img2 is not None:
                    psnr = compare_psnr(img1, img2, maxvalue)
                    psnr_values.append(psnr)
                else:
                    print(f"Could not read images: {filename1}, {filename2}")
            else:
                print(f"File not found: {filename2}")
    if psnr_values:
        avg_psnr = np.mean(psnr_values)
        print(f"Average PSNR: {avg_psnr:.2f} dB")
    else:
        print("No valid image pairs found for comparison.")

folder1 = "/content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/enhanced"
folder2 = "/content/drive/MyDrive/Colab Notebooks/Zero_DCE_self_do/data/test_data/high"
maxvalue = 255  # Assuming 8-bit images
compare_images_in_folders(folder1, folder2, maxvalue)



Average PSNR: 12.18 dB
