<a href="https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/restorers/Train_Zero_DCE_Restorers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<!--- @wandbcode{restorers-zero-dce-train} -->

# 🌈 Restorers + WandB 🪄🐝

<!--- @wandbcode{restorers-mirnetv2-train} -->

This notebook shows how to train a [Zero-DCE](https://arxiv.org/abs/2001.06826) model for zero-reference low-light enhancement using [**restorers**](https://github.com/soumik12345/restorers) and [**wandb**](https://wandb.ai/site). For more details regarding usage of restorers, refer to the following report:

[![](https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-gradient.svg)](https://wandb.ai/ml-colabs/low-light-enhancement/reports/Lighting-up-Images-in-the-Deep-Learning-Era--VmlldzozNzE4Njkz)

In [None]:
!pip install -q --upgrade pip setuptools
!pip install git+https://github.com/soumik12345/restorers.git

In [None]:
import os
from glob import glob

import tensorflow as tf

import wandb
# import the wandb callbacks for keras
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint

from restorers.dataloader import UnsupervisedLOLDataLoader
from restorers.model.zero_dce import ZeroDCE

In [None]:
wandb.init(project="low-light-enhancement", job_type="train")


data_loader = UnsupervisedLOLDataLoader(
    # size of image crops on which we will train
    image_size=128,
    # bit depth of the images
    bit_depth=8,
    # fraction of images for validation
    val_split=0.2,
    # visualize the dataset on WandB or not
    visualize_on_wandb=True,
    # the wandb artifact address of the dataset,
    # this can be found from the `Usage` tab of
    # the aforemenioned weave panel
    dataset_artifact_address="ml-colabs/dataset/LoL:v0",
)

# call `get_datasets` on the `data_loader` to get
# the TensorFlow datasets corresponding to the 
# training and validation splits
train_dataset, val_dataset = data_loader.get_datasets(batch_size=16)

In [None]:
# define the ZeroDCE model; this gives us a `tf.keras.Model`
model = ZeroDCE(
    num_intermediate_filters=32, # number of filters in the intermediate convolutional layers
    num_iterations=8, # number of iterations of enhancement
    decoder_channel_factor=1 # factor by which number filters in the decoder of deep curve estimation layer is multiplied
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    weight_exposure_loss=1.0, # weight of the exposure control loss
    weight_color_constancy_loss=0.5, # weight of the color constancy loss
    weight_illumination_smoothness_loss=20, # weight of the illumination smoothness loss
)

In [None]:
callbacks = [
    # define the metrics logger callback;
    # we set the `log_freq="batch"` explicitly
    # to the metrics are logged both batch-wise and epoch-wise
    WandbMetricsLogger(log_freq="batch"),
    # define the model checkpoint callback
    WandbModelCheckpoint(
        filepath="checkpoint",
        monitor="val_loss",
        save_best_only=False,
        save_weights_only=False,
        initial_value_threshold=None,
    )
]

# call model.fit()
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=50,
    callbacks=callbacks,
)

In [None]:
wandb.finish()