In [None]:
import os
import random
import numpy as np
from glob import glob
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
import os
import shutil

def load_SICE_dataset(src_dir, dst_dir_low, dst_dir_high):
    os.makedirs(dst_dir_low, exist_ok=True)
    os.makedirs(dst_dir_high, exist_ok=True)
    for dir_name in os.listdir(src_dir):
        subdir_path = os.path.join(src_dir, dir_name)
        
        if os.path.isdir(subdir_path) and dir_name not in ["low", "high"]:
            images = [f for f in os.listdir(subdir_path) if f.endswith('.JPG')]

            images.sort(key=lambda f: int(os.path.splitext(f)[0]))

            for img in images[:-1]:
                img_path = os.path.join(subdir_path, img)
                dst_path_low = os.path.join(dst_dir_low, f"{dir_name}_{img}")
                shutil.copy(img_path, dst_path_low)
            
                highest_img_path = os.path.join(subdir_path, images[-1])
                dst_path_high = os.path.join(dst_dir_high, f"{dir_name}_{img}")
                shutil.copy(highest_img_path, dst_path_high)


In [None]:
!gdown https://drive.google.com/uc?id=1HiLtYiyT9R7dR9DRTLRlUUrAicC4zzWN # SICE Part 1
!pip install unrar
!unrar x Dataset_Part1.rar

In [None]:
load_SICE_dataset("./Dataset_Part1", "./data/SICE1/low", "./data/SICE1/high")

In [None]:
!rm -r Dataset_Part1
!rm Dataset_Part1.rar

In [None]:
!gdown https://drive.google.com/uc?id=16VoHNPAZ5Js19zspjFOsKiGRrfkDgHoN # SICE Part 2
!unrar x Dataset_Part2.rar

In [None]:
load_SICE_dataset("./Dataset_Part2", "./data/SICE2/low", "./data/SICE2/high")

In [None]:
!rm -r Dataset_Part2
!rm Dataset_Part2.rar

In [None]:
!gdown https://drive.google.com/uc?id=0B_FjaR958nw_eHhvUUN6MzBCQXc # DICM
!unzip datasets__DICM.zip

In [None]:
!gdown https://drive.google.com/uc?id=0B_FjaR958nw_QUpMeFlmYW5MUVE # LIME
!unzip datasets__LIME.zip

In [None]:
IMAGE_SIZE = 512
BATCH_SIZE = 8
MAX_TRAIN_IMAGES = 2422


def load_data(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
    image = image / 255.0
    return image


def data_generator(low_light_images):
    dataset = tf.data.Dataset.from_tensor_slices((low_light_images))
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    return dataset

train_low_light_images = sorted(glob("./data/SICE1/low/*"))[:MAX_TRAIN_IMAGES]
val_low_light_images = sorted(glob("./data/SICE1/low/*"))[MAX_TRAIN_IMAGES:]


train_dataset = data_generator(train_low_light_images)
val_dataset = data_generator(val_low_light_images)

print("Train Dataset:", train_dataset)
print("Validation Dataset:", val_dataset)

Train Dataset: <_BatchDataset element_spec=TensorSpec(shape=(8, 512, 512, 3), dtype=tf.float32, name=None)>
Validation Dataset: <_BatchDataset element_spec=TensorSpec(shape=(8, 512, 512, 3), dtype=tf.float32, name=None)>


In [None]:
!pip install tensorflow-addons
import tensorflow_addons as tfa

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Multiply, Reshape, Permute, Lambda
from tensorflow.keras import backend as K
from tensorflow.keras.initializers import Zeros, Ones
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input

class SALayer(layers.Layer):
    def __init__(self, channel, groups=None, **kwargs):
        super(SALayer, self).__init__(**kwargs)
        if groups is None:
            groups = 1
        self.groups = groups
        self.avg_pool = layers.GlobalAveragePooling2D(keepdims=True)
        self.cweight = self.add_weight(shape=(1, 1, 1, channel // (2 * groups)), initializer="zeros", trainable=True)
        self.cbias = self.add_weight(shape=(1, 1, 1, channel // (2 * groups)), initializer="ones", trainable=True)
        self.sweight = self.add_weight(shape=(1, 1, 1, channel // (2 * groups)), initializer="zeros", trainable=True)
        self.sbias = self.add_weight(shape=(1, 1, 1, channel // (2 * groups)), initializer="ones", trainable=True)
        self.sigmoid = layers.Activation("sigmoid")
        self.gn = tfa.layers.GroupNormalization(groups=channel // (2 * groups)) 


    @staticmethod
    def channel_shuffle(x, groups):
        shape = tf.shape(x)
        b, h, w, c = shape[0], shape[1], shape[2], shape[3]

        x = tf.reshape(x, [b, groups, h, w, c])
        x = tf.transpose(x, [0, 2, 3, 4, 1])

        x = tf.reshape(x, [b, h, w, c])

        return x

    def call(self, inputs):
        shape = tf.shape(inputs)
        b, h, w, c = shape[0], shape[1], shape[2], shape[3]
        x = tf.reshape(inputs, [b * self.groups, h, w, c])
        x_0, x_1 = tf.split(x, num_or_size_splits=2, axis=1)
        
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = self.sigmoid(xn)
        xn = x_0 * xn

        xs = self.gn(x_1)
        xs = self.sweight * xs + self.sbias
        xs = self.sigmoid(xs)
        xs = x_1 * xs

        out = tf.concat([xn, xs], axis=1)
        out = tf.reshape(out, [b, h, w, c])
        out = self.channel_shuffle(out, self.groups)
        return out


def ghost_module(input_tensor, output_channels):
    conv1x1 = layers.Conv2D(
        output_channels // 4, (1, 1), strides=(1, 1), activation="relu", padding="same"
    )(input_tensor)
    conv3x3 = layers.Conv2D(
        output_channels // 4, (3, 3), strides=(1, 1), activation="relu", padding="same"
    )(conv1x1)
    return layers.Concatenate()([conv1x1, conv3x3])

def build_dce_net():
    input_img = layers.Input(shape=(None, None, 3))
    ghost1 = ghost_module(input_img, 32)
    sa1 = SALayer(32)(ghost1)
    half_ghost1 = layers.Lambda(lambda x: x[:, :, :, :3 // 2])(sa1)
    ghost2 = ghost_module(sa1, 16)
    sa2 = SALayer(16)(ghost2)
    ghost3 = ghost_module(sa2, 16)
    sa3 = SALayer(16)(ghost3)
    ghost4 = ghost_module(sa3, 16)
    sa4 = SALayer(16)(ghost4)
    skip_conn1 = layers.Concatenate(axis=-1)([sa4, sa3])
    ghost5 = ghost_module(skip_conn1, 16)
    sa5 = SALayer(16)(ghost5)
    skip_conn2 = layers.Concatenate(axis=-1)([sa5, sa2])
    ghost6 = ghost_module(skip_conn2, 16)
    sa6 = SALayer(16)(ghost6)
    skip_conn3 = layers.Concatenate(axis=-1)([sa6, sa1])
    ghost7 = ghost_module(skip_conn3, 32)
    sa7 = SALayer(32)(ghost7)
    skip_conn4 = layers.Concatenate(axis=-1)([sa7, half_ghost1])
    x_r = layers.SeparableConv2D(3, (3, 3), strides=(1, 1), activation="tanh", padding="same")(skip_conn4)
    return keras.Model(inputs=input_img, outputs=x_r)

In [None]:
def color_constancy_loss(x):
    mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
    mr, mg, mb = mean_rgb[:, :, :, 0], mean_rgb[:, :, :, 1], mean_rgb[:, :, :, 2]
    d_rg = tf.square(mr - mg)
    d_rb = tf.square(mr - mb)
    d_gb = tf.square(mb - mg)
    return tf.sqrt(tf.square(d_rg) + tf.square(d_rb) + tf.square(d_gb))


In [None]:
def exposure_loss(x, mean_val=0.6):
    x = tf.reduce_mean(x, axis=3, keepdims=True)
    mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
    return tf.reduce_mean(tf.square(mean - mean_val))


In [None]:
def illumination_smoothness_loss(x):
    batch_size = tf.shape(x)[0]
    h_x = tf.shape(x)[1]
    w_x = tf.shape(x)[2]
    count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
    count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
    h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
    w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
    batch_size = tf.cast(batch_size, dtype=tf.float32)
    count_h = tf.cast(count_h, dtype=tf.float32)
    count_w = tf.cast(count_w, dtype=tf.float32)
    return 2 * (h_tv / count_h + w_tv / count_w) / batch_size


In [None]:
class SpatialConsistencyLoss(keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(reduction="none")

        self.left_kernel = tf.constant(
            [[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
        )
        self.right_kernel = tf.constant(
            [[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
        )
        self.up_kernel = tf.constant(
            [[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
        )
        self.down_kernel = tf.constant(
            [[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
        )

    def call(self, y_true, y_pred):

        original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
        enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
        original_pool = tf.nn.avg_pool2d(
            original_mean, ksize=4, strides=4, padding="VALID"
        )
        enhanced_pool = tf.nn.avg_pool2d(
            enhanced_mean, ksize=4, strides=4, padding="VALID"
        )

        d_original_left = tf.nn.conv2d(
            original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_original_right = tf.nn.conv2d(
            original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_original_up = tf.nn.conv2d(
            original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_original_down = tf.nn.conv2d(
            original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )

        d_enhanced_left = tf.nn.conv2d(
            enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_enhanced_right = tf.nn.conv2d(
            enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_enhanced_up = tf.nn.conv2d(
            enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )
        d_enhanced_down = tf.nn.conv2d(
            enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )

        d_left = tf.square(d_original_left - d_enhanced_left)
        d_right = tf.square(d_original_right - d_enhanced_right)
        d_up = tf.square(d_original_up - d_enhanced_up)
        d_down = tf.square(d_original_down - d_enhanced_down)
        return d_left + d_right + d_up + d_down


In [None]:
class SAZeroDCETiny(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dce_model = build_dce_net()

    def compile(self, learning_rate, **kwargs):
        super().compile(**kwargs)
        self.optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")

    def get_enhanced_image(self, data, output):
        r = output
        enhanced_image = data
        for _ in range(8):
          enhanced_image = enhanced_image + r * (tf.square(enhanced_image) - enhanced_image)
        return enhanced_image

    def call(self, data):
        dce_net_output = self.dce_model(data)
        return self.get_enhanced_image(data, dce_net_output)

    def compute_losses(self, data, output):
        enhanced_image = self.get_enhanced_image(data, output)
        loss_illumination = 200 * illumination_smoothness_loss(output)
        loss_spatial_constancy = tf.reduce_mean(
            self.spatial_constancy_loss(enhanced_image, data)
        )
        loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
        loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
        total_loss = (
            loss_illumination
            + loss_spatial_constancy
            + loss_color_constancy
            + loss_exposure
        )
        return {
            "total_loss": total_loss,
            "illumination_smoothness_loss": loss_illumination,
            "spatial_constancy_loss": loss_spatial_constancy,
            "color_constancy_loss": loss_color_constancy,
            "exposure_loss": loss_exposure,
        }

    def train_step(self, data):
        with tf.GradientTape() as tape:
            output = self.dce_model(data)
            losses = self.compute_losses(data, output)
        gradients = tape.gradient(
            losses["total_loss"], self.dce_model.trainable_weights
        )
        self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
        return losses

    def test_step(self, data):
        output = self.dce_model(data)
        return self.compute_losses(data, output)

    def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
        """While saving the weights, we simply save the weights of the DCE-Net"""
        self.dce_model.save_weights(
            filepath, overwrite=overwrite, save_format=save_format, options=options
        )

    def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
        """While loading the weights, we simply load the weights of the DCE-Net"""
        self.dce_model.load_weights(
            filepath=filepath,
            by_name=by_name,
            skip_mismatch=skip_mismatch,
            options=options,
        )


In [None]:
zero_dce_model = ZeroDCE()
zero_dce_model.compile(learning_rate=1e-4)
history = zero_dce_model.fit(train_dataset, validation_data=val_dataset, epochs=100)


def plot_result(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_result("total_loss")
plot_result("illumination_smoothness_loss")
plot_result("spatial_constancy_loss")
plot_result("color_constancy_loss")
plot_result("exposure_loss")


In [None]:
# zero_dce_model.build((None, 256, 256, 3))
# zero_dce_model.summary()

In [None]:
def plot_results(images, titles, figure_size=(12, 12)):
    fig = plt.figure(figsize=figure_size)
    for i in range(len(images)):
        fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
        _ = plt.imshow(images[i])
        plt.axis("off")
    plt.show()


def infer(original_image):
    image = keras.preprocessing.image.img_to_array(original_image)
    image = image.astype("float32") / 255.0
    image = np.expand_dims(image, axis=0)
    output_image = zero_dce_model(image)
    output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
    output_image = Image.fromarray(output_image.numpy())
    return output_image


In [None]:
for val_image_file in val_low_light_images[:5]:
    original_image = Image.open(val_image_file)
    enhanced_image = infer(original_image)
    plot_results(
        [original_image, ImageOps.autocontrast(original_image), enhanced_image],
        ["Original", "PIL Autocontrast", "Enhanced"],
        (20, 12),
    )


In [None]:
test_low_light_images = sorted(glob("./data/SICE2/low/*"))[:2300]
test_low_light_images += sorted(glob("./data/DICM/low/*"))[:64]
test_low_light_images += sorted(glob("./data/LIME/low/*"))[:10]

test_high_light_images = sorted(glob("./data/SICE2/high/*"))[:2300]
test_high_light_images += sorted(glob("./data/DICM/high/*"))[:64]
test_high_light_images += sorted(glob("./data/LIME/high/*"))[:10]


In [None]:
from skimage.measure import structural_similarity, peak_signal_noise_ratio
from tensorflow.keras.utils import img_to_array
import numpy as np

def calculate_psnr(img1, img2):
    img1 = img_to_array(img1)
    img2 = img_to_array(img2)
    return peak_signal_noise_ratio(img1, img2, data_range=img1.max() - img1.min())

def calculate_ssim(img1, img2):
    img1 = img_to_array(img1)
    img2 = img_to_array(img2)
    return structural_similarity(img1, img2, multichannel=True)

def calculate_mae(img1, img2):
    img1 = img_to_array(img1)
    img2 = img_to_array(img2)
    return np.mean(np.abs(img1 - img2))

In [None]:
total_psnr, total_ssim, total_mae = 0, 0, 0
num_images = len(test_low_light_images)

for idx, val_image_file in enumerate(test_low_light_images):
    original_low_image = Image.open(val_image_file)
    enhanced_image = infer(original_image)
    original_high_image = Image.open(test_high_light_images[idx])

    psnr = calculate_psnr(original_high_image, enhanced_image)
    ssim = calculate_ssim(original_high_image, enhanced_image)
    mae = calculate_mae(original_high_image, enhanced_image)

    total_psnr += psnr
    total_ssim += ssim
    total_mae += mae

avg_psnr = total_psnr / num_images
avg_ssim = total_ssim / num_images
avg_mae = total_mae / num_images

print("Average PSNR:", avg_psnr)
print("Average SSIM:", avg_ssim)
print("Average MAE:", avg_mae)


In [None]:
model_save_path = "sa_zero_dce_tiny_model.h5"
zero_dce_model.save_weights(model_save_path)

In [None]:
import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model(zero_dce_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()


In [None]:
with open("sa_zero_dce_tiny_model.tflite", "wb") as f:
    f.write(tflite_model)


In [None]:
import numpy as np
import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path='sa_zero_dce_tiny_512x512.tflite')

interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


In [None]:
import numpy as np
from PIL import Image
import tensorflow as tf
from skimage.transform import resize

def infer_tflite(original_image, interpreter):
    image = np.array(original_image)

    image = image.astype("float32") / 255.0

    image = np.expand_dims(image, axis=0)

    image = resize(image, (512,512, 3))

    interpreter.set_tensor(interpreter.get_input_details()[0]['index'], image)

    interpreter.invoke()

    output_image = interpreter.get_tensor(interpreter.get_output_details()[0]['index'])

    output_image = (output_image[0] * 255).astype(np.uint8)

    output_image = Image.fromarray(output_image)

    return output_image


In [None]:
for val_image_file in test_low_light_images[:1]:
    original_image = Image.open(val_image_file)
    enhanced_image = infer_tflite(original_image, interpreter)
    plot_results(
        [original_image, ImageOps.autocontrast(original_image), enhanced_image],
        ["Original", "PIL Autocontrast", "Enhanced"],
        (20, 12),
    )
