# Deep Learning-Based Low Light Image Enhancement for Improved Visibility

# Importing Required Libraries

In [1]:
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import argparse
import os
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import numpy as np
from scipy import fft
from skimage import io, exposure, img_as_ubyte, img_as_float
from tqdm import trange
import matplotlib.pyplot as plt  
import warnings

warnings.filterwarnings('ignore')

# Class for performing Local Illumination Map Enhancement (LIME)

In [2]:
class LIME:
    def __init__(self, iterations=10, alpha=2, rho=2, gamma=0.7, strategy=2, *args, **kwargs):
        self.iterations = iterations
        self.alpha = alpha
        self.rho = rho
        self.gamma = gamma
        self.strategy = strategy

    def load(self, imgPath):
        self.L = img_as_float(io.imread(imgPath))
        self.row = self.L.shape[0]
        self.col = self.L.shape[1]

        self.T_hat = np.max(self.L, axis=2)
        self.dv = self.firstOrderDerivative(self.row)
        self.dh = self.firstOrderDerivative(self.col, -1)
        self.vecDD = self.toeplitizMatrix(self.row * self.col, self.row)
        self.W = self.weightingStrategy()

    def firstOrderDerivative(self, n, k=1):
        return np.eye(n) * (-1) + np.eye(n, k=k)

    def toeplitizMatrix(self, n, row):
        vecDD = np.zeros(n)
        vecDD[0] = 4
        vecDD[1] = -1
        vecDD[row] = -1
        vecDD[-1] = -1
        vecDD[-row] = -1
        return vecDD

    def weightingStrategy(self):
        if self.strategy == 2:
            dTv = self.dv @ self.T_hat
            dTh = self.T_hat @ self.dh
            Wv = 1 / (np.abs(dTv) + 1)
            Wh = 1 / (np.abs(dTh) + 1)
            return np.vstack([Wv, Wh])
        else:
            return np.ones((self.row * 2, self.col))

    def __T_subproblem(self, G, Z, u):
        X = G - Z / u
        Xv = X[:self.row, :]
        Xh = X[self.row:, :]
        temp = self.dv @ Xv + Xh @ self.dh
        numerator = fft.fft(self.vectorize(2 * self.T_hat + u * temp))
        denominator = fft.fft(self.vecDD * u) + 2
        T = fft.ifft(numerator / denominator)
        T = np.real(self.reshape(T, self.row, self.col))
        return exposure.rescale_intensity(T, (0, 1), (0.001, 1))

    def __G_subproblem(self, T, Z, u, W):
        dT = self.__derivative(T)
        epsilon = self.alpha * W / u
        X = dT + Z / u
        return np.sign(X) * np.maximum(np.abs(X) - epsilon, 0)

    def __Z_subproblem(self, T, G, Z, u):
        dT = self.__derivative(T)
        return Z + u * (dT - G)

    def __u_subproblem(self, u):
        return u * self.rho

    def __derivative(self, matrix):
        v = self.dv @ matrix
        h = matrix @ self.dh
        return np.vstack([v, h])

    def illumMap(self):
        T = np.zeros((self.row, self.col))
        G = np.zeros((self.row * 2, self.col))
        Z = np.zeros((self.row * 2, self.col))
        u = 1

        for _ in trange(0, self.iterations):
            T = self.__T_subproblem(G, Z, u)
            G = self.__G_subproblem(T, Z, u, self.W)
            Z = self.__Z_subproblem(T, G, Z, u)
            u = self.__u_subproblem(u)

        return T ** self.gamma

    def enhance(self):
        self.T = self.illumMap()
        self.R = self.L / np.repeat(self.T[:, :, np.newaxis], 3, axis=2)
        self.R = exposure.rescale_intensity(self.R, (0, 1))
        self.R = img_as_ubyte(self.R)
        return self.R

    def vectorize(self, matrix):
        return matrix.T.ravel()

    def reshape(self, vector, row, col):
        return vector.reshape((row, col), order='F')

# Image Enhancement Function

In [3]:
def enhance_image(filePath, options):
    lime = LIME(**vars(options))
    lime.load(filePath)
    enhanced_image = lime.enhance()
    
    filename = os.path.split(filePath)[-1]
    
    # Create the output directory if it doesn't exist
    if options.output and not os.path.exists(options.output):
        os.makedirs(options.output)
    
    if options.output:
        savePath = os.path.join(options.output, f"enhanced_{filename}")
        plt.imsave(savePath, enhanced_image)

    original_image = Image.open(filePath)
    original_array = np.array(original_image)
    enhanced_array = np.array(enhanced_image)

    # Calculate PSNR (Peak Signal-to-Noise Ratio)
    psnr_value = peak_signal_noise_ratio(original_array, enhanced_array)
    print('PSNR', psnr_value)
    
    # Calculate SSIM (Structural Similarity Index) with explicit win_size
    ssim_value = structural_similarity(original_array, enhanced_array, multichannel=True, win_size=7, channel_axis=-1)
    print('SSIM', ssim_value)
    
    # Calculate MSE (Mean Squared Error)
    mse_value = np.mean((original_array - enhanced_array) ** 2)
    print('MSE', mse_value)

    return Image.fromarray(enhanced_array)


# Class for GUI-based image enhancement application

In [4]:
class ImageEnhancementApp:
    def __init__(self, master, options):
        self.master = master
        self.options = options
        master.title("Image Enhancement")

        self.label = tk.Label(master, text="Select an image to enhance:")
        self.label.pack()

        self.select_button = tk.Button(master, text="Select Image", command=self.select_image)
        self.select_button.pack()

        self.enhance_button = tk.Button(master, text="Enhance Image", command=self.enhance_image)
        self.enhance_button.pack()

        self.original_image_label = tk.Label(master)
        self.original_image_label.pack(pady=10)

        self.enhanced_image_label = tk.Label(master)
        self.enhanced_image_label.pack(pady=10)

    def select_image(self):
        file_path = filedialog.askopenfilename(filetypes=[("Image Files", "*.jpg;*.jpeg;*.png;*.bmp")])
        if file_path:
            self.image = Image.open(file_path)
            self.photo = ImageTk.PhotoImage(self.image)
            self.original_image_label.config(image=self.photo)

            # Update the file_path variable in the enhancement options
            self.options.filePath = file_path

    def enhance_image(self):
        if hasattr(self, 'image'):
            # Call the enhance_image function with options and the selected image
            enhanced_image = enhance_image(self.options.filePath, self.options)

            # Update the image label with the enhanced image
            self.enhanced_photo = ImageTk.PhotoImage(enhanced_image)
            self.enhanced_image_label.config(image=self.enhanced_photo)
        else:
            messagebox.showerror("Error", "Please select an image first.")


# Main GUI Function and Argument Parsing

In [None]:
def main_gui(options):
    root = tk.Tk()
    app = ImageEnhancementApp(root, options)
    root.mainloop()

if __name__ == "__main__":
    # Parse the arguments for enhancement options
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--filePath", default="./data/1.bmp", type=str, help="image path to enhance")
    parser.add_argument("-m", "--map", action="store_true", help="save illumination map")
    parser.add_argument("-o", "--output", default="./", type=str, help="output folder")
    parser.add_argument("-i", "--iterations", default=10, type=int, help="iteration number")
    parser.add_argument("-a", "--alpha", default=2, type=int, help="parameter of alpha")
    parser.add_argument("-r", "--rho", default=2, type=int, help="parameter of rho")
    parser.add_argument("-g", "--gamma", default=0.7, type=float, help="parameter of gamma")
    parser.add_argument("-s", "--strategy", default=2, type=int, choices=[1, 2], help="weighting strategy")
    options = parser.parse_args()

    main_gui(options)


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 38.31it/s]


PSNR 13.755897494240443
SSIM 0.47672970282533994
MSE 55.193595400454605
