In [None]:
import wandb
from wandb.keras import WandbMetricsLogger

from tqdm.keras import TqdmCallback

import tensorflow as tf
tf.get_logger().setLevel('ERROR')

from restorers.model import NAFNet
from restorers.dataloader import LOLDataLoader
from restorers.losses import CharbonnierLoss, PSNRLoss
from restorers.metrics import PSNRMetric, SSIMMetric
from restorers.utils import get_model_checkpoint_callback

In [None]:
wandb.init(project="nafnet", entity="ml-colabs")

data_loader = LOLDataLoader(
    image_size=128,
    bit_depth=8,
    val_split=0.2,
    visualize_on_wandb=False,
    dataset_artifact_address="ml-colabs/dataset/LoL:v0"
)

train_dataset, val_dataset = data_loader.get_datasets(batch_size=4)

In [None]:
model = NAFNet()

decay_steps = (len(data_loader.train_input_images) // 4) * 100
lr_schedule_fn = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=2e-4,
    decay_steps=decay_steps,
    alpha=1e-6,
)
optimizer = tf.keras.optimizers.experimental.AdamW(
    learning_rate=lr_schedule_fn, weight_decay=1e-4
)

psnr_metric = PSNRMetric(max_val=1.0)
ssim_metric = SSIMMetric(max_val=1.0)

loss = CharbonnierLoss(epsilon=1e-3)

model.compile(
    optimizer=optimizer, loss=loss, metrics=[psnr_metric, ssim_metric]
)

In [None]:
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=100,
    verbose=0,
    callbacks=[
        WandbMetricsLogger(log_freq="batch"),
        get_model_checkpoint_callback(
            filepath="checkpoint", save_best_only=False, using_wandb=True
        ),
        TqdmCallback()
    ]
)

In [None]:
wandb.finish()