In [None]:
import os
from glob import glob

import tensorflow as tf

import wandb
from wandb.keras import WandbMetricsLogger

from absl import app, flags, logging
from ml_collections.config_flags import config_flags

from restorers.model.zero_dce import ZeroDCE
from restorers.utils import (
    get_model_checkpoint_callback,
    initialize_device
)

from low_light_config import get_config

In [None]:
wandb_project_name = 'zero-dce' #@param {type:"string"}
wandb_run_name = 'train/lol' #@param {type:"string"}
wandb_entity_name = 'ml-colabs' #@param {type:"string"}
wandb_job_type = 'test' #@param {type:"string"}

experiment_configs = get_config()
wandb.init(
    project=wandb_project_name,
    name=wandb_run_name,
    entity=wandb_entity_name,
    job_type=wandb_job_type,
    config=experiment_configs.to_dict(),
)

In [None]:
tf.keras.utils.set_random_seed(experiment_configs.seed)
strategy = initialize_device()
batch_size = (
    experiment_configs.data_loader_configs.local_batch_size
    * strategy.num_replicas_in_sync
)
wandb.config.global_batch_size = batch_size

In [None]:
def load_data(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(
        images=image,
        size=[
            experiment_configs.data_loader_configs.image_size,
            experiment_configs.data_loader_configs.image_size
        ]
    )
    image = image / 255.0
    return image


def data_generator(low_light_images):
    dataset = tf.data.Dataset.from_tensor_slices((low_light_images))
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset


artifact = wandb.use_artifact(
    experiment_configs.data_loader_configs.dataset_artifact_address, type='dataset'
)
artifact_dir = artifact.download()

train_low_light_images = sorted(glob(os.path.join(artifact_dir, "our485", "low", "*")))
num_train_images = int((1 - experiment_configs.data_loader_configs.val_split) * len(train_low_light_images))
val_low_light_images = train_low_light_images[num_train_images:]
train_low_light_images = train_low_light_images[:num_train_images]

train_dataset = data_generator(train_low_light_images)
val_dataset = data_generator(val_low_light_images)

In [None]:
with strategy.scope():
    model = ZeroDCE(
        num_intermediate_filters=experiment_configs.model_configs.num_intermediate_filters,
        num_iterations=experiment_configs.model_configs.num_iterations
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=experiment_configs.training_configs.learning_rate
        )
    )

In [None]:
callbacks = [
    get_model_checkpoint_callback(
        filepath="checkpoint",
        save_best_only=experiment_configs.training_configs.save_best_checkpoint_only,
        using_wandb=True
    )
]
callbacks.append(WandbMetricsLogger(log_freq="batch"))

In [None]:
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=experiment_configs.training_configs.epochs,
    callbacks=callbacks,
)

In [None]:
wandb.finish()