In [None]:
!pip install -q git+https://github.com/soumik12345/mirnetv2.git@model-factory

In [None]:
import tensorflow as tf
from absl import app, flags, g
from ml_collections.config_flags import config_flags
from wandb.keras import WandbMetricsLogger

import wandb
from restorers.callbacks import LowLightEvaluationCallback
from restorers.dataloader import LOLDataLoader
from restorers.losses import CharbonnierLoss
from restorers.metrics import PSNRMetric, SSIMMetric
from restorers.model import MirNetv2
from restorers.utils import get_model_checkpoint_callback, initialize_device

from low_light_config import get_config

In [None]:
wandb_project_name = 'mirnet-v2' #@param {type:"string"}
wandb_run_name = 'train/low-light/mirnetv2' #@param {type:"string"}
wandb_entity_name = 'ml-colabs' #@param {type:"string"}
wandb_job_type = 'train' #@param {type:"string"}

experiment_configs = get_config()
wandb.init(
    project=wandb_project_name,
    name=wandb_run_name,
    entity=wandb_entity_name,
    job_type=wandb_job_type,
    config=experiment_configs.to_dict(),
)

In [None]:
tf.keras.utils.set_random_seed(experiment_configs.seed)
strategy = initialize_device()
batch_size = (
    experiment_configs.data_loader_configs.local_batch_size
    * strategy.num_replicas_in_sync
)
wandb.config.global_batch_size = batch_size

In [None]:
data_loader = LOLDataLoader(
    image_size=experiment_configs.data_loader_configs.image_size,
    bit_depth=experiment_configs.data_loader_configs.bit_depth,
    val_split=experiment_configs.data_loader_configs.val_split,
    visualize_on_wandb=experiment_configs.data_loader_configs.visualize_on_wandb,
    dataset_artifact_address=experiment_configs.data_loader_configs.dataset_artifact_address,
)
train_dataset, val_dataset = data_loader.get_datasets(batch_size=batch_size)

In [None]:
with strategy.scope():
    model = MirNetv2(
        channels=experiment_configs.model_configs.channels,
        channel_factor=experiment_configs.model_configs.channel_factor,
        num_mrb_blocks=experiment_configs.model_configs.num_mrb_blocks,
        add_residual_connection=experiment_configs.model_configs.add_residual_connection,
    )
    loss = CharbonnierLoss(
        epsilon=experiment_configs.training_configs.charbonnier_epsilon,
        reduction=tf.keras.losses.Reduction.SUM,
    )

    decay_steps = (
        len(data_loader.train_low_light_images) // batch_size
    ) * experiment_configs.training_configs.epochs
    lr_schedule_fn = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=experiment_configs.training_configs.initial_learning_rate,
        decay_steps=decay_steps,
        alpha=experiment_configs.training_configs.minimum_learning_rate,
    )
    optimizer = tf.keras.optimizers.experimental.AdamW(
        learning_rate=lr_schedule_fn,
        weight_decay=experiment_configs.training_configs.weight_decay,
        beta_1=experiment_configs.training_configs.decay_rate_1,
        beta_2=experiment_configs.training_configs.decay_rate_2,
    )

    psnr_metric = PSNRMetric(
        max_val=experiment_configs.training_configs.psnr_max_val
    )
    ssim_metric = SSIMMetric(
        max_val=experiment_configs.training_configs.ssim_max_val
    )

    model.compile(
        optimizer=optimizer, loss=loss, metrics=[psnr_metric, ssim_metric]
    )

In [None]:
callbacks = [
    get_model_checkpoint_callback(
        filepath="checkpoint", save_best_only=False, using_wandb=True
    )
]
callbacks.append(WandbMetricsLogger(log_freq="batch"))
callbacks.append(
    LowLightEvaluationCallback(
        validation_data=val_dataset.take(1),
        data_table_columns=["Input-Image", "Ground-Truth-Image"],
        pred_table_columns=[
            "Epoch",
            "Input-Image",
            "Ground-Truth-Image",
            "Predicted-Image",
            "Peak-Signal-To-Noise-Ratio",
            "Structural-Similarity",
        ],
    )
)

In [None]:
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=experiment_configs.training_configs.epochs,
    callbacks=callbacks,
)

In [None]:
wandb.finish()