In [None]:
import os
import random
import numpy as np
from glob import glob
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import wandb; wandb.login()
from wandb.keras import WandbCallback

In [None]:
run = wandb.init()
artifact = run.use_artifact('ivangoncharov/Low Light Enhancement with Zero-DCE/Lol_Dataset:v0', type='dataset')
artifact_dir = artifact.download()
artifact_path = os.path.join(artifact_dir, "lol_dataset.zip")

In [None]:
!unzip $artifact_path

In [None]:
IMAGE_SIZE = 256
BATCH_SIZE = 16
MAX_TRAIN_IMAGES = 400


In [None]:
def load_data(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
    image = image / 255.0
    return image

In [None]:
def data_generator(low_light_images):
    dataset = tf.data.Dataset.from_tensor_slices((low_light_images))
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    return dataset

In [None]:
train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
test_low_light_images = sorted(glob("./lol_dataset/eval15/low/*"))

In [None]:
train_dataset = data_generator(train_low_light_images)
val_dataset = data_generator(val_low_light_images)

In [None]:
print("Train Dataset:", train_dataset)
print("Validation Dataset:", val_dataset)

In [None]:
wandb.init(project="low_light_zero_DCE", job_type="EDA")

table = wandb.Table(columns=["Low light", "High light"])

for img_path in train_low_light_images[:100]:
  lowlight_image = Image.open(img_path)
  highlight_image = Image.open(img_path.replace("low", "high"))
  table.add_data(
        wandb.Image(np.array(lowlight_image)),
        wandb.Image(np.array(highlight_image)),
    )


wandb.log({"Dataset table": table})

wandb.finish()
     