In [None]:
import os
import wandb
import numpy as np
from glob import glob
from time import time
import tensorflow as tf
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from PIL import Image, ImageOps

from low_light_config import get_config
from restorers.model.zero_dce import ZeroDCE

In [None]:
wandb_project_name = 'zero-dce' #@param {type:"string"}
wandb_run_name = 'train/lol' #@param {type:"string"}
wandb_entity_name = 'ml-colabs' #@param {type:"string"}
wandb_job_type = 'test' #@param {type:"string"}

experiment_configs = get_config()
wandb.init(
    project=wandb_project_name,
    name=wandb_run_name,
    entity=wandb_entity_name,
    job_type=wandb_job_type,
    config=experiment_configs.to_dict(),
)

config = wandb.config
config.model_artifact_address = "ml-colabs/zero-dce/run_55hfxg0a_model:v9" #@param {type:"string"}

In [None]:
artifact = wandb.use_artifact(config.model_artifact_address, type="model")
model_configs = artifact.logged_by().config["model_configs"]
model_path = artifact.download()

# Load Model
model = tf.keras.models.load_model(model_path, compile=False)

In [None]:
dataset_artifact_address = "ml-colabs/dataset/LoL:v0"

# Fetch dataset from WandB dataset artifact
artifact = wandb.use_artifact(dataset_artifact_address, type='dataset')
dataset_dir = artifact.download()

train_val_low_light_images = sorted(glob(
    os.path.join(dataset_dir, "our485", "low", "*.png")
))
train_val_ground_truth_images = sorted(glob(
    os.path.join(dataset_dir, "our485", "high", "*.png")
))

test_low_light_images = sorted(glob(
    os.path.join(dataset_dir, "eval15", "low", "*.png")
))
test_ground_truth_images = sorted(glob(
    os.path.join(dataset_dir, "eval15", "high", "*.png")
))

print(
    "Number of low-light images for training and validation:",
    len(test_low_light_images)
)
print(
    "Number of ground-truth images for training and validation:",
    len(test_ground_truth_images)
)

print(
    "Number of low-light images for evaluation:",
    len(test_low_light_images)
)
print(
    "Number of ground-truth images for evaluation:",
    len(test_ground_truth_images)
)

In [None]:
def preprocess_image(image):
    """Preprocesses the image for inference.

    Returns:
        A numpy array of shape (1, height, width, 3) preprocessed for inference.
    """
    image = tf.keras.preprocessing.image.img_to_array(image)
    image = image.astype("float32") / 255.0
    return np.expand_dims(image, axis=0)


def postprocess_image(model_output):
    """Postprocesses the model output for inference.
    
    Returns:
        A list of PIL.Image.Image objects postprocessed for visualization.
    """
    model_output = model_output * 255.0
    model_output = model_output.clip(0, 255)
    image = model_output[0].reshape(
        (np.shape(model_output)[1], np.shape(model_output)[2], 3)
    )
    return Image.fromarray(np.uint8(image))


def plot_results(images, titles, figure_size=(12, 12)):
    """A simple utility for plotting the results"""
    fig = plt.figure(figsize=figure_size)
    for i in range(len(images)):
        fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
        _ = plt.imshow(images[i])
        plt.axis("off")
    plt.show()


def infer_and_visualize(
    low_light_image_file,
    ground_truth_image_file,
    model,
    visualize_plots
):
    low_light_image = Image.open(low_light_image_file)
    ground_truth_image = Image.open(ground_truth_image_file)
    preprocessed_image = preprocess_image(low_light_image)
    start = time()
    preprocessed_ground_truth = preprocess_image(ground_truth_image)
    inference_time = time() - start
    model_output = model.predict(preprocessed_image, verbose=0)
    psnr = tf.image.psnr(preprocessed_image, model_output, max_val=1.0)
    ssim = tf.image.ssim(preprocessed_image, model_output, max_val=1.0)
    post_processed_image = postprocess_image(model_output)
    
    if visualize_plots:
        plot_results(
            images=[
                low_light_image, ground_truth_image, post_processed_image
            ],
            titles=[
                "Low-light Image", "Ground-truth Image", "Predicted Image"
            ],
            figure_size=(22, 15)
        )
    return (
        low_light_image,
        ground_truth_image,
        post_processed_image,
        psnr, ssim, inference_time
    )

In [None]:
table = wandb.Table(columns=[
    "Input-Image",
    "Ground-Truth",
    "Image-Enhanced-By-AutoContrast",
    "Image-Enhanced-By-ZeroDCE",
    "Peak-Signal-Noise-Ratio",
    "Structual-Similarity",
    "Inference-Time",
    "Dataset"
])

In [None]:
for idx in tqdm(range(len(train_val_low_light_images))):
    (
        low_light_image,
        ground_truth_image,
        mirnet_enhanced_image,
        psnr, ssim, inference_time
    ) = infer_and_visualize(
        train_val_low_light_images[idx],
        train_val_ground_truth_images[idx],
        model,
        visualize_plots=False
    )
    autocontrast_enhanced_image = ImageOps.autocontrast(low_light_image)
    table.add_data(
        wandb.Image(low_light_image),
        wandb.Image(ground_truth_image),
        wandb.Image(autocontrast_enhanced_image),
        wandb.Image(mirnet_enhanced_image),
        psnr.numpy().item(), ssim.numpy().item(),
        inference_time, "LoL/Trian-Val"
    )

In [None]:
for idx in tqdm(range(len(test_low_light_images))):
    (
        low_light_image,
        ground_truth_image,
        mirnet_enhanced_image,
        psnr, ssim, inference_time
    ) = infer_and_visualize(
        test_low_light_images[idx],
        test_ground_truth_images[idx],
        model,
        visualize_plots=False
    )
    autocontrast_enhanced_image = ImageOps.autocontrast(low_light_image)
    table.add_data(
        wandb.Image(low_light_image),
        wandb.Image(ground_truth_image),
        wandb.Image(autocontrast_enhanced_image),
        wandb.Image(mirnet_enhanced_image),
        psnr.numpy().item(), ssim.numpy().item(),
        inference_time, "LoL/Eval15"
    )

In [None]:
wandb.log({"Evaluation": table})
wandb.finish()