<a href="https://colab.research.google.com/github/PedrolyraC/Campo-Minado/blob/main/learning_to_see_in_the_dark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
!pip install einops



In [None]:
!nvidia -smi

In [27]:
import argparse
import math
import numpy as np
import os
import PIL.Image as Image
import random
import tensorflow as tf
import yaml
import tensorflow.keras.callbacks as tfkc
import tensorflow.keras.initializers as tfki
import tensorflow.keras.layers as tfkl
import tensorflow.keras.models as tfkm
import tensorflow.keras.preprocessing as tfkp
import tensorflow.keras.optimizers as tfko
import tensorflow.keras.utils as tfku

from collections import OrderedDict
from einops import rearrange
from tensorflow.keras import Sequential
from tqdm import tqdm

In [28]:
AUTO = tf.data.AUTOTUNE
YAML_PATH = '/content/drive/MyDrive/Computer Vision/options.yml'
LOL_PATH = '/content/drive/MyDrive/Computer Vision/Dataset/train_datasets/lol_dataset'
SID_PATH = '/content/drive/MyDrive/Computer Vision/Dataset/train_datasets/SID dataset'

# **PSNR(Peak Signal-to-Noise Ratio)**

In [29]:
class PSNR(tf.keras.metrics.Metric):
    def __init__(self, name='psnr', dtype=tf.float32, **kwargs):
        super(PSNR, self).__init__(name=name, dtype=dtype, **kwargs)
        self.psnr_sum = self.add_weight(name='psnr_sum', initializer='zeros')
        self.total_samples = self.add_weight(name='total_samples', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        psnr = tf.image.psnr(y_true, y_pred, max_val=1.0)
        psnr = tf.cast(psnr, self.dtype)

        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, self.dtype)
            sample_weight = tf.broadcast_to(sample_weight, psnr.shape)
            psnr = tf.multiply(psnr, sample_weight)

        self.psnr_sum.assign_add(tf.reduce_sum(psnr))
        self.total_samples.assign_add(tf.cast(tf.size(psnr), self.dtype))

    def result(self):
        return self.psnr_sum / self.total_samples if self.total_samples != 0.0 else 0.0

    def reset_state(self):
        self.psnr_sum.assign(0)
        self.total_samples.assign(0)

# **Dataloaders**

In [30]:
class BaseDataLoader:
    def __init__(self, root_dir, validation_split, seed, train):
        self.root_dir = root_dir
        self.validation_split = validation_split if train else 0.0
        self.seed = seed
        self.train = train

    def load_dataset(self, dataset, image_size):
        full_dir = os.path.join(
            self.root_dir, dataset, 'train' if self.train else 'test'
        )

        try:
            lq_ds, gt_ds = tfkp.image_dataset_from_directory(
                full_dir,
                labels=None,
                color_mode='rgb',
                batch_size=None,
                image_size=image_size,
                shuffle=False,
                seed=self.seed,
                validation_split=0.5,
                subset='both',
                crop_to_aspect_ratio=True,
            )
        except:
            print(f'No dataset found')
            pass

        return lq_ds, gt_ds

In [31]:
class RawDataLoader(BaseDataLoader):
    def load_dataset(self, dataset, image_size):
        lq_ds, lq_val_ds, _, _ = super().load_dataset(dataset, image_size)
        full_dir = os.path.join(
            self.root_dir, dataset, 'train' if self.train else 'test'
        )

        all_images = []
        for img_path in tqdm(lq_ds.file_paths + lq_val_ds.file_paths):
            img_id = os.path.basename(img_path).split('_')[0]
            try:
                img = tfku.load_img(
                    os.path.join(full_dir, 'target', f'{img_id}_00_10s.png'),
                    target_size = image_size
                )
            except FileNotFoundError:
                img = tfku.load_img(
                    os.path.join(full_dir, 'target', f'{img_id}_00_30s.png'),
                    target_size = image_size
                )

            img_array = tfku.img_to_array(img)
            all_images.append(img_array)

        all_images_np = np.array(all_images)
        num_val_samples = int(self.validation_split * len(all_images_np))
        gt_ds = tf.data.Dataset.from_tensor_slices(all_images_np[:-num_val_samples])
        gt_val_ds = tf.data.Dataset.from_tensor_slices(all_images_np[-num_val_samples:])

        return lq_ds, lq_val_ds, gt_ds, gt_val_ds

In [32]:
def load_datasets(
    root_dir,
    dataset,
    image_size,
    validation_split=0.2,
    seed=None,
    shuffle=True,
    train=True
  ):

    loader = BaseDataLoader(root_dir, validation_split, seed, train)
    lq_ds, gt_ds = loader.load_dataset(dataset, image_size)

    lq_fp = lq_ds.file_paths
    gt_fp = gt_ds.file_paths

    assert lq_fp == [
        path.replace('target', 'input').replace('-gt', '') for path in gt_fp
    ], f'Mismatach in dataset alignment for {dataset}: lq {lq_fp} != gt {gt_fp}'

    train_ds = tf.data.Dataset.zip(lq_ds, gt_ds)

    if train:
        buffer_div = lq_ds.element_spec.shape[1] / 640

        data_size = train_ds.cardinality()
        train_size = round(data_size.numpy() * (1 - validation_split))
        buffer_size = min(data_size.numpy(), data_size.numpy() // buffer_div)

        if shuffle:
            print(f'Shuffling: {dataset} {str(image_size)} dataset')
            train_ds = train_ds.shuffle(
                  buffer_size=buffer_size,
                  seed=seed,
                  reshuffle_each_iteration=True
            )

        val_ds = train_ds.skip(train_size)
        train_ds = train_ds.take(train_size)

        return train_ds, val_ds
    else:
        return train_ds, None

def prepare_train_dataset(
    root_dir,
    save_dir,
    save_ds,
    data_dirs,
    batch_size,
    validation_split=0.2,
    seed=None,
    shuffle=True,
    augment=True
  ):
    train_save_dir = os.path.join(save_dir, 'train')
    val_save_dir = os.path.join(save_dir, 'val')

    if os.path.exists(f'{train_save_dir}/ALL') and os.path.exist(
        f'{val_save_dir}/ALL'
    ):
        print(f'loading cached datasets from {train_save_dir}/ALL')
        train_ds = tf.data.Dataset.load(f'{train_save_dir}/ALL', compression='GZIP')
        val_ds = tf.data.Dataset.load(f'{val_save_dir}/ALL', compression='GZIP')
    else:
        for dataset, image_size in data_dirs.items():
            print(f'processing: {dataset} {str(image_size)} dataset')
            dataset_train_path = os.path.join(train_save_dir, dataset, str(image_size))
            dataset_val_path = os.path.join(val_save_dir, dataset, str(image_size))

            if os.path.exists(dataset_train_path) and os.path.exists(dataset_val_path):
                temp_train_ds = tf.data.Dataset.load(dataset_train_path, compression='GZIP')
                temp_val_ds = tf.data.Dataset.load(dataset_val_path, compression='GZIP')
            else:
                temp_train_ds, temp_val_ds = load_datasets(
                    root_dir,
                    dataset,
                    image_size,
                    validation_split=validation_split,
                    seed=seed,
                    shuffle=shuffle,
                    train=True
                )
                print(f'Batching {dataset} {str(image_size)} dataset')
                temp_train_ds = temp_train_ds.batch(batch_size['train'])
                temp_val_ds = temp_val_ds.batch(batch_size['val'])

                if augment:
                    print(f'augmenting: {dataset} {str(image_size)} dataset')
                    temp_train_ds = apply_augmentation(temp_train_ds, seed=seed)

                if save_ds:
                    print(f'Saving: {dataset} {str(image_size)} dataset')
                    temp_train_ds.save(dataset_train_path, compression='GZIP')
                    temp_val_ds.save(dataset_val_path, compression='GZIP')

            if  'train_ds' in locals():
                train_ds = train_ds.concatenate(temp_train_ds)
            else:
                print(f'Concatenating dataset: {dataset} {str(image_size)}')
                train_ds = temp_train_ds

            if 'val_ds' in locals():
              val_ds = val_ds.concatenate(temp_val_ds)
            else:
              val_ds = temp_val_ds

    return train_ds, val_ds

def prepare_test_dataset(root_dir, save_dir, save_ds, data_dirs, batch_size, seed=None):
    test_save_dir = os.path.join(save_dir, 'test')
    if os.path.exists(f'{test_save_dir}/ALL'):
        print(f'Loading cached datasets from {test_save_dir}/ALL')
        test_ds = tf.data.Dataset.load(f'{test_save_dir}/ALL', compression='GZIP')
    else:
        for dataset, image_size in data_dirs.items():
            dataset_test_path = os.path.join(save_dir, 'test', dataset, str(image_size))

            if os.path.exists(dataset_test_path):
                temp_test_ds = tf.data.Dataset.load(
                    dataset_test_path, compression='GZIP'
                )
            else:
                temp_test_ds, _ = load_datasets(
                    root_dir,
                    dataset,
                    image_size,
                    validation_split=0,
                    seed=seed,
                    train=False,
                )
                print(f'Batching: {dataset} {str(image_size)} dataset')
                temp_test_ds = temp_test_ds.batch(batch_size['val'])

                if save_ds:
                    print(f'Saving: {dataset} {str(image_size)} dataset')
                    temp_test_ds.save(dataset_test_path, compression='GZIP')

            if 'test_ds' in locals():
                print(f'Concatenating dataset: {dataset} {str(image_size)}')
                test_ds = test_ds.concatenate(temp_test_ds)
            else:
                test_ds = temp_test_ds
        if save_ds:
            test_ds.save(os.path.join(save_dir, 'val/ALL'), compression='GZIP')

    return test_ds

def prepare_predict_dataset(directory):
    size = None

    for file in os.listdir(directory):
        full_path = os.path.join(directory, file)
        if os.path.isfile(full_path):
            try:
                with Image.open(full_path) as img:
                    size = img.size
                    break
            except IOError:
                pass

    if size is None:
        raise ValueError('No valid images found in the directory.')

    dataset = tfkp.image_dataset_from_directory(
        directory,
        labels=None,
        color_mode='rgb',
        batch_size=None,
        image_size=size,
        shuffle=False,
    )

    return dataset

class DataLoader:
    def __init__(self, options, seed=None):
        print('Instantiating DataLoader...')
        self.root_dir = options['root_dir']
        self.save_dir = options['save_dir']
        self.data_dirs = options['data_dirs']
        self.validation_split = options['validation_split']
        self.batch_size = options['batch_size']
        self.use_augment = options['use_augment']
        self.use_shuffle = options['use_shuffle']
        self.save_ds = options['save_ds']
        self.seed = seed

    def load_train_data(self):
        train_ds, val_ds = prepare_train_dataset(
            self.root_dir,
            self.save_dir,
            self.save_ds,
            self.data_dirs,
            self.batch_size,
            self.validation_split,
            self.seed,
            self.use_shuffle,
            self.use_augment,
        )

        return (
            train_ds.prefetch(tf.data.AUTOTUNE),
            val_ds.prefetch(tf.data.AUTOTUNE),
        )

    def load_test_data(self):
        test_ds = prepare_test_dataset(
            self.root_dir,
            self.save_dir,
            self.save_ds,
            self.data_dirs,
            self.batch_size,
            self.seed,
        )

        return test_ds.cache().prefetch(tf.data.AUTOTUNE)

    def load_predict_data(self, predict_dir):
        return prepare_predict_dataset(predict_dir)

# **Data Augmentation**

In [33]:
class PairedImageAugumentation(tf.keras.layers.Layer):
    def __init__(self):
        super(PairedImageAugumentation, self).__init__()
        self.h_flip = tf.image.flip_left_right
        self.v_flip = tf.image.flip_up_down

    def call(self, inputs, seed=None):
        lq_img_batch, gt_img_batch = inputs

        h_flip_seed = tf.random.uniform([], seed=seed, minval=0, maxval=2, dtype=tf.int32)
        v_flip_seed = tf.random.uniform([], seed=seed, minval=0, maxval=2, dtype=tf.int32)
        rot_flip_seed = tf.random.uniform([], seed=seed, minval=0, maxval=4, dtype=tf.int32)

        lq_img_batch = tf.cond(h_flip_seed == 1, lambda: self.h_flip(lq_img_batch), lambda: lq_img_batch)
        gt_img_batch = tf.cond(h_flip_seed == 1, lambda: self.h_flip(gt_img_batch), lambda: gt_img_batch)

        lq_img_batch = tf.cond(v_flip_seed == 1, lambda: self.v_flip(lq_img_batch), lambda: lq_img_batch)
        gt_img_batch = tf.cond(v_flip_seed == 1, lambda: self.v_flip(gt_img_batch), lambda: gt_img_batch)

        lq_img_batch = tf.image.rot90(lq_img_batch, k=rot_flip_seed)
        gt_img_batch = tf.image.rot90(gt_img_batch, k=rot_flip_seed)

        return lq_img_batch, gt_img_batch

def apply_augmentation(dataset, seed=None):
  custom_augmentation = PairedImageAugumentation()

  def data_augmentation(lq_img_batch, gt_img_batch, seed=seed):
      normalized_lq_batch = lq_img_batch/255.0
      normalized_gt_batch = gt_img_batch/255.0

      augmented_lq_batch, augmented_gt_batch = custom_augmentation(
          lq_img_batch, gt_img_batch, seed=seed
      )
      return augmented_lq_batch, augmented_gt_batch

  return dataset.map(
      lambda lq_img_batch, gt_img_batch: data_augmentation(
          lq_img_batch, gt_img_batch, seed=seed
      ),
      num_parallel_calls=AUTO
  )

# **Model**

In [34]:
class CosineDecayCycleRestarts(tfko.schedules.LearningRateSchedule):
    def __init__(
        self,
        base_lr,
        first_decay_steps,
        t_mul=2.0,
        m_mul=1.0,
        alpha=[0.0, 0.0],
        name=None,
    ):
        super().__init__()
        self.base_lr = base_lr
        self.first_decay_steps = first_decay_steps
        self._t_mul = t_mul
        self._m_mul = m_mul
        self.alpha = alpha
        self.name = name

    def __call__(self, step):
        with tf.name_scope(self.name or 'SGDRDecay') as name:
            base_lr = tf.convert_to_tensor(self.base_lr, name='base_lr')
            dtype = base_lr.dtype
            first_decay_steps = tf.cast(self.first_decay_steps, dtype)
            t_mul = tf.cast(self._t_mul, dtype)
            m_mul = tf.cast(self._m_mul, dtype)

            global_step_recomp = tf.cast(step, dtype)
            completed_fraction = global_step_recomp / first_decay_steps

            alpha = tf.cond(
                tf.greater(completed_fraction, 1.0),
                lambda: tf.cast(self.alpha[1], dtype),
                lambda: tf.cast(self.alpha[0], dtype),
            )

            def compute_step(completed_fraction, geometric=False):
                if geometric:
                    i_restart = tf.floor(
                        tf.math.log(1.0 - completed_fraction * (1.0 - t_mul))
                        / tf.math.log(t_mul)
                    )

                    sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
                    completed_fraction = (
                        completed_fraction - sum_r
                    ) / t_mul**i_restart

                else:
                    i_restart = tf.floor(completed_fraction)
                    completed_fraction -= i_restart

                return i_restart, completed_fraction

            i_restart, completed_fraction = tf.cond(
                tf.equal(t_mul, 1.0),
                lambda: compute_step(completed_fraction, geometric=False),
                lambda: compute_step(completed_fraction, geometric=True),
            )

            m_fac = m_mul**i_restart
            cosine_decayed = (
                0.5
                * m_fac
                * (1.0 + tf.cos(tf.constant(math.pi, dtype=dtype) * completed_fraction))
            )
            decayed = (1 - alpha) * cosine_decayed + alpha

            return tf.multiply(base_lr, decayed, name=name)

    def get_config(self):
        return {
            'base_lr': self.base_lr,
            'first_decay_steps': self.first_decay_steps,
            't_mul': self._t_mul,
            'm_mul': self._m_mul,
            'alpha': self.alpha,
            'name': self.name,
        }

In [35]:
def dim_swap(tensor, order=[]):
    if len(order) != 2:
        raise ValueError('Order list must have exactly two elements.')

    ndims = len(tensor.shape)

    order = [d if d >= 0 else ndims + d for d in order]

    if order[0] >= ndims or order[1] >= ndims or order[0] < 0 or order[1] < 0:
        raise IndexError('Order indices are out of range for the tensor dimensions.')

    perm = list(range(ndims))
    perm[order[0]], perm[order[1]] = perm[order[1]], perm[order[0]]

    return tf.transpose(tensor, perm=perm)


def flatten(input_tensor, start_dim, end_dim):
    shape = tf.shape(input_tensor)
    slice_numel = tf.reduce_prod(shape[start_dim : end_dim + 1])
    new_shape = tf.concat(
        [
            shape[:start_dim],
            [slice_numel],
            shape[end_dim + 1 :],
        ],
        axis=0,
    )

    return tf.reshape(input_tensor, new_shape)

class PreNorm(tfkl.Layer):
    def __init__(self, fn):
        super(PreNorm, self).__init__()
        self.fn = fn
        self.norm = tfkl.LayerNormalization(axis=-1, epsilon=1e-6)

    def call(self, x):
        x = self.norm(x)
        return self.fn(x)

class Illumination_Estimator(tfkl.Layer):
    def __init__(self, n_fea_middle, n_fea_out=3):
        super(Illumination_Estimator, self).__init__()

        self.conv1 = tfkl.Conv2D(n_fea_middle, kernel_size=1, use_bias=True)
        self.depth_conv = tfkl.DepthwiseConv2D(
            kernel_size=5, padding='same', use_bias=True
        )

        self.conv2 = tfkl.Conv2D(n_fea_out, kernel_size=1, use_bias=True)

    def call(self, img):
        mean_c = tf.expand_dims(tf.reduce_mean(img, axis=3), axis=3)
        input = tf.concat([img, mean_c], axis=3)

        x_1 = self.conv1(input)
        illu_fea = self.depth_conv(x_1)
        illu_map = self.conv2(illu_fea)
        return illu_fea, illu_map


class IG_MSA(tfkm.Model):
    def __init__(
        self,
        dim,
        dim_head=40,
        heads=8,
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_q = tfkl.Dense(
            dim_head * heads,
            use_bias=False,
            kernel_initializer=tfki.TruncatedNormal(stddev=0.02),
        )
        self.to_k = tfkl.Dense(
            dim_head * heads,
            use_bias=False,
            kernel_initializer=tfki.TruncatedNormal(stddev=0.02),
        )
        self.to_v = tfkl.Dense(
            dim_head * heads,
            use_bias=False,
            kernel_initializer=tfki.TruncatedNormal(stddev=0.02),
        )
        self.rescale = tf.Variable(tf.ones([heads, 1, 1]))
        self.proj = tfkl.Dense(
            dim, use_bias=True, kernel_initializer=tfki.TruncatedNormal(stddev=0.02)
        )
        self.pos_emb = Sequential(
            [
                tfkl.DepthwiseConv2D(
                    kernel_size=3,
                    strides=1,
                    padding='same',
                    use_bias=True,
                    activation='gelu',
                ),
                tfkl.DepthwiseConv2D(kernel_size=3, padding='same', use_bias=True),
            ]
        )

        self.dim = dim

    def call(self, x_in, illu_fea_trans):
        b, h, w, c = (
            tf.shape(x_in)[0],
            tf.shape(x_in)[1],
            tf.shape(x_in)[2],
            tf.shape(x_in)[3],
        )
        x = tf.reshape(x_in, [b, h * w, c])
        q_inp = self.to_q(x)
        k_inp = self.to_k(x)
        v_inp = self.to_v(x)
        illu_attn = illu_fea_trans
        illu_attn_flat = flatten(illu_attn, 1, 2)
        q, k, v, illu_attn = map(
            lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
            (q_inp, k_inp, v_inp, illu_attn_flat),
        )
        v = v * illu_attn
        q = dim_swap(q, [-2, -1])
        k = dim_swap(k, [-2, -1])
        v = dim_swap(v, [-2, -1])
        q = tf.nn.l2_normalize(q)
        k = tf.nn.l2_normalize(k)
        attn = tf.matmul(k, dim_swap(q, [-2, -1]))
        attn = attn * self.rescale
        attn = tf.nn.softmax(attn)
        x = tf.matmul(attn, v)
        x = tf.transpose(x, [0, 3, 1, 2])
        x = tf.reshape(x, [b, h * w, self.num_heads * self.dim_head])
        out_c = tf.reshape(self.proj(x), [b, h, w, c])
        out_p = self.pos_emb(tf.reshape(v_inp, [b, h, w, c]))
        out = out_c + out_p

        return out

class FeedForward(tfkm.Model):
    def __init__(self):
        super().__init__()
        self.net = Sequential(
            [
                tfkl.DepthwiseConv2D(
                    kernel_size=1,
                    strides=1,
                    use_bias=False,
                    activation='gelu',
                ),
                tfkl.DepthwiseConv2D(
                    kernel_size=3,
                    strides=1,
                    padding='same',
                    use_bias=False,
                    activation='gelu',
                ),
                tfkl.DepthwiseConv2D(kernel_size=1, strides=1, use_bias=False),
            ]
        )

    def call(self, x):
        out = self.net(x)
        return out

class IGAB(tfkl.Layer):
    def __init__(self, dim, dim_head=40, heads=8, num_blocks=2):
        super().__init__()
        self.blocks = []
        for _ in range(num_blocks):
            self.blocks.append(
                [
                    IG_MSA(dim=dim, dim_head=dim_head, heads=heads),
                    PreNorm(fn=FeedForward()),
                ]
            )

    def call(self, x, illu_fea):
        for attn, ff in self.blocks:
            x = attn(x, illu_fea_trans=illu_fea) + x
            x = ff(x) + x
        out = x
        return out

class Corruption_Restorer(tfkl.Layer):
    def __init__(self, out_dim=3, dim=40, level=2, num_blocks=[1, 2, 2]):
        super(Corruption_Restorer, self).__init__()
        self.dim = dim
        self.level = level

        # Input projection
        self.embedding = tfkl.Conv2D(
            self.dim, kernel_size=3, strides=1, padding='same', use_bias=False
        )

        # Encoder
        self.encoder_layers = []
        dim_level = dim
        for i in range(level):
            self.encoder_layers.append(
                [
                    IGAB(
                        dim=dim_level,
                        num_blocks=num_blocks[i],
                        dim_head=dim,
                        heads=dim_level // dim,
                    ),
                    tfkl.Conv2D(
                        dim_level * 2,
                        kernel_size=4,
                        strides=2,
                        padding='same',
                        use_bias=False,
                    ),
                    tfkl.Conv2D(
                        dim_level * 2,
                        kernel_size=4,
                        strides=2,
                        padding='same',
                        use_bias=False,
                    ),
                ]
            )
            dim_level *= 2

        # Bottleneck
        self.bottleneck = IGAB(
            dim=dim_level,
            dim_head=dim,
            heads=dim_level // dim,
            num_blocks=num_blocks[-1],
        )

        # Decoder
        self.decoder_layers = []
        for i in range(level):
            self.decoder_layers.append(
                [
                    tfkl.Conv2DTranspose(dim_level // 2, kernel_size=2, strides=2),
                    tfkl.Conv2D(
                        dim_level // 2, kernel_size=1, strides=1, use_bias=False
                    ),
                    IGAB(
                        dim=dim_level // 2,
                        num_blocks=num_blocks[level - 1 - i],
                        dim_head=dim,
                        heads=(dim_level // 2) // dim,
                    ),
                ]
            )
            dim_level //= 2

        # Output projection
        self.mapping = tfkl.Conv2D(
            out_dim, kernel_size=3, strides=1, padding='same', use_bias=False
        )

    def call(self, x, illu_fea):
        # Embedding
        fea = self.embedding(x)

        # Encoder
        fea_encoder = []
        illu_fea_list = []
        for IGAB, FeaDownSample, IlluFeaDownsample in self.encoder_layers:
            fea = IGAB(fea, illu_fea)
            illu_fea_list.append(illu_fea)
            fea_encoder.append(fea)
            fea = FeaDownSample(fea)
            illu_fea = IlluFeaDownsample(illu_fea)

        # Bottleneck
        fea = self.bottleneck(fea, illu_fea)

        # Decoder
        for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
            fea = FeaUpSample(fea)
            fea = Fution(
                tfkl.concatenate([fea, fea_encoder[self.level - 1 - i]], axis=-1)
            )
            illu_fea = illu_fea_list[self.level - 1 - i]
            fea = LeWinBlcok(fea, illu_fea)

        # Mapping
        out = self.mapping(fea) + x
        return out

class RetinexFormer_Single_Stage(tfkl.Layer):
    def __init__(self, out_channels=3, n_feat=40, level=2, num_blocks=[1, 2, 2]):
        super(RetinexFormer_Single_Stage, self).__init__()
        self.estimator = Illumination_Estimator(n_feat)
        self.denoiser = Corruption_Restorer(
            out_dim=out_channels,
            dim=n_feat,
            level=level,
            num_blocks=num_blocks,
        )

    def call(self, img):
        illu_fea, illu_map = self.estimator(img)
        input_img = img * illu_map + img
        output_img = self.denoiser(input_img, illu_fea)

        return output_img

class RetinexFormer(tfkm.Model):
    def __init__(self, out_channels=3, n_feat=40, stage=1, num_blocks=[1, 2, 2]):
        super(RetinexFormer, self).__init__()
        self.stage = stage

        self.body = Sequential(
            [
                RetinexFormer_Single_Stage(
                    out_channels=out_channels,
                    n_feat=n_feat,
                    level=2,
                    num_blocks=num_blocks,
                )
                for _ in range(stage)
            ]
        )

    def call(self, x):
        out = self.body(x)
        return out

In [36]:
class PrintLR(tfkc.Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = self.model.optimizer._decayed_lr(tf.float32).numpy()
        print(f'Learning rate for epoch {epoch + 1} is {lr:.6f}')

class RetinexFormerModel:
    def __init__(self, options):
        self.options = options
        self.seed = self.options['manual_seed']
        self.checkpoint_dir = (
            self.options['checkpoint_dir']
            if 'checkpoint_dir' in self.options
            else './model/training_checkpoints'
        )
        self.checkpoint_prefix = os.path.join(self.checkpoint_dir, 'ckpt_{epoch:04d}')
        self.initial_epoch = 0
        self.logs_dir = self.options['logs_dir']
        self.dataset_options = self.options['dataset']
        self.model_options = self.options['model']
        self.training_options = self.options['training']
        self.data_loader = DataLoader(self.dataset_options, self.seed)
        self.model = self.create_model()

    def create_model(self):
        tfku.set_random_seed(self.seed)
        model = RetinexFormer(**self.model_options)
        return model

    def compile_model(self):
        first_decay_steps = (
            self.training_options['scheduler']['periods'][0]
            // self.dataset_options['batch_size']
        )
        base_lr = self.training_options['optimizer']['lr']
        t_mul = (
            self.training_options['scheduler']['periods'][1]
            / self.training_options['scheduler']['periods'][0]
        )
        m_mul = self.training_options['scheduler']['m_mul']
        alpha = [
            alpha / base_lr for alpha in self.training_options['scheduler']['lr_mins']
        ]
        clipnorm = 0.01 if self.training_options['optimizer']['clipnorm'] else None
        beta_1 = self.training_options['optimizer']['betas'][0]
        beta_2 = self.training_options['optimizer']['betas'][1]

        learning_rate_schedule = CosineDecayCycleRestarts(
            first_decay_steps=first_decay_steps,
            base_lr=base_lr,
            t_mul=t_mul,
            m_mul=m_mul,
            alpha=alpha,
        )
        self.model.compile(
            optimizer=tfko.Adam(
                learning_rate=learning_rate_schedule,
                beta_1=beta_1,
                beta_2=beta_2,
                global_clipnorm=clipnorm,
            ),
            loss='mae',
            metrics=[f'accuracy {PSNR()}'],
        )

    def load_weights(self):
        latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_dir)

        if latest_checkpoint:
            print(f'Loading weights from {latest_checkpoint}')
            self.model.load_weights(latest_checkpoint)
            checkpoint_name = os.path.basename(latest_checkpoint)
            self.initial_epoch = int(checkpoint_name.split('_')[-1])

    def train(self):
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.logs_dir, exist_ok=True)

        train_ds, val_ds = self.data_loader.load_train_data()
        epochs = max(
            1,
            self.training_options['total_iter']
            // (train_ds.cardinality().numpy() * self.dataset_options['batch_size']),
        )

        callbacks = [
            tfkc.ModelCheckpoint(
                filepath=self.checkpoint_prefix, verbose=1, save_weights_only=True
            ),
            tfkc.TensorBoard(
                log_dir=self.logs_dir, histogram_freq=1, profile_batch='500,520'
            ),
            tfkc.CSVLogger(os.path.join(self.logs_dir, 'training.log')),
            PrintLR(),
        ]

        self.model.fit(
            train_ds,
            epochs=epochs,
            initial_epoch=self.initial_epoch,
            verbose='auto',
            validation_data=val_ds,
            shuffle=True,
            callbacks=callbacks,
        )

    def evaluate(self):
        test_ds = self.data_loader.load_test_data()

        callbacks = [
            tfkc.TensorBoard(log_dir=self.logs_dir, histogram_freq=1),
            tfkc.CSVLogger(os.path.join(self.logs_dir, 'test.log')),
        ]

        return self.model.evaluate(
            test_ds,
            callbacks=callbacks,
        )

    def predict(self, data):
        predict_data = self.data_loader.load_predict_data(data)
        return self.model.predict(predict_data)

# **Run**

In [40]:
def ordered_yaml():
    try:
        from yaml import CDumper as Dumper
        from yaml import CLoader as Loader
    except ImportError:
        from yaml import Dumper, Loader

    _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG

    def dict_representer(dumper, data):
        return dumper.represent_dict(data.items())

    def dict_constructor(loader, node):
        return OrderedDict(loader.construct_pairs(node))

    Dumper.add_representer(OrderedDict, dict_representer)
    Loader.add_constructor(_mapping_tag, dict_constructor)
    return Loader, Dumper

def parse_yaml(opt_path):
    with open(opt_path, mode='r') as f:
        Loader, _ = ordered_yaml()
        opt = yaml.load(f, Loader=Loader)

    return opt

def parse_options():
    parser = argparse.ArgumentParser(description='Run RetinexFormer model')
    parser.add_argument(
        '-m',
        '--mode',
        metavar='\b',
        choices=['train', 'test', 'predict'],
        required=True,
        help='Select a mode to run the model in ["train", "test", "predict"]',
    )
    parser.add_argument(
        '-o',
        '--opt',
        metavar='\b',
        type=str,
        default=YAML_PATH,
    )
    parser.add_argument(
        '-e',
        '--enhance',
        metavar='\b',
        type=str,
        default=None,
        help='Path to image file(s) to enhance',
    )

    args = parser.parse_args()

    with open(args.opt, 'r') as stream:
        try:
            opt = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)

    for arg in vars(args):
        opt[arg] = getattr(args, arg)

    seed = opt.get('manual_seed')
    if seed is None:
        seed = random.randint(1, 10000)
        opt['manual_seed'] = seed

    return opt

In [41]:
options = parse_options()

model = RetinexFormerModel(options)

if options['mode'] == 'train':
    model.compile_model()
    model.load_weights()
    model.train()
elif options['mode'] == 'test':
    model.load_weights()
    model.evaluate()
elif options['mode'] == 'predict':
    model.load_weights()
    model.predict(options.enhance)

usage: colab_kernel_launcher.py [-h] -m  [-o ] [-e ]
colab_kernel_launcher.py: error: the following arguments are required: -m/--mode


SystemExit: 2