In [None]:
!pip install -q --upgrade pip setuptools
!pip install -q git+https://github.com/soumik12345/mirnetv2.git

In [None]:
import wandb
from wandb.keras import WandbMetricsLogger

from tqdm.keras import TqdmCallback

import tensorflow as tf
tf.get_logger().setLevel('ERROR')

from restorers.model import NAFNet
from restorers.dataloader import LOLDataLoader
from restorers.losses import CharbonnierLoss, PSNRLoss
from restorers.metrics import PSNRMetric, SSIMMetric
from restorers.utils import initialize_device

from low_light_config import get_config

In [None]:
wandb_project_name = 'mirnet-v2' #@param {type:"string"}
wandb_run_name = 'train/low-light/mirnetv2' #@param {type:"string"}
wandb_entity_name = 'ml-colabs' #@param {type:"string"}
wandb_job_type = 'test' #@param {type:"string"}

experiment_configs = get_config()
data_loader_configs = FLAGS.experiment_configs.data_loader_configs
model_configs = FLAGS.experiment_configs.model_configs
training_configs = FLAGS.experiment_configs.training_configs

wandb.init(
    project=wandb_project_name,
    name=wandb_run_name,
    entity=wandb_entity_name,
    job_type=wandb_job_type,
    config=experiment_configs.to_dict(),
)

In [None]:
strategy = initialize_device()

batch_size = data_loader_configs.local_batch_size * strategy.num_replicas_in_sync
if using_wandb:
    wandb.config.global_batch_size = batch_size

In [None]:
data_loader = LOLDataLoader(
    image_size=data_loader_configs.image_size,
    bit_depth=data_loader_configs.bit_depth,
    val_split=data_loader_configs.val_split,
    visualize_on_wandb=data_loader_configs.visualize_on_wandb,
    dataset_artifact_address=data_loader_configs.dataset_artifact_address,
)

train_dataset, val_dataset = data_loader.get_datasets(batch_size=batch_size)

In [None]:
with strategy.scope():
    model = NAFNet(
        filters=model_configs.filters,
        middle_block_num=model_configs.middle_block_num,
        encoder_block_nums=model_configs.encoder_block_nums,
        decoder_block_nums=model_configs.decoder_block_nums,
    )
    loss = CharbonnierLoss(
        epsilon=training_configs.charbonnier_epsilon,
        reduction=tf.keras.losses.Reduction.SUM,
    )

    decay_steps = (
        len(data_loader.train_input_images) // batch_size
    ) * training_configs.epochs
    lr_schedule_fn = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=training_configs.initial_learning_rate,
        decay_steps=decay_steps,
        alpha=training_configs.minimum_learning_rate,
    )
    optimizer = tf.keras.optimizers.experimental.AdamW(
        learning_rate=lr_schedule_fn,
        weight_decay=training_configs.weight_decay,
        beta_1=training_configs.decay_rate_1,
        beta_2=training_configs.decay_rate_2,
    )
    logging.info(f"Using AdamW optimizer.")

    psnr_metric = PSNRMetric(max_val=training_configs.psnr_max_val)
    logging.info("Using Peak Signal-noise Ratio Metric.")
    ssim_metric = SSIMMetric(max_val=training_configs.ssim_max_val)
    logging.info("Using Structural Similarity Metric.")

    model.compile(
        optimizer=optimizer, loss=loss, metrics=[psnr_metric, ssim_metric]
    )

In [None]:
callbacks = [
    TqdmCallback(),
    WandbMetricsLogger(log_freq="batch"),
    WandbModelCheckpoint(
        filepath="checkpoint",
        monitor="val_loss",
        save_best_only=False,
        save_weights_only=False,
        initial_value_threshold=None,
    )
]

In [None]:
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=training_configs.epochs,
    callbacks=callbacks,
)

In [None]:
wandb.finish()