<a href="https://colab.research.google.com/github/PedrolyraC/Campo-Minado/blob/main/learning_to_see_in_the_dark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import PIL.Image as Image
import math
import numpy as np
import tensorflow as tf
import tensorflow.keras.callbacks as tfkc
import tensroflow.keras.initializers as tkfi
import tensorflow.keras.layers as tfkl
import tensorflow.keras.optimizers as tfko
import tensorflow.keras.preprocessing as tfkp
import tensorflow.keras.utils as tfku

from einops import rearrange
from tensorflow.keras import Sequential

# **PSNR(Peak Signal-to-Noise Ratio)**

In [None]:
class PSNR(tf.keras.metrics.Metric):
    def __init__(self, name='psnr', dtype=tf.float32, **kwargs):
        super(PSNR, self).__init__(name=name, dtype=dtype, **kwargs)
        self.psnr_sum = self.add_weight(name='psnr_sum', initializer='zeros')
        self.total_samples = self.add_weight(name='total_samples', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        psnr = tf.image.psnr(y_true, y_pred, max_val=1.0)
        psnr = tf.cast(psnr, self.dtype)

        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, self.dtype)
            sample_weight = tf.broadcast_to(sample_weight, psnr.shape)
            psnr = tf.multiply(psnr, sample_weight)

        self.psnr_sum.assign_add(tf.reduce_sum(psnr))
        self.total_samples.assign_add(tf.cast(tf.size(psnr), self.dtype))

    def result(self):
        return self.psnr_sum / self.total_samples if self.total_samples != 0.0 else 0.0

    def reset_state(self):
        self.psnr_sum.assign(0)
        self.total_samples.assign(0)

# **Dataloaders**

In [None]:
class BaseDataLoader:
    def __init__(self, root_dir, validation_split, seed, train):
        self.root_dir = root_dir
        self.validation_split = validation_split if train else 0.0
        self.seed = seed
        self.train = train

    def load_dataset(self, dataset, image_size):
        full_dir = os.path.join(
            self.root_dir, dataset, 'train' if self.train else 'test'
        )

        try:
            lq_ds, gt_ds = tfkp.image_dataset_from_directory(
                full_dir,
                labels=None,
                color_mode='rgb',
                batch_size=None,
                image_size=image_size,
                shuffle=False,
                seed=self.seed,
                validation_split=0.5,
                subset='both',
                crop_to_aspect_ratio=True,
            )
        except:
            print(f'No dataset found')
            pass

        return lq_ds, gt_ds

In [None]:
class RawDataLoader(BaseDataLoader):
    def load_dataset(self, dataset, image_size):
        lq_ds, lq_val_ds, -, - = super().load_dataset(dataset, image_size)
        full_dir = os.path.join(
            self.root_dir, dataset, 'train', if self.train else 'test'
        )

        all_images = []
        for imag_path in tqdm(lq_ds.file_paths + lq_val_ds.file_paths):
            img_id = os.path.basename(img_path).split('_')[0]
            try:
                img = tfku.load_img(
                    os.path.join(full_dir, 'target', f'{img_id}_00_10s.png'),
                    target_size = image_size
                )
            except FileNotFoundError:
                img = tfku.load_img(
                    os.path.join(full_dir, 'target', f'{img_id}_00_30s.png'),
                    target_size = image_size
                )

            img_array = tfku.img_to_array(img)
            all_images.append(img_array)

        all_images_np = np.array(all_images)
        num_val_samples = int(self.validation_split * len(all_images_np))
        gt_ds = tf.data.Dataset.from_tensor_slices(all_images_np[:-num_val_samples])
        gt_val_ds = tf.data.Dataset.from_tensor_slices(all_images_np[-num_val_samples:])

        return lq_ds, lq_val_ds, gt_ds, gt_val_ds

In [None]:
def load_datasets(
    root_dir,
    dataset,
    image_size,
    validation_split=0.2,
    seed=None,
    shuffle=True,
    train=True
  ):

    loader = BaseDataLoader(root_dir, validation_split, seed, train)
    lq_ds, gt_ds = loader.load_dataset(dataset, image_size)

    lq_fp = lq_ds.file_paths
    gt_fp = gt_ds.file_paths

    assert lq_fp == [
        path.replace('target', 'input').replace('-gt', '') for path in gt_fp
    ], f'Mismatach in dataset alignment for {dataset}: lq {lq_fp} != gt {gt_fp}'

    train_ds = tf.data.Dataset.zip(lq_ds, gt_ds)

    if train:
        buffer_div = lq_ds.element_spec.shape[1] / 640

        data_size = train_ds.cardinality()
        train_size = round(data_size.numpy() * (1 - validation_split))
        buffer_size = min(data_size.numpy(), data_size.numpy() // buffer_div)

        if shuffle:
            print(f'Shuffling: {dataset} {str(image_size)} dataset')
            train_ds = train_ds.shuffle(
                  buffer_size=buffer_size,
                  seed=seed,
                  reshuffle_each_iteration=True
            )

        val_ds = train_ds.skip(train_size)
        train_ds = train_ds.take(train_size)

        return train_ds, val_ds
    else:
        return train_ds, None

def prepare_train_dataset(
    root_dir,
    save_dir,
    save_ds,
    date_dirs,
    batch_size,
    validation_split=0.2,
    seed=None,
    shuffle=True,
    augment=True
  ):
    train_save_dir = os.path.join(save_dir, 'train')
    val_save_dir = os.path.join(save_dir, 'val')

    if os.path.exists(f'{train_save_dir}/ALL') and os.path.exist(
        f'{val_save_dir}/ALL'
    ):
        print(f'loading cached datasets from {train_save_dir}/ALL')
        train_ds = tf.data.Dataset.load(f'{train_save_dir}/ALL', compression='GZIP')
        val_ds = tf.data.Dataset.load(f'{val_save_dir}/ALL', compression='GZIP')
    else:
        for dataset, image_size in data_dirs.items():
            print(f'processing: {dataset} {str(image_size)} dataset')
            dataset_train_path = os.path.join(train_save_dir, dataset, str(image_size))
            dataset_val_path = os.path.join(val_save_dir, dataset, str(image_size))

            if os.path.exists(dataset_train_path) and os.path.exists(dataset_val_path):
                temp_train_ds = tf.data.Dataset.load(dataset_train_path, compression='GZIP')
                temp_val_ds = tf.data.Dataset.load(dataset_val_path, compression='GZIP')
            else:
                temp_train_ds, temp_val_ds = load_datasets(
                    root_dir,
                    dataset,
                    image_size,
                    validation_split=validation_split,
                    seed=seed,
                    shuffle=shuffle,
                    train=True
                )
                print(f'Batching {dataset} {str(image_size)} dataset')
                temp_train_ds = temp_train_ds.batch(batch_size['train'])
                temp_val_ds = temp_val_ds.batch(batch_size['val'])

                if augment:
                    print(f'augmenting: {dataset} {str(image_size)} dataset')
                    temp_train_ds = apply_augmentation(temp_train_ds, seed=seed)

                if save_ds:
                    print(f'Saving: {dataset} {str(image_size)} dataset')
                    temp_train_ds.save(dataset_train_path, compression='GZIP')
                    temp_val_ds.save(dataset_val_path, compression='GZIP')

            if  'train_ds' in locals():
                train_ds = train_ds.concatenate(temp_train_ds)
            else:
                print(f'Concatenating dataset: {dataset} {str(image_size)}')
                train_ds = temp_train_ds

            if 'val_ds' in locals():
              val_ds = val_ds.concatenate(temp_val_ds)
            else:
              val_ds = temp_val_ds

    return train_ds, val_ds

def prepare_test_dataset(root_dir, save_dir, save_ds, data_dirs, batch_size, seed=None):
    test_save_dir = os.path.join(save_dir, 'test')
    if os.path.exists(f"{test_save_dir}/ALL"):
        print(f'Loading cached datasets from {test_save_dir}/ALL')
        test_ds = tf.data.Dataset.load(f'{test_save_dir}/ALL', compression='GZIP')
    else:
        for dataset, image_size in data_dirs.items():
            dataset_test_path = os.path.join(save_dir, 'test', dataset, str(image_size))

            if os.path.exists(dataset_test_path):
                temp_test_ds = tf.data.Dataset.load(
                    dataset_test_path, compression='GZIP'
                )
            else:
                temp_test_ds, _ = load_datasets(
                    root_dir,
                    dataset,
                    image_size,
                    validation_split=0,
                    seed=seed,
                    train=False,
                )
                print(f'Batching: {dataset} {str(image_size)} dataset')
                temp_test_ds = temp_test_ds.batch(batch_size['val'])

                if save_ds:
                    print(f'Saving: {dataset} {str(image_size)} dataset')
                    temp_test_ds.save(dataset_test_path, compression='GZIP')

            if "test_ds" in locals():
                print(f'Concatenating dataset: {dataset} {str(image_size)}')
                test_ds = test_ds.concatenate(temp_test_ds)
            else:
                test_ds = temp_test_ds
        if save_ds:
            test_ds.save(os.path.join(save_dir, 'val/ALL'), compression='GZIP')

    return test_ds

def prepare_predict_dataset(directory):
    size = None

    for file in os.listdir(directory):
        full_path = os.path.join(directory, file)
        if os.path.isfile(full_path):
            try:
                with Image.open(full_path) as img:
                    size = img.size
                    break
            except IOError:
                pass

    if size is None:
        raise ValueError('No valid images found in the directory.')

    dataset = tfkp.image_dataset_from_directory(
        directory,
        labels=None,
        color_mode='rgb',
        batch_size=None,
        image_size=size,
        shuffle=False,
    )

    return dataset


class DataLoader:
    def __init__(self, options, seed=None):
        print('Instantiating DataLoader...')
        self.root_dir = options['root_dir']
        self.save_dir = options['save_dir']
        self.data_dirs = options['data_dirs']
        self.validation_split = options['validation_split']
        self.batch_size = options['batch_size']
        self.use_augment = options['use_augment']
        self.use_shuffle = options['use_shuffle']
        self.save_ds = options['save_ds']
        self.seed = seed

    def load_train_data(self):
        train_ds, val_ds = prepare_train_dataset(
            self.root_dir,
            self.save_dir,
            self.save_ds,
            self.data_dirs,
            self.batch_size,
            self.validation_split,
            self.seed,
            self.use_shuffle,
            self.use_augment,
        )

        return (
            train_ds.prefetch(tf.data.AUTOTUNE),
            val_ds.prefetch(tf.data.AUTOTUNE),
        )

    def load_test_data(self):
        test_ds = prepare_test_dataset(
            self.root_dir,
            self.save_dir,
            self.save_ds,
            self.data_dirs,
            self.batch_size,
            self.seed,
        )

        return test_ds.cache().prefetch(tf.data.AUTOTUNE)

    def load_predict_data(self, predict_dir):
        return prepare_predict_dataset(predict_dir)

# **Data Augmentation**

In [None]:
class PairedImageAugumentation(tf.keras.layers.Layer):
    def __init__(self):
        super(PairedImageAugumentation, self).__init__()
        self.h_flip = tf.image.flip_left_right
        self.v_flip = tf.image.flip_up_down

    def call(self, inputs, seed=None):
        lq_img_batch, gt_img_batch = inputs

        h_flip_seed = tf.random.uniform([], seed=seed, minval=0, maxval=2, dtype=tf.int32)
        v_flip_seed = tf.random.uniform([], seed=seed, minval=0, maxval=2, dtype=tf.int32)
        rot_flip_seed = tf.random.uniform([], seed=seed, minval=0, maxval=4, dtype=tf.int32)

        lq_img_batch = tf.cond(h_flip_seed == 1, lambda: self.h_flip(lq_img_batch), lambda: lq_img_batch)
        gt_img_batch = tf.cond(h_flip_seed == 1, lambda: self.h_flip(gt_img_batch), lambda: gt_img_batch)

        lq_img_batch = tf.cond(v_flip_seed == 1, lambda: self.v_flip(lq_img_batch), lambda: lq_img_batch)
        gt_img_batch = tf.cond(v_flip_seed == 1, lambda: self.v_flip(gt_img_batch), lambda: gt_img_batch)

        lq_img_batch = tf.image.rot90(lq_imq_batch, k=rot_flip_seed)
        gt_img_batch = tf.image.rot90(gt_img_batch, k=rot_flip_seed)

        return lq_img_batch, gt_img_batch

    def apply_augmentation(dataset, seed=None):
        custom_augmentation = PairedImageAugmentation()

        def data_augmentation(lq_img_batch, gt_img_batch, seed=seed):
            normalized_lq_batch = lq_img_batch/255.0
            normalized_gt_batch = gt_img_batch/255.0

            augmented_lq_batch, augmented_gt_batch = custom_augmentation(
                lq_img_batch, gt_img_batch, seed=seed
            )
            return augmented_lq_batch, augmented_gt_batch

        return dataset.map(
            lambda lq_img_batch, gt_img_batch: data_augmentation(
                lq_img_batch, gt_img_batch, seed=seed
            ),
            num_parallel_calls=AUTO
        )

# **Model**