In [5]:
import numpy as np 
import os
import glob
import skimage.io as io
import skimage.transform as trans
import tensorflow as tf
from keras import *
from keras import backend as K
from keras.callbacks import *
from keras.layers import *
from keras.models import *
from keras.optimizers import *
from keras.preprocessing.image import *
from keras.losses import *
import cv2

from tensorflow.python.compat import compat
from tensorflow.python.framework import *
from tensorflow.python.ops import *
from tensorflow.python.util import *
import matplotlib.pyplot as plt


In [6]:
## PATHS AND EXPERIMENT INFO ##
train_data_image_folder = 'lol_dataset/train/high'
train_data_mask_folder = 'lol_dataset/train/low'
val_data_image_folder = 'lol_dataset/test/high'
val_data_mask_folder = 'lol_dataset/test/low'
model_folder = 'model/'

IMG_HEIGHT = 256
IMG_WIDTH = 256
NUM_EPOCHS = 60
BATCH_SIZE = 2

In [7]:
# Limit GPU memory growth - add this after imports
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("GPU memory growth enabled")
    except RuntimeError as e:
        print(f"Error setting memory growth: {e}")

GPU memory growth enabled


In [8]:
## LOAD TRAIN AND VALIDATION IMAGES AND RESPECTIVE MASKS ##
train_images = []
train_masks = []
val_images = []
val_masks = []

train_files = os.listdir(train_data_image_folder)
train_files = sorted([file for file in train_files if file.endswith(".png")])

test_files = os.listdir(val_data_image_folder)
test_files = sorted([file for file in test_files if file.endswith(".png")])

print(train_files)
print(test_files)

# LOAD TRAIN IMAGES
for image in train_files:

  print(image)

  train_img = cv2.imread(os.path.join(train_data_image_folder, image))
  train_img = cv2.resize(train_img, (IMG_HEIGHT, IMG_WIDTH), interpolation = cv2.INTER_CUBIC)
  train_images.append(train_img)

  train_mask = cv2.imread(os.path.join(train_data_mask_folder, image))
  train_mask = cv2.resize(train_mask, (IMG_HEIGHT, IMG_WIDTH), interpolation = cv2.INTER_CUBIC)
  train_masks.append(train_mask)

# LOAD VAL IMAGES
for image in test_files:

  print(image)

  val_img = cv2.imread(os.path.join(val_data_image_folder, image))
  val_img = cv2.resize(val_img, (IMG_HEIGHT, IMG_WIDTH), interpolation = cv2.INTER_CUBIC)
  val_images.append(val_img)

  val_mask = cv2.imread(os.path.join(val_data_mask_folder, image))
  val_mask = cv2.resize(val_mask, (IMG_HEIGHT, IMG_WIDTH), interpolation = cv2.INTER_CUBIC)
  val_masks.append(val_mask)

# CREATE THE ARRAYS
train_image_array = np.array(train_images)
train_mask_array = np.array(train_masks)
val_image_array = np.array(val_images)
val_mask_array = np.array(val_masks)

print(train_image_array.shape,train_mask_array.shape)
print(val_image_array.shape,val_mask_array.shape)

['10.png', '100.png', '101.png', '102.png', '103.png', '104.png', '105.png', '106.png', '107.png', '109.png', '110.png', '112.png', '113.png', '114.png', '115.png', '116.png', '117.png', '118.png', '119.png', '12.png', '120.png', '121.png', '122.png', '123.png', '124.png', '125.png', '126.png', '127.png', '128.png', '129.png', '13.png', '130.png', '131.png', '132.png', '135.png', '136.png', '137.png', '138.png', '139.png', '14.png', '140.png', '141.png', '142.png', '143.png', '144.png', '145.png', '147.png', '149.png', '15.png', '150.png', '151.png', '152.png', '154.png', '157.png', '159.png', '16.png', '160.png', '162.png', '167.png', '169.png', '17.png', '171.png', '172.png', '173.png', '174.png', '175.png', '176.png', '18.png', '180.png', '183.png', '184.png', '185.png', '186.png', '187.png', '188.png', '189.png', '191.png', '193.png', '194.png', '195.png', '196.png', '198.png', '199.png', '2.png', '200.png', '201.png', '202.png', '203.png', '204.png', '206.png', '207.png', '209.png

In [9]:
## DEFINE METRICS AND LOSS COMPONENTS ##
# Define VGG blocks at the module level, outside of any function
from keras.applications.vgg16 import VGG16
import keras.backend as K

# Create VGG model with blocks outside of any function
vgg_model = VGG16(include_top=False, weights='imagenet', input_shape=(256, 256, 3))
loss_block1 = Model(inputs=vgg_model.input, outputs=vgg_model.get_layer('block1_conv2').output)
loss_block2 = Model(inputs=vgg_model.input, outputs=vgg_model.get_layer('block2_conv2').output)
loss_block3 = Model(inputs=vgg_model.input, outputs=vgg_model.get_layer('block3_conv3').output)
# Ensure blocks are not trainable
loss_block1.trainable = False
loss_block2.trainable = False
loss_block3.trainable = False

# PEAK SIGNAL TO NOISE RATIO
@tf.function
def psnr(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    return tf.image.psnr(y_true, y_pred, max_val=255.0)

# STRUCTURAL SIMILARITY INDEX
@tf.function
def ssim(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    return tf.image.ssim(y_true, y_pred, max_val=255.0)

# STRUCTURAL LOSS
@tf.function
def ssim_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    return 1-tf.image.ssim(y_true, y_pred, max_val=255.0)

# ABSOLUTE BRIGHTNESS
@tf.function
def ab(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    return tf.abs(tf.reduce_mean(y_true[:,:,:,:3])-tf.reduce_mean(y_pred[:,:,:,:3]))
# PERCEPTUAL LOSS
def per_loss_vgg(img_true, img_generated):
    # Convert both inputs to float32
    img_true = tf.cast(img_true, tf.float32)
    img_generated = tf.cast(img_generated, tf.float32)
    
    normalisation = 15 * 256 * 256
    
    # Use the pre-defined models
    loss = K.mean(K.square(img_true - img_generated)) + \
           2 * K.mean(K.square(loss_block1(img_true) - loss_block1(img_generated))) + \
           4 * K.mean(K.square(loss_block2(img_true) - loss_block2(img_generated))) + \
           8 * K.mean(K.square(loss_block3(img_true) - loss_block3(img_generated)))
    
    return loss/normalisation

# WEIGHTED PATCH-WISE EUCLIDEAN LOSS
@tf.function
def wpw(y_true, y_pred, weight=4, percentage=0.25, patches_per_row=16):
    # Convert both inputs to float32
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    gray_org = 0.39 * y_pred[:, :, :, 0] + 0.5 * y_pred[:, :, :, 1] + 0.11 * y_pred[:, :, :, 2]
    gray_true = 0.39 * y_true[:, :, :, 0] + 0.5 * y_true[:, :, :, 1] + 0.11 * y_true[:, :, :, 2]
    gray = tf.expand_dims(gray_org, -1)
    
    # Using static values to avoid dynamic shapes
    patch_length = int(256/patches_per_row)  # Hard-coded IMG_HEIGHT/patches_per_row
    no_of_patches = patches_per_row*patches_per_row
    no_of_patches_to_consider = int(no_of_patches * percentage)
    normalization_factor = int(no_of_patches*((weight - 1)*percentage + 1)*256*256)
    
    filter_of_ones = tf.ones([patch_length,patch_length,1,1], tf.float32)
    
    strides_val = [1, patch_length, patch_length, 1]
    sum_of_patches = tf.nn.conv2d(gray, filter_of_ones, strides=strides_val, padding='SAME')
    
    sorted_sums = tf.sort(tf.reshape(sum_of_patches, [-1]))
    threshold_sum = sorted_sums[no_of_patches_to_consider]
    
    mask = tf.cast(sum_of_patches <= threshold_sum, tf.float32)
    
    weighted_mask_per_channel = tf.add(tf.multiply(float(weight), mask), 
                                      tf.subtract(float(1), mask))
    squared_loss = tf.square(gray_org - gray_true)
    squared_loss = tf.expand_dims(squared_loss, -1)
    
    sum_of_squared_loss_patches = tf.nn.conv2d(squared_loss, filter_of_ones, 
                                              strides=strides_val, padding='SAME')
    
    loss = tf.reduce_sum(tf.multiply(weighted_mask_per_channel, sum_of_squared_loss_patches))
    return loss/normalization_factor

# TOTAL LOSS
@tf.function
def total_loss(y_true, y_pred):
    w_per = 1
    w_ssim = 1
    w_wpw = 0.1
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    # Calculate each loss component 
    p_loss = per_loss_vgg(y_true, y_pred)
    s_loss = 1 - ssim(y_true, y_pred)
    w_loss = wpw(y_true, y_pred)
    
    total = w_per * p_loss + w_ssim * s_loss + w_wpw * w_loss
    return total


In [10]:
## DEFINE MODEL COMPONENTS ##

def _define_conv_block(
    input_, layers, filters,
    kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal', **kwargs):
    output_ = Conv2D(filters, kernel_size, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(input_)
    for layer in range(1, layers):
        output_ = Conv2D(filters, kernel_size, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(output_)
    return output_

def _define_sep_conv_block(
    input_, layers, filters,
    kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal', **kwargs):
    output_ = SeparableConv2D(filters, kernel_size, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(input_)
    for layer in range(1, layers):
        output_ = SeparableConv2D(filters, kernel_size, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(output_)
    return output_

def _define_aspp_block(input_, filters, dilation_rates=list((1, 3, 5, 7)),
                       kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal', **kwargs):
    num_parallel_outputs = len(dilation_rates)
    parallel_outputs = [None] * num_parallel_outputs
    for output_idx in range(num_parallel_outputs):
        dilation_rate = (dilation_rates[output_idx], dilation_rates[output_idx])
        parallel_outputs[output_idx] = SeparableConv2D(filters, kernel_size, dilation_rate = dilation_rate, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(input_)
    output_ = concatenate(parallel_outputs, axis=3)
    return output_

def encoder_decoder_with_aspp_blocks(input_shape = (256, 256, 3)):
    inputs = Input(input_shape)
    
    block1 = _define_conv_block(inputs, 2, 64)
    pool1 = MaxPooling2D(pool_size=(2, 2))(block1)

    block2 = _define_aspp_block(pool1, 64)
    block2 = _define_sep_conv_block(block2, 2, 64)
    pool2 = MaxPooling2D(pool_size=(2, 2))(block2)
    
    block3 = _define_aspp_block(pool2, 128)
    block3 = _define_sep_conv_block(block3, 2, 128)
    pool3 = MaxPooling2D(pool_size=(2, 2))(block3)
    
    block4 = _define_aspp_block(pool3, 256)
    block4 = _define_sep_conv_block(block4, 2, 256)
    pool4 = MaxPooling2D(pool_size=(2, 2))(block4)

    block5 = _define_aspp_block(pool4, 512)
    block5 = _define_sep_conv_block(block5, 4, 512)
    
    up6 = SeparableConv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(block5))
    merge6 = concatenate([block4, up6], axis = 3)
    block6 = _define_sep_conv_block(merge6, 2, 512)

    up7 = SeparableConv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(block6))
    merge7 = concatenate([block3, up7], axis = 3)
    block7 = _define_sep_conv_block(merge7, 2, 256)

    up8 = SeparableConv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(block7))
    merge8 = concatenate([block2, up8], axis = 3)
    block8 = _define_sep_conv_block(merge8, 2, 128)

    up9 = SeparableConv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(block8))
    merge9 = concatenate([block1, up9], axis = 3)    
      
    block9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    output = Conv2D(3, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(block9)
    model = Model(inputs = inputs, outputs = output)
    return model

In [11]:
## BUILD MODEL ##

# Model Compile
loss = total_loss

# Configure Optimizer - updated for TF 2.x
learning_rate = 1e-4
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # Changed lr= to learning_rate=

# Others
metrics = [psnr, ssim, ab,per_loss_vgg, ssim_loss,wpw]

# Build
model = encoder_decoder_with_aspp_blocks()
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
model.summary()

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 256, 256, 64  1792        ['input_2[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 256, 256, 64  36928       ['conv2d[0][0]']                 
                                )                                                           

In [12]:
## TRAIN ##
if not os.path.exists(model_folder):
  os.makedirs(model_folder)
model_name = 'model_clienet_epoch-{epoch:03d}_val_loss-{val_loss:.4f}_psnr-{val_psnr:.3f}_ssim-{val_ssim:.3f}_ab-{val_ab:.3f}.h5'
checkpointer = ModelCheckpoint(os.path.join(model_folder, model_name), verbose=0, save_best_only=False, save_weights_only = True)



In [13]:
# Update to include normalization
train_image_array = train_image_array.astype('float32') / 255.0
train_mask_array = train_mask_array.astype('float32') / 255.0
val_image_array = val_image_array.astype('float32') / 255.0
val_mask_array = val_mask_array.astype('float32') / 255.0

In [14]:
# Save only best model
checkpointer = ModelCheckpoint(
    os.path.join(model_folder, 'model_clienet_best.h5'), 
    verbose=1, 
    save_best_only=True, 
    save_weights_only=True,
    monitor='val_psnr',
    mode='max'
)

In [15]:
history = model.fit(
    train_mask_array,  # Low-light images (input)
    train_image_array, # High-light images (target)
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    validation_data=(val_mask_array, val_image_array),
    callbacks=[checkpointer],
    shuffle=True,
    verbose=1
)

Epoch 1/60
Epoch 1: val_psnr improved from -inf to 64.87448, saving model to model\model_clienet_best.h5
Epoch 2/60
Epoch 2: val_psnr improved from 64.87448 to 65.49338, saving model to model\model_clienet_best.h5
Epoch 3/60
Epoch 3: val_psnr did not improve from 65.49338
Epoch 4/60
Epoch 4: val_psnr did not improve from 65.49338
Epoch 5/60
Epoch 5: val_psnr improved from 65.49338 to 65.68040, saving model to model\model_clienet_best.h5
Epoch 6/60
Epoch 6: val_psnr did not improve from 65.68040
Epoch 7/60
Epoch 7: val_psnr did not improve from 65.68040
Epoch 8/60
Epoch 8: val_psnr improved from 65.68040 to 65.74101, saving model to model\model_clienet_best.h5
Epoch 9/60
Epoch 9: val_psnr did not improve from 65.74101
Epoch 10/60
Epoch 10: val_psnr improved from 65.74101 to 65.85297, saving model to model\model_clienet_best.h5
Epoch 11/60
Epoch 11: val_psnr did not improve from 65.85297
Epoch 12/60

KeyboardInterrupt: 