In [None]:
import tensorflow as tf
from absl import app, flags, logging
from ml_collections.config_flags import config_flags
from wandb.keras import WandbMetricsLogger

import wandb

from low_light_config import get_config

In [None]:
wandb_project_name = 'zero-dce' #@param {type:"string"}
wandb_run_name = 'train/lot' #@param {type:"string"}
wandb_entity_name = 'ml-colabs' #@param {type:"string"}
wandb_job_type = 'test' #@param {type:"string"}

experiment_configs = get_config()
wandb.init(
    project=wandb_project_name,
    name=wandb_run_name,
    entity=wandb_entity_name,
    job_type=wandb_job_type,
    config=experiment_configs.to_dict(),
)

In [None]:
tf.keras.utils.set_random_seed(experiment_configs.seed)
strategy = initialize_device()
batch_size = (
    experiment_configs.data_loader_configs.local_batch_size
    * strategy.num_replicas_in_sync
)
wandb.config.global_batch_size = batch_size

In [None]:
def load_data(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(
        images=image,
        size=[
            config.data_loader_configs.image_size,
            config.data_loader_configs.image_size
        ]
    )
    image = image / 255.0
    return image


def data_generator(low_light_images):
    dataset = tf.data.Dataset.from_tensor_slices((low_light_images))
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset


artifact = run.use_artifact(
    config.data_loader_configs.dataset_artifact_address, type='dataset'
)
artifact_dir = artifact.download()

train_low_light_images = sorted(glob(os.path.join(artifact_dir, "our485", "low", "*")))
num_train_images = int((1 - config.data_loader_configs.val_split) * len(train_low_light_images))
val_low_light_images = train_low_light_images[num_train_images:]
train_low_light_images = train_low_light_images[:num_train_images]

train_dataset = data_generator(train_low_light_images)
val_dataset = data_generator(val_low_light_images)