**INTRODUCTION**

Competition: https://codalab.lisn.upsaclay.fr/competitions/21112

This is a challenge associated to the PRIN 2022 project "LICAM - AI-powered LiDAR fusion for next-generation smartphone cameras". The challenge task is to deblur real low-light images taken by the iPhone 15 Pro, using both the blurred image and the co-registered depth map produced by the Lidar sensor. The deblurred images will be compared to registered ground truth sharp images by means of the LPIPS perceptual quality metric.

Training and validation data are provided from a novel dataset of low-light iPhone images affected by noise and motion blur, with a registered Lidar map and a sharp ground truth image. These images are the most similar to the test images. Participants may also use the ARKitScenes dataset to pretrain their models by simulating motion blur.

The LICAM -“AI-powered LIDAR fusion for next-generation smartphone cameras (LICAM)” project is funded by European Union – Next Generation EU within the PRIN 2022 program (D.D. 104 - 02/02/2022 Ministero dell’Università e della Ricerca). The contents of this website reflect only the authors' views and opinions and the Ministry cannot be considered responsible for them.

In [None]:
!pip install keras-unet-collection
# !pip install tensorflow==2.13.0
# !pip install tensorflow-gpu==2.8.0

In [None]:
import glob
from PIL import Image
import numpy as np
import re
import tensorflow as tf
import transformers
import torch
import random
from sklearn.model_selection import train_test_split
from keras_unet_collection import models

import matplotlib.pyplot as plt

from typing import (Optional,
                    Tuple)

In [None]:
PATH_TRAIN: str = "../input/lidar-challenge/LICAM_deblur_challenge_dataset/train_val"
PATH_TEST: str = "../input/lidar-challenge/LICAM_deblur_challenge_dataset/test"
PATH_DATA: str = "../input/lidar-challenge/data"
PATH_MODEL: str = "../input/deblurring_images/other/default/4/best_model_unet.weights.h5"

# Depth map preprocessing

In [None]:
MASK_VALUE: int = -1

In [None]:
files_train = glob.glob(f"{PATH_TRAIN}/images45/*/depth/*.png")
files_data = glob.glob(f"{PATH_DATA}/*/*/highres_depth/*.png")

In [None]:
for file in files_train:
    img = np.array(Image.open(file), dtype=np.int32)
    img = (img - np.min(img)) / (np.max(img) - np.min(img))
    np.save(f'./0_depth_{re.findall(r".*/images45/.*/depth/(.*).png", file)[0]}.npy', img)

In [None]:
#for file in files_data:
#    img = np.array(Image.open(file).resize(size=(512, 512), resample=Image.NEAREST), dtype=np.int32)
#    img[img == 0] = MASK_VALUE
#    img[img != MASK_VALUE] = ((img[img != MASK_VALUE] - np.min(img[img != MASK_VALUE])) 
#                              / (np.max(img[img != MASK_VALUE]) - np.min(img[img != MASK_VALUE])))
#    np.save(f'./1_depth_{re.findall(r".*/.*/.*/highres_depth/(.*).png", file)[0]}.npy', img)

In [None]:
for file in files_data:
    img = np.array(Image.open(file).resize(size=(512, 512), resample=Image.NEAREST), dtype=np.int32)
    img[img == 0] = np.mean(img[img != 0])
    img = (img - np.min(img)) / (np.max(img) - np.min(img))
    np.save(f'./1_depth_{re.findall(r".*/.*/.*/highres_depth/(.*).png", file)[0]}.npy', img)

# Image preprocessing

In [None]:
files_train = glob.glob(f"{PATH_TRAIN}/images45/*/rgb/*.png")
files_data = glob.glob(f"{PATH_DATA}/*/*/wide/*.png")

In [None]:
for file in files_train:
    img = np.array(Image.open(file), dtype=np.float32)
    img = img / 255.
    np.save(f'./0_rgb_{re.findall(r".*/images45/.*/rgb/(.*).png", file)[0]}.npy', img)

In [None]:
for file in files_data:
    img = np.array(Image.open(file).resize(size=(512, 512), resample=Image.LANCZOS), dtype=np.int32)
    img = img / 255.
    np.save(f'./1_rgb_{re.findall(r".*/.*/.*/wide/(.*).png", file)[0]}.npy', img)

# Ground truth

In [None]:
files_train = glob.glob(f"{PATH_TRAIN}/images45/*/gt/*.png")

In [None]:
for file in files_train:
    img = np.array(Image.open(file), dtype=np.float32)
    img = img / 255.
    np.save(f'./0_gt_{re.findall(r".*/images45/.*/gt/(.*).png", file)[0]}.npy', img)

# Sequence generator

In [None]:
idx = glob.glob(f"./*_rgb_*.npy")
idx = np.asarray([re.findall(r"./(.*)_rgb_(.*).npy", file)[0] for file in idx])

In [None]:
def random_motion(steps: Optional[int] = 16,
                  initial_vector: Optional[torch.Tensor] = None,
                  alpha: Optional[float] = 0.2):
    motion = [torch.zeros_like(initial_vector)]
    for s in range(steps):
        change = torch.randn(initial_vector.shape[0], dtype=torch.cfloat)
        initial_vector = initial_vector + change * alpha
        initial_vector /= initial_vector.abs().add(1e-8)
        motion.append(motion[-1] + initial_vector)
    motion = torch.stack(motion, -1)
    motion -= motion.mean(-1, keepdim=True)
    xrange = max(motion.real.max().ceil().long(), -motion.real.min().floor().long()) * 2
    yrange = max(motion.imag.max().ceil().long(), -motion.imag.min().floor().long()) * 2
    
    kernel = torch.zeros(initial_vector.shape[0], 1, yrange.item()+1, xrange.item()+1)
    for s in range(steps):
        v = motion[:,s] + kernel.shape[-1] // 2 + (kernel.shape[-2] // 2)*1j
        ixs = v.real.long() 
        iys = v.imag.long()
        vxs = v.real - ixs
        vys = v.imag - iys

        for i, (iy, ix, vy, vx) in enumerate(zip(iys, ixs, vys, vxs)): 
            kernel[i,0,iy,ix] += (1-vx) * (1-vy) / steps
            kernel[i,0,iy,ix+1] += vx * (1-vy) / steps
            kernel[i,0,iy+1,ix] += (1-vx) * vy / steps
            kernel[i,0,iy+1,ix+1] += vx * vy / steps
        
    return kernel

In [None]:
class RandomMotionBlur(torch.nn.Module):
    def __init__(self,
                 steps: Optional[int] = 17,
                 alpha: Optional[float] = 0.2):
        super().__init__()
        self.steps = steps
        self.alpha = alpha
        
    def forward(self,
                x: torch.Tensor):
        x = x.swapaxes(1, 3)
        vector = torch.randn(x.shape[0], dtype=torch.cfloat) / 3
        vector.real /= 2
        m = random_motion(self.steps, vector, alpha=self.alpha)
        xpad = [m.shape[-1]//2+1] * 2 + [m.shape[-2]//2+1] * 2
        x = torch.nn.functional.pad(x, xpad)
        mpad = [0, x.shape[-1]-m.shape[-1], 0, x.shape[-2]-m.shape[-2]]
        mp = torch.nn.functional.pad(m, mpad)
        fx = torch.fft.fft2(x)
        fm = torch.fft.fft2(mp)
        fy = fx * fm
        y = torch.fft.ifft2(fy).real
        y = y[...,xpad[2]:-xpad[3], xpad[0]:-xpad[1]].swapaxes(1, 3)
        return y

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, 
                 idx: np.ndarray,
                 batch_size: Optional[int] = 8,
                 shuffle: Optional[bool] = True,
                 **kwargs):
        super(DataGenerator, self).__init__(**kwargs)
        self.idx = idx
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __len__(self):
        return int(np.ceil(len(self.idx) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.idx)

    def __getitem__(self,
                    index: int):
        idx = self.idx[index * self.batch_size : (index + 1) * self.batch_size]
        X_rgb = np.zeros((len(idx), 512, 512, 3), dtype=np.float32)
        X_depth = np.zeros((len(idx), 512, 512), dtype=np.float32)
        y = np.zeros((len(idx), 512, 512, 3), dtype=np.float32)
        for i, (dtype, file) in enumerate(idx):
            X_depth[i] = np.load(f"./{dtype}_depth_{file}.npy")
            X_rgb[i] = np.load(f"./{dtype}_rgb_{file}.npy")
            y[i] = np.load(f"./{dtype}_gt_{file}.npy") if dtype == "0" else np.load(f"./{dtype}_rgb_{file}.npy")
            if dtype == "1":
                blur = RandomMotionBlur(steps=random.randint(5, 50))
                X_rgb[i] = blur(torch.Tensor(X_rgb[i][np.newaxis, ...]))[0]
        return (X_rgb, X_depth[..., np.newaxis]), y

In [None]:
idx_train, idx_test = train_test_split(range(len(idx)), test_size=0.3)

In [None]:
train_generator = DataGenerator(idx=idx[idx_train], batch_size=4)
test_generator = DataGenerator(idx=idx[idx_test], batch_size=4)

# Model

In [None]:
class SPADE(tf.keras.layers.Layer):
    def __init__(self, 
                 filters: Tuple[int],
                 epsilon: Optional[float] = 1e-5,
                 **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        self.conv = tf.keras.layers.Conv2D(128, 3, padding="same", activation="gelu")
        self.conv_gamma = tf.keras.layers.Conv2D(filters, 3, padding="same")
        self.conv_beta = tf.keras.layers.Conv2D(filters, 3, padding="same")

    def build(self,
              input_shape: Tuple[int]):
        self.resize_shape = input_shape[1:3]

    def call(self,
             input_tensor: tf.Tensor,
             raw_mask: tf.Tensor):
        mask = tf.keras.ops.image.resize(raw_mask, self.resize_shape, interpolation="nearest")
        x = self.conv(mask)
        gamma = self.conv_gamma(x)
        beta = self.conv_beta(x)
        mean, var = tf.keras.ops.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
        std = tf.keras.ops.sqrt(var + self.epsilon)
        normalized = (input_tensor - mean) / std
        output = gamma * normalized + beta
        return output

In [None]:
class ResBlock(tf.keras.layers.Layer):
    def __init__(self,
                 filters: Tuple[int],
                 **kwargs):
        super().__init__(**kwargs)
        self.filters = filters

    def build(self,
              input_shape: Tuple[int]):
        input_filter = input_shape[-1]
        self.spade_1 = SPADE(input_filter)
        self.spade_2 = SPADE(self.filters)
        self.conv_1 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")
        self.conv_2 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")
        self.learned_skip = False

        if self.filters != input_filter:
            self.learned_skip = True
            self.spade_3 = SPADE(input_filter)
            self.conv_3 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")

    def call(self, input_tensor, mask):
        x = self.spade_1(input_tensor, mask)
        x = self.conv_1(tf.keras.activations.leaky_relu(x, 0.2))
        x = self.spade_2(x, mask)
        x = self.conv_2(tf.keras.activations.leaky_relu(x, 0.2))
        skip = (self.conv_3(tf.keras.activations.leaky_relu(self.spade_3(input_tensor, mask), 0.2))
                if self.learned_skip
                else input_tensor)
        output = skip + x
        return output

In [None]:
class SPADEBlock(tf.keras.layers.Layer):
    def __init__(self,
                 filters: Tuple[int],
                 **kwargs):
        super().__init__(**kwargs)
        self.filters = filters

    def build(self,
              input_shape: Tuple[int]):
        input_filter = input_shape[-1]
        self.spade_1 = SPADE(input_filter)
        self.spade_2 = SPADE(self.filters)
        self.conv_1 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")
        self.conv_2 = tf.keras.layers.Conv2D(self.filters, 3, padding="same")

    def call(self, input_tensor, mask):
        x = self.spade_1(input_tensor, mask)
        x = self.conv_1(tf.keras.activations.leaky_relu(x, 0.2))
        x = self.spade_2(x, mask)
        x = self.conv_2(tf.keras.activations.leaky_relu(x, 0.2))
        return x

In [None]:
def downsample_spadeblock(x: tf.Tensor,
                        mask: tf.Tensor,
                        channels: int,
                        kernels: Tuple[int],
                        strides: Optional[int] = 2):
    x = SPADEBlock(filters=channels) (x, mask)
    x = tf.keras.layers.Conv2D(channels,
                               kernels,
                               strides=strides,
                               padding="same",
                               use_bias=False,
                               kernel_initializer=tf.keras.initializers.GlorotNormal()) (x)
    x = tf.keras.layers.GroupNormalization(groups=-1) (x)
    x = tf.keras.layers.LeakyReLU(0.2) (x)
    return x

In [None]:
def downsample(channels: int,
               kernels: Tuple[int],
               strides: Optional[int] = 2):
    block = tf.keras.Sequential()
    block.add(tf.keras.layers.Conv2D(channels,
                                     kernels,
                                     strides=strides,
                                     padding="same",
                                     use_bias=False,
                                     kernel_initializer=tf.keras.initializers.GlorotNormal()))
    block.add(tf.keras.layers.GroupNormalization(groups=-1))
    block.add(tf.keras.layers.LeakyReLU(0.2))
    return block

## Basic

In [None]:
input_img = tf.keras.Input(shape=(512, 512, 3), dtype=tf.float32)
input_depth = tf.keras.Input(shape=(512, 512, 1), dtype=tf.float32)
x = tf.keras.layers.concatenate([input_img, input_depth], axis=-1)
x = tf.keras.layers.Conv2DTranspose(filters=16,
                                    kernel_size=(7, 7),
                                    strides=(1, 1),
                                    padding="same",
                                    use_bias=False,
                                    activation=None,
                                    kernel_initializer=tf.keras.initializers.GlorotNormal()) (x)
x = tf.keras.layers.GroupNormalization(groups=-1) (x)
x = tf.keras.layers.LeakyReLU(0.2) (x)
x = tf.keras.layers.Conv2DTranspose(filters=32,
                                    kernel_size=(5, 5),
                                    strides=(1, 1),
                                    padding="same",
                                    use_bias=False,
                                    activation=None,
                                    kernel_initializer=tf.keras.initializers.GlorotNormal()) (x)
x = tf.keras.layers.GroupNormalization(groups=-1) (x)
x = tf.keras.layers.LeakyReLU(0.2) (x)
x = tf.keras.layers.Conv2DTranspose(filters=64,
                                    kernel_size=(3, 3),
                                    strides=(1, 1),
                                    padding="same",
                                    use_bias=False,
                                    activation=None,
                                    kernel_initializer=tf.keras.initializers.GlorotNormal()) (x)
x = tf.keras.layers.GroupNormalization(groups=-1) (x)
x = tf.keras.layers.LeakyReLU(0.2) (x)
x = tf.keras.layers.Conv2DTranspose(filters=128,
                                    kernel_size=(1, 1),
                                    strides=(1, 1),
                                    padding="same",
                                    use_bias=False,
                                    activation=None,
                                    kernel_initializer=tf.keras.initializers.GlorotNormal()) (x)
x = tf.keras.layers.GroupNormalization(groups=-1) (x)
x = tf.keras.layers.LeakyReLU(0.2) (x)
x = tf.keras.layers.Conv2DTranspose(filters=3,
                                    kernel_size=(1, 1),
                                    strides=(1, 1),
                                    padding="same",
                                    use_bias=False,
                                    activation="tanh",
                                    kernel_initializer=tf.keras.initializers.GlorotNormal()) (x)

model = tf.keras.models.Model(inputs=[input_img, input_depth],
                              outputs=x)
model.summary()

## Basic SPADE

In [None]:
input_img = tf.keras.Input(shape=(512, 512, 3), dtype=tf.float32)
input_depth = tf.keras.Input(shape=(512, 512, 1), dtype=tf.float32)
x = SPADEBlock(filters=32) (input_img, input_depth)
x = SPADEBlock(filters=64) (x, input_depth)
x = SPADEBlock(filters=128) (x, input_depth)
x = SPADEBlock(filters=256) (x, input_depth)
x = tf.keras.layers.Conv2DTranspose(filters=3,
                                    kernel_size=(1, 1),
                                    strides=(1, 1),
                                    padding="same",
                                    use_bias=False,
                                    activation="tanh",
                                    kernel_initializer=tf.keras.initializers.GlorotNormal()) (x)

model = tf.keras.models.Model(inputs=[input_img, input_depth],
                              outputs=x)
model.summary()

## UNet -> esta

In [None]:
input_img = tf.keras.Input(shape=(512, 512, 3), dtype=tf.float32)
input_depth = tf.keras.Input(shape=(512, 512, 1), dtype=tf.float32)
x = tf.keras.layers.concatenate([input_img, input_depth], axis=-1)
x = downsample(channels=32,
               kernels=(5, 5),
               strides=(1, 1)) (x)
d1 = downsample(channels=64,
                kernels=(3, 3),
                strides=(2, 2)) (x)
d2 = downsample(channels=128,
                kernels=(3, 3),
                strides=(2, 2)) (d1)
d3 = downsample(channels=256,
                kernels=(3, 3),
                strides=(2, 2)) (d2)
d4 = downsample(channels=512,
                kernels=(3, 3),
                strides=(2, 2)) (d3)
u4 = ResBlock(filters=512) (d4, input_depth)
u4 = ResBlock(filters=512) (u4, input_depth)
u4 = tf.keras.layers.UpSampling2D(size=(2, 2),
                                  interpolation="nearest") (u4)
u3 = tf.keras.layers.concatenate([u4, d3])
u3 = ResBlock(filters=256) (u3, input_depth)
u3 = ResBlock(filters=256) (u3, input_depth)
u3 = tf.keras.layers.UpSampling2D(size=(2, 2),
                                  interpolation="nearest") (u3)
u2 = tf.keras.layers.concatenate([u3, d2])
u2 = ResBlock(filters=128) (u2, input_depth)
u2 = ResBlock(filters=128) (u2, input_depth)
u2 = tf.keras.layers.UpSampling2D(size=(2, 2),
                                  interpolation="nearest") (u2)
u1 = tf.keras.layers.concatenate([u2, d1])
u1 = ResBlock(filters=64) (u1, input_depth)
u1 = ResBlock(filters=64) (u1, input_depth)
u1 = tf.keras.layers.UpSampling2D(size=(2, 2),
                                  interpolation="nearest") (u1)
x = tf.keras.layers.concatenate([u1, x])
x = ResBlock(filters=32) (x, input_depth)
x = ResBlock(filters=32) (x, input_depth)
x = tf.keras.layers.Conv2D(filters=3,
                           kernel_size=(1, 1),
                           strides=(1, 1),
                           padding="same",
                           use_bias=False,
                           activation="sigmoid",
                           kernel_initializer=tf.keras.initializers.GlorotNormal()) (x)

model = tf.keras.models.Model(inputs=[input_img, input_depth],
                              outputs=x)
model.summary()

## Train

In [None]:
class VGGFeatureMatchingLoss(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.encoder_layers = ["block1_conv1",
                               "block2_conv1",
                               "block3_conv1",
                               "block4_conv1",
                               "block5_conv1"]
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
        vgg = tf.keras.applications.VGG19(include_top=False, weights="imagenet")
        layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers]
        self.vgg_model = tf.keras.Model(vgg.input, layer_outputs, name="VGG")
        self.mae = tf.keras.losses.MeanAbsoluteError()

    def call(self, y_true, y_pred):
        y_true = tf.keras.applications.vgg19.preprocess_input(255. * y_true)
        y_pred = tf.keras.applications.vgg19.preprocess_input(255. * y_pred)
        real_features = self.vgg_model(y_true)
        fake_features = self.vgg_model(y_pred)
        loss = 0
        for i in range(len(real_features)):
            loss += self.weights[i] * self.mae(real_features[i], fake_features[i])
        return loss

In [None]:
class SobelMaskingLoss(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, y_true, y_pred):
        mask_sobel = tf.image.sobel_edges(y_true)
        mask_sobel = mask_sobel[..., 0] + mask_sobel[..., 1]
        mask_sobel = tf.add(mask_sobel, 
                            tf.math.abs(tf.math.reduce_min(mask_sobel, axis=[1, 2, 3]))[..., tf.newaxis, tf.newaxis, tf.newaxis]) + 1.0
        loss = tf.reduce_mean(tf.math.abs(y_true - y_pred) * mask_sobel)
        return loss

In [None]:
class SobelLoss(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, y_true, y_pred):
        sobel_true = tf.image.sobel_edges(y_true)
        sobel_pred = tf.image.sobel_edges(y_pred)
        loss = tf.reduce_mean(tf.math.abs(sobel_true - sobel_pred))
        return loss

In [None]:
class PerceptualLoss(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.vggloss = VGGFeatureMatchingLoss()
        self.sobel_loss = SobelMaskingLoss()

    def call(self, y_true, y_pred):
        mask_sobel = tf.image.sobel_edges(y_true)
        loss = (self.vggloss(y_true, y_pred) * 0.5
                + self.sobel_loss(y_true, y_pred) * 10.
                + (1 - tf.reduce_mean(tf.image.ssim(tf.clip_by_value(y_true, clip_value_max=1.0, clip_value_min=0.0), 
                                                    tf.clip_by_value(y_pred, clip_value_max=1.0, clip_value_min=0.0), 
                                                    1.0))) * 3.)
        return loss

In [None]:
def charbonnier_loss(y_true, y_pred):
    return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))

In [None]:
def lr_warmup_cosine_decay(global_step,
                           warmup_steps,
                           hold = 0,
                           total_steps=0,
                           start_lr=0.0,
                           target_lr=1e-3):
    # Cosine decay
    learning_rate = 0.5 * target_lr * (1 + np.cos(np.pi * (global_step - warmup_steps - hold) / float(total_steps - warmup_steps - hold)))

    # Target LR * progress of warmup (=1 at the final warmup step)
    warmup_lr = target_lr * (global_step / warmup_steps)

    # Choose between `warmup_lr`, `target_lr` and `learning_rate` based on whether `global_step < warmup_steps` and we're still holding.
    # i.e. warm up if we're still warming up and use cosine decayed lr otherwise
    if hold > 0:
        learning_rate = np.where(global_step > warmup_steps + hold,
                                 learning_rate, target_lr)
    
    learning_rate = np.where(global_step < warmup_steps, warmup_lr, learning_rate)
    return learning_rate

class WarmupCosineDecay(tf.keras.callbacks.Callback):
    def __init__(self, total_steps=0, warmup_steps=0, start_lr=0.0, target_lr=1e-3, hold=0):

        super(WarmupCosineDecay, self).__init__()
        self.start_lr = start_lr
        self.hold = hold
        self.total_steps = total_steps
        self.global_step = 0
        self.target_lr = target_lr
        self.warmup_steps = warmup_steps
        self.lrs = []

    def on_batch_end(self, batch, logs=None):
        self.global_step = self.global_step + 1
        lr = model.optimizer.learning_rate.numpy()
        self.lrs.append(lr)

    def on_batch_begin(self, batch, logs=None):
        lr = lr_warmup_cosine_decay(global_step=self.global_step,
                                    total_steps=self.total_steps,
                                    warmup_steps=self.warmup_steps,
                                    start_lr=self.start_lr,
                                    target_lr=self.target_lr,
                                    hold=self.hold)
        self.model.optimizer.learning_rate = lr

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau

In [None]:
# añadir guardar mejor modelo. Si esta 10 epochs sin bajar la loss reduzca learning rate. Reduce lr on plateau
# 400 epochs
# guardar cada vez que se reduzca la loss

In [None]:
checkpoint = ModelCheckpoint(
    "best_model_unet_continue.weights.h5", monitor="loss", save_best_only=True, mode="min", verbose=1, save_weights_only=True
)

reduce_lr = ReduceLROnPlateau(
    monitor="loss", factor=0.5, patience=10, min_lr=1e-6, verbose=1
)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=PerceptualLoss(),
    metrics=["mae", "mse"]
)

In [None]:
model.load_weights(PATH_MODEL) 

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=VGGFeatureMatchingLoss(),
    metrics=["mae", "mse"]
)

In [None]:
model.fit(
    train_generator,
    epochs=100,
    validation_data=test_generator,
    callbacks=[checkpoint, reduce_lr],
    verbose=1
)

In [None]:
'''
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), -> este
              loss=PerceptualLoss(),
              #loss="mae",
              metrics=["mae", "mse"])
'''

'''
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss={"output1": PerceptualLoss(), 
                    "output2": VGGFeatureMatchingLoss()},
              metrics={"output1": ["mae", "mse"],
                       "output2": ["mae", "mse"]})

callback = WarmupCosineDecay(total_steps=143 * 50, 
                             warmup_steps=int(143 * 50 * 0.05),
                             hold=0, 
                             start_lr=0.0, 
                             target_lr=1e-3)
'''

In [None]:
# model.save("./00_model_unetv1_v2.0.h5")

# Predict

In [None]:
X_rgb = np.asarray(Image.open(f"{PATH_TEST}/5/rgb/5.png"))[np.newaxis, ...]
X_depth = np.asarray(Image.open(f"{PATH_TEST}/5/depth/5.png"))[np.newaxis, ...]

In [None]:
X_rgb = X_rgb / 255.
X_depth = ((X_depth - np.min(X_depth)) / (np.max(X_depth) - np.min(X_depth)))[..., np.newaxis]

In [None]:
y_pred = model.predict([X_rgb, X_depth])

In [None]:
plt.imshow(X_rgb[0])
plt.show()

In [None]:
plt.imshow(y_pred[0])
plt.show()

In [None]:
y_pred = np.clip(y_pred, a_max=1., a_min=0.) * 255.

In [None]:
plt.imshow(y_pred[0].astype(np.uint16))
plt.show()

# Predict all

In [None]:
import os
newpath = r'./00_res' 
if not os.path.exists(newpath):
    os.makedirs(newpath)

In [None]:
for i in range(15):
    print(i + 1)
    X_rgb = np.asarray(Image.open(f"{PATH_TEST}/{i+1}/rgb/{i+1}.png"))[np.newaxis, ...]
    X_depth = np.asarray(Image.open(f"{PATH_TEST}/{i+1}/depth/{i+1}.png"))[np.newaxis, ...]
    
    X_rgb = X_rgb / 255.
    X_depth = ((X_depth - np.min(X_depth)) / (np.max(X_depth) - np.min(X_depth)))[..., np.newaxis]
    
    y_pred = model.predict([X_rgb, X_depth])

    array = y_pred[0]

    if array.dtype != np.uint8:
        array = (array * 255).clip(0, 255).astype(np.uint8)

    image = Image.fromarray(array)
    image.save(f'./00_res/{i+1}.png')