In [None]:
import os
from glob import glob

import tensorflow as tf

import wandb
from wandb.keras import WandbMetricsLogger

from absl import app, flags, logging
from ml_collections.config_flags import config_flags

from restorers.model.zero_dce import ZeroDCE, FastZeroDce
from restorers.dataloader import UnsupervisedLOLDataLoader
from restorers.utils import (
    get_model_checkpoint_callback,
    initialize_device
)

from low_light_config import get_config

In [None]:
wandb_project_name = 'zero-dce' #@param {type:"string"}
wandb_run_name = 'train/lol' #@param {type:"string"}
wandb_entity_name = 'ml-colabs' #@param {type:"string"}
wandb_job_type = 'test' #@param {type:"string"}

experiment_configs = get_config()
wandb.init(
    project=wandb_project_name,
    name=wandb_run_name,
    entity=wandb_entity_name,
    job_type=wandb_job_type,
    config=experiment_configs.to_dict(),
)

In [None]:
tf.keras.utils.set_random_seed(experiment_configs.seed)
strategy = initialize_device()
batch_size = (
    experiment_configs.data_loader_configs.local_batch_size
    * strategy.num_replicas_in_sync
)
wandb.config.global_batch_size = batch_size

In [None]:
data_loader = UnsupervisedLOLDataLoader(
    image_size=experiment_configs.data_loader_configs.image_size,
    bit_depth=experiment_configs.data_loader_configs.bit_depth,
    val_split=experiment_configs.data_loader_configs.val_split,
    visualize_on_wand=False,
    dataset_artifact_address=experiment_configs.data_loader_configs.dataset_artifact_address,
)

tran_dataset, val_dataset = data_loader.get_datasets(batch_size=batch_size)

In [None]:
with strategy.scope():
    model = (
        ZeroDCE(
            num_intermediate_filters=experiment_configs.model_configs.num_intermediate_filters,
            num_iterations=experiment_configs.model_configs.num_iterations,
            decoder_channel_factor=experiment_configs.model_configs.decoder_channel_factor
        )
        if not experiment_configs.model_configs.use_faster_variant
        else FastZeroDce(
            num_intermediate_filters=experiment_configs.model_configs.num_intermediate_filters,
            num_iterations=experiment_configs.model_configs.num_iterations,
            decoder_channel_factor=experiment_configs.model_configs.decoder_channel_factor
        )
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=experiment_configs.training_configs.learning_rate,
        ),
        weight_exposure_loss=experiment_configs.training_configs.weight_exposure_loss,
        weight_color_constancy_loss=experiment_configs.training_configs.weight_color_constancy_loss,
        weight_illumination_smoothness_loss=experiment_configs.training_configs.weight_illumination_smoothness_loss,
    )

In [None]:
callbacks = [
    get_model_checkpoint_callback(
        filepath="checkpoint",
        save_best_only=experiment_configs.training_configs.save_best_checkpoint_only,
        using_wandb=True
    )
]
callbacks.append(WandbMetricsLogger(log_freq="batch"))

In [None]:
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=experiment_configs.training_configs.epochs,
    callbacks=callbacks,
)

In [None]:
wandb.finish()