In [None]:
# get datasets
# different kinds of qualities and sizes

#!unzip "/content/drive/MyDrive/SR_DATASET_SMALLER.zip" -d "/content/SR_DATASET_SMALLER_150_X_150"
#!unzip "/content/drive/MyDrive/SR_DATASET_SMALLER_HYSTO.zip" -d "/content/SR_DATASET_SMALLER_HYSTO"
#!unzip "/content/drive/MyDrive/SR_DATASET_SMALLER_WQ" -d "/content/SR_DATASET_SMALLER_WQ"
!unzip "/content/drive/MyDrive/SR_DATASET_SMALLER_HYSTO_128.zip" -d "/content/SR_DATASET_SMALLER_128_X_128"
#!unzip "/content/drive/MyDrive/restoration_new_12.07.zip"

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# paths to different daatasets
# original dataset was too big, sr_data_smaller is a subset of it, still including ~25k train images
# sr_data_smaller_augmented are the images with worsened quality 
paths = {
    "train_x": "/content/SR_DATASET_SMALLER_150_X_150/SR_DATASET_SMALLER/SR_TRAIN/SR_DATA_SMALLER_AUGMENTED",
    "train_y": "/content/SR_DATASET_SMALLER_150_X_150/SR_DATASET_SMALLER/SR_TRAIN/SR_DATA_SMALLER",
    "test_x": "/content/SR_DATASET_SMALLER_150_X_150/SR_DATASET_SMALLER/SR_TEST/SR_DATA_SMALLER_AUGMENTED",
    "test_y": "/content/SR_DATASET_SMALLER_150_X_150/SR_DATASET_SMALLER/SR_TEST/SR_DATA_SMALLER",
}

paths_128 = {
    "train_x": "/content/SR_DATASET_SMALLER_128_X_128/SR_DATASET_SMALLER_HYSTO/SR_TRAIN/SR_DATA_SMALLER_AUGMENTED",
    "train_y": "/content/SR_DATASET_SMALLER_128_X_128/SR_DATASET_SMALLER_HYSTO/SR_TRAIN/SR_DATA_SMALLER",
    "test_x": "/content/SR_DATASET_SMALLER_128_X_128/SR_DATASET_SMALLER_HYSTO/SR_TEST/SR_DATA_SMALLER_AUGMENTED",
    "test_y": "/content/SR_DATASET_SMALLER_128_X_128/SR_DATASET_SMALLER_HYSTO/SR_TEST/SR_DATA_SMALLER",
}

paths_hysto = {
    "train_x": "/content/SR_DATASET_SMALLER_HYSTO_128/SR_DATASET_SMALLER/SR_TRAIN/SR_DATA_AUGMENTED_SMALLER",
    "train_y": "/content/SR_DATASET_SMALLER_HYSTO_128/SR_DATASET_SMALLER/SR_TRAIN/SR_DATA_SMALLER",
    "test_x": "/content/SR_DATASET_SMALLER_HYSTO_128/SR_DATASET_SMALLER/SR_TEST/SR_DATA_AUGMENTED_SMALLER",
    "test_y": "/content/SR_DATASET_SMALLER_HYSTO_128/SR_DATASET_SMALLER/SR_TEST/SR_DATA_SMALLER",
}

# same dataset but further worsened quality
paths_wq = {
    "train_x": "/content/SR_DATASET_SMALLER_WQ/SR_DATASET_SMALLER/SR_TRAIN/SR_DATA_SMALLER_AUGMENTED",
    "train_y": "/content/SR_DATASET_SMALLER_WQ/SR_DATASET_SMALLER/SR_TRAIN/SR_DATA_SMALLER",
    "test_x": "/content/SR_DATASET_SMALLER_WQ/SR_DATASET_SMALLER/SR_TEST/SR_DATA_SMALLER_AUGMENTED",
    "test_y": "/content/SR_DATASET_SMALLER_WQ/SR_DATASET_SMALLER/SR_TEST/SR_DATA_SMALLER",
}

new_paths_128 = {
    "train_x": "/content/restoration_new_12.07/train/128x128_wq_resto",
    "train_y": "/content/restoration_new_12.07/train/128x128",
    "test_x": "/content/restoration_new_12.07/test/128x128_wq_resto",
    "test_y": "/content/restoration_new_12.07/test/128x128",
}

# generator for the restoration task
class RestorationGenerator(tf.keras.utils.Sequence):
    def __init__(self, paths, batch_size, train=True, size=(256, 256)):
        self.paths = paths
        self.batch_size = batch_size
        self.size = size
        self.image_paths_x = self._get_image_paths(
            paths["train_x"] if train else paths["test_x"]
        )
        self.image_paths_y = self._get_image_paths(
            paths["train_y"] if train else paths["test_y"]
        )
        self.on_epoch_end()

    def _get_image_paths(self, directory):
        image_paths = []
        for filename in os.listdir(directory):
            image_paths.append(os.path.join(directory, filename))
        return image_paths

    def __len__(self):
        return len(self.image_paths_x) // self.batch_size

    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        batch_x = [self.image_paths_x[i] for i in indexes]
        batch_y = [self.image_paths_y[i] for i in indexes]
        images_x = self._load_images(batch_x)
        images_y = self._load_images(batch_y)
        images_x = images_x / 255.0
        images_y = images_y / 255.0
        return images_x, images_y

    def on_epoch_end(self):
        self.indexes = list(range(len(self.image_paths_x)))
        random.shuffle(self.indexes)

    def _load_images(self, image_paths):
        images = []
        for path in image_paths:
            image = tf.keras.preprocessing.image.load_img(
                path, target_size=self.size, interpolation="bicubic"
            )
            image = tf.keras.preprocessing.image.img_to_array(image)
            images.append(image)
        return np.array(images)


high = True # depends on image size
batch_size = 64 if high else 8
train_data = RestorationGenerator(paths_128, batch_size, size=(128, 128))
test_data = RestorationGenerator(paths_128, batch_size, train=False, size=(128, 128))

# show example images and further information
def show_example(train_gen):
    index = np.random.randint(len(train_gen))
    x, y = train_gen.__getitem__(index)
    print(x.shape, y.shape, np.min(x), np.max(x), np.min(y), np.max(y))
    num_samples = x.shape[0]
    x, y = x * 255.0, y * 255.0
    sample_indices = np.random.choice(num_samples, size=4, replace=False)

    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i, ax in enumerate(axes.flatten()):
        if i < 4:
            ax.imshow(x[sample_indices[i]].astype(np.uint8))
            ax.set_title("Low-res")
        else:
            ax.imshow(y[sample_indices[i - 4]].astype(np.uint8))
            ax.set_title("High-res")
        ax.axis("off")
    plt.tight_layout()
    plt.show()


show_example(train_data)

In [None]:
show_example(test_data)

In [None]:
from typing import Optional, Tuple, Type
import tensorflow as tf
from tensorflow import keras

NAFBLOCK = "nafblock"
PLAIN = "plain"
BASELINE = "baseline"


class SimpleGate(keras.layers.Layer):
    """
    Simple Gate
    It splits the input of size (b,h,w,c) into tensors of size (b,h,w,c//factor) and returns their Hadamard product
    Parameters:
        factor: the amount by which the channels are scaled down
    """

    def __init__(self, factor: Optional[int] = 2, **kwargs) -> None:
        super().__init__(**kwargs)
        self.factor = factor

    def call(self, x: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        x = tf.expand_dims(x, axis=-1)
        return tf.reduce_prod(
            tf.concat(tf.split(x, num_or_size_splits=self.factor, axis=-2), axis=-1),
            axis=-1,
        )

    def get_config(self) -> dict:
        """Add factor to the config"""
        config = super().get_config()
        config.update({"factor": self.factor})
        return config


class ChannelAttention(keras.layers.Layer):
    """
    Channel Attention layer

    Parameters:
        channels: number of channels in input
    """

    def __init__(self, channels: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.channels = channels
        self.avg_pool = keras.layers.GlobalAveragePooling2D()
        self.conv1 = keras.layers.Conv2D(
            filters=channels // 2, kernel_size=1, activation=keras.activations.relu
        )
        self.conv2 = keras.layers.Conv2D(
            filters=channels, kernel_size=1, activation=keras.activations.sigmoid
        )

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        average_pooling = self.avg_pool(inputs)
        feature_descriptor = tf.reshape(
            average_pooling, shape=(-1, 1, 1, self.channels)
        )
        x = self.conv1(feature_descriptor)
        return inputs * self.conv2(x)

    def get_config(self) -> dict:
        """Add channels to the config"""
        config = super().get_config()
        config.update({"channels": self.channels})
        return config


class SimplifiedChannelAttention(keras.layers.Layer):
    """
    Simplified Channel Attention layer
    It is a modification of channel attention without any non-linear activations.
    Parameters:
        channels: number of channels in input
    """

    def __init__(self, channels: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.channels = channels
        self.avg_pool = keras.layers.GlobalAveragePooling2D()
        self.conv = keras.layers.Conv2D(filters=channels, kernel_size=1)

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        average_pooling = self.avg_pool(inputs)
        feature_descriptor = tf.reshape(
            average_pooling, shape=(-1, 1, 1, self.channels)
        )
        features = self.conv(feature_descriptor)
        return inputs * features

    def get_config(self) -> dict:
        """Add channels to the config"""
        config = super().get_config()
        config.update({"channels": self.channels})
        return config


class NAFBlock(keras.layers.Layer):
    """
    NAFBlock (Nonlinear Activation Free Block)

    Parameters:
        input_channels: number of channels in the input (as NAFBlock retains the input size in the output)
        factor: factor by which the channels must be increased before being reduced by simple gate.
            (Higher factor denotes higher order polynomial in multiplication. Default factor is 2)
        drop_out_rate: dropout rate
        balanced_skip_connection: adds additional trainable parameters to the skip connections.
            The parameter denotes how much importance should be given to the sub block in the skip connection.
        mode: NAFBlock has 3 mode.
            'plain' mode uses the PlainBlock.
                It is derived from the restormer block, keeping the most common components
            'baseline' mode used the BaselineBlock
                It is derived by adding layer normalization, channel attention to PlainBlock.
                It also replaces ReLU activation with GeLU in PlainBlock.
            'nafblock' mode uses the NAFBlock
                It derived from BaselineBlock by removing all the non-linear activation.
                Non-linear activations are replaced by equivalent matrix multiplication operations.
    """

    def __init__(
        self,
        factor: Optional[int] = 2,
        drop_out_rate: Optional[float] = 0.0,
        balanced_skip_connection: Optional[bool] = False,
        mode: Optional[str] = NAFBLOCK,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.factor = factor
        self.drop_out_rate = drop_out_rate
        self.balanced_skip_connection = balanced_skip_connection

        valid_mode = {PLAIN, BASELINE, NAFBLOCK}
        if mode not in valid_mode:
            raise ValueError("Mode must be one of %r." % valid_mode)
        self.mode = mode

        if self.mode == PLAIN:
            self.activation = keras.layers.Activation("relu")
        elif self.mode == BASELINE:
            self.activation = keras.layers.Activation("gelu")
        else:
            self.activation = SimpleGate(factor)

        self.dropout1 = keras.layers.Dropout(drop_out_rate)

        self.dropout2 = keras.layers.Dropout(drop_out_rate)

        self.layer_norm1 = None
        self.layer_norm2 = None
        if self.mode in [NAFBLOCK, BASELINE]:
            self.layer_norm1 = keras.layers.LayerNormalization()
            self.layer_norm2 = keras.layers.LayerNormalization()

    def get_dw_channel(self, input_channels: int) -> int:
        if self.mode == NAFBLOCK:
            return input_channels * self.factor
        else:
            return input_channels

    def get_ffn_channel(self, input_channels: int) -> int:
        return input_channels * self.factor

    def get_attention_layer(
        self, input_shape: tf.TensorShape
    ) -> Optional[keras.layers.Layer]:
        input_channels = input_shape[-1]
        if self.mode == NAFBLOCK:
            return SimplifiedChannelAttention(input_channels)
        elif self.mode == BASELINE:
            return ChannelAttention(input_channels)
        else:
            return None

    def build(self, input_shape: tf.TensorShape) -> None:
        input_channels = input_shape[-1]
        dw_channel = self.get_dw_channel(input_channels)

        self.conv1 = keras.layers.Conv2D(filters=dw_channel, kernel_size=1, strides=1)
        self.dconv2 = keras.layers.Conv2D(
            filters=dw_channel,
            kernel_size=3,
            padding="same",
            strides=1,
            groups=dw_channel,
        )

        self.attention = self.get_attention_layer(input_shape)

        self.conv3 = keras.layers.Conv2D(
            filters=input_channels, kernel_size=1, strides=1
        )

        ffn_channel = self.get_ffn_channel(input_channels)

        self.conv4 = keras.layers.Conv2D(filters=ffn_channel, kernel_size=1, strides=1)
        self.conv5 = keras.layers.Conv2D(
            filters=input_channels, kernel_size=1, strides=1
        )

        self.beta = tf.Variable(
            tf.ones((1, 1, 1, input_channels)), trainable=self.balanced_skip_connection
        )
        self.gamma = tf.Variable(
            tf.ones((1, 1, 1, input_channels)), trainable=self.balanced_skip_connection
        )

    def call_block1(self, inputs: tf.Tensor) -> tf.Tensor:
        x = inputs
        if self.layer_norm1 != None:
            x = self.layer_norm1(x)
        x = self.conv1(x)
        x = self.dconv2(x)
        x = self.activation(x)
        if self.attention != None:
            x = self.attention(x)
        x = self.conv3(x)
        x = self.dropout1(x)
        return x

    def call_block2(self, inputs: tf.Tensor) -> tf.Tensor:
        y = inputs
        if self.layer_norm2 != None:
            y = self.layer_norm2(y)
        y = self.conv4(y)
        y = self.activation(y)
        y = self.conv5(y)
        y = self.dropout2(y)
        return y

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        # Block 1
        x = self.call_block1(inputs)

        # Residual connection
        x = inputs + self.beta * x

        # Block 2
        y = self.call_block2(x)

        # Residual connection
        y = x + self.gamma * y

        return y

    def get_config(self) -> dict:
        """Add constructor arguments to the config"""
        config = super().get_config()
        config.update(
            {
                "factor": self.factor,
                "drop_out_rate": self.drop_out_rate,
                "balanced_skip_connection": self.balanced_skip_connection,
                "mode": self.mode,
            }
        )
        return config


class PixelShuffle(keras.layers.Layer):
    """
    PixelShuffle Layer

    Given input of size (H,W,C), it will generate an output
    of size
    (
        H*pixel_shuffle_factor,
        W*pixel_shuffle_factor,
        channels//(pixel_shuffle_factor**2)
    )

    Wrapper Class for tf.nn.depth_to_space
    """

    def __init__(self, upscale_factor: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.upscale_factor = upscale_factor

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        return tf.nn.depth_to_space(inputs, self.upscale_factor)

    def get_config(self) -> dict:
        """Add upscale factor to the config"""
        config = super().get_config()
        config.update({"upscale_factor": self.upscale_factor})
        return config


class UpScale(keras.layers.Layer):
    """
    UpScale Layer

    Given channels and pixel_shuffle_factor as input, it will generate an output
    of size
    (
        H*pixel_shuffle_factor,
        W*pixel_shuffle_factor,
        channels//(pixel_shuffle_factor**2)
    )
    While giving input, make sure that (pixel_shuffle_factor**2) divides channels
    """

    def __init__(self, channels: int, pixel_shuffle_factor: int, **kwargs) -> None:
        super().__init__(**kwargs)
        self.channels = channels
        self.pixel_shuffle_factor = pixel_shuffle_factor

        if channels % (pixel_shuffle_factor**2) != 0:
            raise ValueError(
                f"Number of channels must divide square of pixel_shuffle_factor"
                f"In the constructor {channels} channels and "
                f"{pixel_shuffle_factor} pixel_shuffle_factor was passed"
            )

        self.conv = keras.layers.Conv2D(
            channels, kernel_size=1, strides=1, use_bias=False
        )
        self.pixel_shuffle = PixelShuffle(pixel_shuffle_factor)

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        return self.pixel_shuffle(self.conv(inputs))

    def get_config(self) -> dict:
        """Add channels and pixel_shuffle_factor to the config"""
        config = super().get_config()
        config.update(
            {
                "channels": self.channels,
                "pixel_shuffle_factor": self.pixel_shuffle_factor,
            }
        )
        return config


class NAFNet(keras.models.Model):
    """
    NAFNet

    The input channels will be mapped to the number of filters passed.
    After each down block, the number of filters will increase by a factor of 2.
    After each up block, the number of filters will decrease by a factor of 2.
    And finally the filters will be mapped back to the initial input size.

    Overwrite create_encoder_and_down_blocks, create_decoder_and_up_blocks, create_middle_blocks
    to add your own implementation for these blocks. Overwrite get_blocks to use your custom block
    in NAFNet. But make sure to follow the restrictions on these methods and blocks.

    Parameters:
        filters: denotes the starting filter size.
        middle_block_num: (int) denotes the number of middle blocks.
            Each middle block is a single NAFBlock unit.
        encoder_block_nums: (tuple) the tuple size denotes the number of encoder blocks.
            Each tuple entry denotes the number of NAFBlocks in the corresponding encoder block.
            len(encoder_block_nums) should be the same as the len(decoder_block_nums)
        decoder_block_nums: (tuple) the tuple size denotes the number of decoder blocks.
            Each tuple entry denotes the number of NAFBlocks in the corresponding decoder block.
            len(decoder_block_nums) should be the same as the len(encoder_block_nums)
        block_type: (str) denotes what block to use in NAFNet
    """

    def __init__(
        self,
        filters: Optional[int] = 16,
        middle_block_num: Optional[int] = 1,
        encoder_block_nums: Optional[Tuple[int]] = (1, 1, 1, 1),
        decoder_block_nums: Optional[Tuple[int]] = (1, 1, 1, 1),
        block_type: Optional[str] = NAFBLOCK,
        drop_out_rate: Optional[float] = 0.0,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        self.filters = filters
        self.middle_block_num = middle_block_num
        self.encoder_block_nums = encoder_block_nums
        self.decoder_block_nums = decoder_block_nums
        self.block_type = block_type
        self.drop_out_rate = drop_out_rate

        self.intro = keras.layers.Conv2D(filters=filters, kernel_size=3, padding="same")

        self.encoders = []
        self.decoders = []
        self.ups = []
        self.downs = []

        if len(encoder_block_nums) != len(decoder_block_nums):
            raise ValueError(
                "The number of encoder blocks should match the number of decoder blocks"
                f"In the constructor {len(encoder_block_nums)} encoder blocks"
                f" and {len(decoder_block_nums)} were passed."
            )

        channels = filters
        channels = self.create_encoder_and_down_blocks(channels, encoder_block_nums)

        if len(self.encoders) != len(self.downs):
            raise ValueError(
                "The number of encoder blocks should match the number of down blocks"
                f"In `create_encoder_and_down_blocks` {len(self.encoders)} encoder blocks"
                f" and {len(self.downs)} down blocks were created."
            )

        self.create_middle_blocks(middle_block_num)

        self.create_decoder_and_up_blocks(channels, decoder_block_nums)

        if len(self.decoders) != len(self.ups):
            raise ValueError(
                "The number of decoder blocks should match the number of up blocks"
                f"In `create_decoder_and_up_blocks` {len(self.decoders)} decoder blocks"
                f" and {len(self.ups)} up blocks were created."
            )

        if len(encoder_block_nums) != len(decoder_block_nums):
            raise ValueError(
                "The number of encoder blocks should match the number of decoder blocks"
                f"In `create_encoder_and_down_blocks` {len(self.encoders)} encoder blocks were created."
                f"In `create_decoder_and_up_blocks` {len(self.decoders)} decoder blocks were created."
            )

        # The height and width of the image should be a
        #  multiple of self.expected_image_scale
        # If that is not the case, it will be fixed in the call(...) method.
        self.expected_image_scale = 2 ** len(self.encoders)

    def build(self, input_shape: tf.TensorShape) -> None:
        input_channels = input_shape[-1]
        self.ending = keras.layers.Conv2DTranspose(
            filters=input_channels, kernel_size=3, padding="same"
        )

    def get_block(self) -> keras.layers.Layer:
        """
        Returns the block to be used in NAFNet
        Can be overriden to use custom blocks in NAFNet
        """
        return NAFBlock(mode=self.block_type, drop_out_rate=self.drop_out_rate)

    def create_encoder_and_down_blocks(
        self,
        channels: int,
        encoder_block_nums: Optional[Tuple[int]],
    ) -> int:
        """
        Creates equal number of encoder blocks and down blocks.
        """

        for num in encoder_block_nums:
            self.encoders.append(
                keras.models.Sequential([self.get_block() for _ in range(num)])
            )
            self.downs.append(
                keras.layers.Conv2D(2 * channels, kernel_size=2, strides=2)
            )
            channels *= 2
        return channels

    def create_middle_blocks(self, middle_block_num: int) -> None:
        """
        Creates middle blocks in NAFNet
        """
        self.middle_blocks = keras.models.Sequential(
            [self.get_block() for _ in range(middle_block_num)]
        )

    def create_decoder_and_up_blocks(
        self,
        channels: int,
        decoder_block_nums: Optional[Tuple[int]],
    ) -> int:
        """
        Creates equal number of decoder blocks and up blocks.
        """
        for num in decoder_block_nums:
            self.ups.append(UpScale(2 * channels, pixel_shuffle_factor=2))
            channels = channels // 2
            self.decoders.append(
                keras.models.Sequential([self.get_block() for _ in range(num)])
            )
        return channels

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        _, H, W, _ = inputs.shape
        # Scale the image to the next nearest multiple of self.expected_image_scale
        inputs = self.fix_input_shape(inputs)

        x = self.intro(inputs)

        encoder_outputs = []
        for encoder, down in zip(self.encoders, self.downs):
            x = encoder(x)
            encoder_outputs.append(x)
            x = down(x)

        x = self.middle_blocks(x)

        for decoder, up, encoder_output in zip(
            self.decoders, self.ups, encoder_outputs[::-1]
        ):
            x = up(x)
            # Residual connection of encoder blocks with decoder blocks
            x = x + encoder_output
            x = decoder(x)

        x = self.ending(x)
        # Residual connection of inputs with output
        x = x + inputs

        # Crop back to the original size
        return x[:, :H, :W, :]

    def fix_input_shape(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        Fixes input shape for NAFNet
        This is because NAFNet can only work with images whose shape is
         multiple of 2**(no. of encoder blocks)
        Hence the image is padded to match that shape
        """
        _, H, W, _ = inputs.shape

        if H is None:
            H = 256

        # Calculating how much padding is required
        height_padding, width_padding = 0, 0
        if H % self.expected_image_scale != 0:
            height_padding = self.expected_image_scale - H % self.expected_image_scale
        if W % self.expected_image_scale != 0:
            width_padding = self.expected_image_scale - W % self.expected_image_scale

        paddings = tf.constant(
            [[0, 0], [0, height_padding], [0, width_padding], [0, 0]]
        )
        return tf.pad(inputs, paddings)

    def get_config(self) -> dict:
        """Add upscale factor to the config"""
        config = super().get_config()
        config.update(
            {
                "filters": self.filters,
                "middle_block_num": self.middle_block_num,
                "encoder_block_nums": self.encoder_block_nums,
                "decoder_block_nums": self.decoder_block_nums,
                "block_type": self.block_type,
            }
        )
        return config


def create_NAFNet(
    filters: Optional[int] = 16,
    middle_block_num: Optional[int] = 1,
    encoder_block_nums: Optional[tuple] = (1, 1, 1, 1),
    decoder_block_nums: Optional[tuple] = (1, 1, 1, 1),
    block_type: Optional[str] = NAFBLOCK,
    drop_out_rate: Optional[float] = 0.0,
    input_shape: Optional[tuple] = (1, 256, 256, 3),
    summary: Optional[bool] = True,
):
    model = NAFNet(
        filters=filters,
        middle_block_num=middle_block_num,
        encoder_block_nums=encoder_block_nums,
        decoder_block_nums=decoder_block_nums,
        block_type=block_type,
        drop_out_rate=drop_out_rate,
    )
    if summary:
        dummy = tf.ones(input_shape)
        model(dummy)
        model.summary()
    return model


model = create_NAFNet(
    filters=32,
    middle_block_num=2,
    encoder_block_nums=(1, 1, 1, 28),
    input_shape=(1, 128, 128, 3),
    drop_out_rate=0.1,
    decoder_block_nums=(1, 1, 1, 1),
)

In [None]:
#!pip install import-ipynb
#!unzip /content/drive/MyDrive/mirnet_v2
import tensorflow as tf
from keras.optimizers import optimizer
import time
import csv
import cv2


# also tested with mirnet_v2
# import import_ipynb
# from mirnet_v2 import get_mirnet

# metrics
def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, 1.0)


def ssim_metric(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, 1.0)


@tf.function
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
    # stepweight decay

    p.assign(p * (1 - lr * wd))

    # weight update

    update = (
        tf.raw_ops.LinSpace(start=1.0, stop=0.0, num=1, name=None)[0] * exp_avg
        + (1 - tf.raw_ops.LinSpace(start=1.0, stop=0.0, num=1, name=None)[0]) * grad
    )
    p.assign_add(tf.sign(update) * -lr)

    # decay the momentum running average coefficient

    exp_avg.assign(exp_avg * beta2 + grad * (1 - beta2))


def lerp(start, end, weight):
    return start + weight * (end - start)


def sparse_lerp(start, end, weight):
    # Mathematically equivalent, but you can't subtract a dense Tensor from sparse
    # IndexedSlices, so we have to flip it around.
    return start + weight * -(start - end)

# lion optimizer
class Lion(optimizer.Optimizer):
    """Optimizer that implements the Lion algorithm.
    Lion was published in the paper "Symbolic Discovery of Optimization Algorithms"
    which is available at https://arxiv.org/abs/2302.06675
    Args:
      learning_rate: A `tf.Tensor`, floating point value, a schedule that is a
        `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
        that takes no arguments and returns the actual value to use. The
        learning rate. Defaults to 1e-4.
      beta_1: A float value or a constant float tensor, or a callable
        that takes no arguments and returns the actual value to use. Factor
         used to interpolate the current gradient and the momentum. Defaults to 0.9.
      beta_2: A float value or a constant float tensor, or a callable
        that takes no arguments and returns the actual value to use. The
        exponential decay rate for the momentum. Defaults to 0.99.
    Notes:
    The sparse implementation of this algorithm (used when the gradient is an
    IndexedSlices object, typically because of `tf.gather` or an embedding
    lookup in the forward pass) does apply momentum to variable slices even if
    they were not used in the forward pass (meaning they have a gradient equal
    to zero). Momentum decay (beta2) is also applied to the entire momentum
    accumulator. This means that the sparse behavior is equivalent to the dense
    behavior (in contrast to some momentum implementations which ignore momentum
    unless a variable slice was actually used).
    """

    def __init__(
        self,
        learning_rate=1e-4,
        beta_1=0.9,
        beta_2=0.99,
        weight_decay=None,
        clipnorm=None,
        clipvalue=None,
        global_clipnorm=None,
        jit_compile=True,
        name="Lion",
        **kwargs,
    ):
        super().__init__(
            name=name,
            weight_decay=weight_decay,
            clipnorm=clipnorm,
            clipvalue=clipvalue,
            global_clipnorm=global_clipnorm,
            jit_compile=jit_compile,
            **kwargs,
        )
        self._learning_rate = self._build_learning_rate(learning_rate)
        self.beta_1 = beta_1
        self.beta_2 = beta_2

    def build(self, var_list):
        """Initialize optimizer variables.
        var_list: list of model variables to build Lion variables on.
        """
        super().build(var_list)
        if hasattr(self, "_built") and self._built:
            return
        self._built = True
        self._emas = []
        for var in var_list:
            self._emas.append(
                self.add_variable_from_reference(
                    model_variable=var, variable_name="ema"
                )
            )

    def update_step(self, gradient, variable):
        """Update step given gradient and the associated model variable."""
        lr = tf.cast(self.learning_rate, variable.dtype)
        beta_1 = tf.constant(self.beta_1, shape=(1,))
        beta_2 = tf.constant(self.beta_2, shape=(1,))

        var_key = self._var_key(variable)
        ema = self._emas[self._index_dict[var_key]]

        if isinstance(gradient, tf.IndexedSlices):
            # Sparse gradients.
            lerp_fn = sparse_lerp
        else:
            # Dense gradients.
            lerp_fn = lerp

        update = lerp_fn(ema, gradient, 1 - beta_1)
        update = tf.sign(update)
        variable.assign_sub(update * lr)

        ema.assign(lerp_fn(ema, gradient, 1 - beta_2))

    def get_config(self):
        config = super().get_config()

        config.update(
            {
                "learning_rate": self._serialize_hyperparameter(self._learning_rate),
                "beta_1": self.beta_1,
                "beta_2": self.beta_2,
            }
        )
        return config


def fast_check(generator, amount: int = 20):
    count = 0
    for item in generator:
        yield item
        count += 1
        if count == amount:
            break

# Loss functions
class SSIMLoss(tf.keras.losses.Loss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def call(self, y_true, y_pred):
        return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))


class PSNRLoss(tf.keras.losses.Loss):
    """Implementation of Negative PSNR Loss

    References:

    1. [HINet: Half Instance Normalization Network for Image Restoration](https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Chen_HINet_Half_Instance_Normalization_Network_for_Image_Restoration_CVPRW_2021_paper.pdf)
    """

    def __init__(self, max_val: float = 1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_val = max_val

    def call(self, y_true, y_pred):
        return -tf.image.psnr(y_true, y_pred, max_val=self.max_val)


class CharbonnierLoss(tf.keras.losses.Loss):
    """The Charbonnier implemented as a `tf.keras.losses.Loss`.

    The Charbonnier loss, also known as the "smooth L1 loss," is a loss function that is used in
    image processing and computer vision tasks to balance the trade-off between the Mean Squared
    Error (MSE) and the Mean Absolute Error (MAE). It is defined as

    $$L=\sqrt{\left(\left(x^{\wedge} 2+\varepsilon^{\wedge} 2\right)\right)}$$

    where x is the error and ε is a small positive constant (typically on the order of 0.001). It
    is less sensitive to outliers than the mean squared error and less computationally expensive
    than the mean absolute error.

    Args:
        epsilon (float): a small positive constant.
    """

    def __init__(self, epsilon: float, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epsilon = tf.convert_to_tensor(epsilon)

    def call(self, y_true, y_pred):
        squared_difference = tf.square(y_true - y_pred)
        return tf.reduce_mean(tf.sqrt(squared_difference + tf.square(self.epsilon)))

# combination loss with weighted "counter loss"
class CombinedLoss(tf.keras.losses.Loss):
    def __init__(
        self,
        counter_loss_weight: float = 0.2,
        charbonnier_loss_weight: float = 0.8,
        psnr_max_val: float = 1.0,
        charbonnier_epsilon: float = 1e-3,
        counter_loss_fn: str = "ssim",
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.psnr_loss = PSNRLoss(max_val=psnr_max_val)
        self.ssim_loss = SSIMLoss()
        self.charbonnier_loss = CharbonnierLoss(
            epsilon=charbonnier_epsilon, reduction=tf.keras.losses.Reduction.SUM
        )
        self.counter_loss_weight = counter_loss_weight
        self.charbonnier_loss_weight = charbonnier_loss_weight
        self.counter_loss_fn = counter_loss_fn

    def call(self, y_true, y_pred):
        counter_loss = (
            self.psnr_loss(y_true, y_pred)
            if self.counter_loss_fn == "psnr"
            else self.ssim_loss(y_true, y_pred)
        )
        charbonnier_loss = self.charbonnier_loss(y_true, y_pred)
        combined_loss = (
            self.counter_loss_weight * counter_loss
            + self.charbonnier_loss_weight * charbonnier_loss
        )
        return combined_loss


# usage of tf.GradientTape() because of model contraints fitting generator
@tf.function
def train_step(model, optimizer, loss_fn, x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x, training=True)
        loss = loss_fn(y, y_pred)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss, psnr_metric(y, y_pred), ssim_metric(y, y_pred)


@tf.function
def test_step(model, loss_fn, x, y):
    y_pred = model(x, training=False)
    loss = loss_fn(y, y_pred)
    return loss, psnr_metric(y, y_pred), ssim_metric(y, y_pred)


def metric_avg(psnr, ssim, loss, decimal_places):
    avg_psnr = tf.reduce_mean(psnr)
    avg_ssim = tf.reduce_mean(ssim)
    loss = tf.reduce_mean(loss)
    avg_psnr_rounded = tf.round(avg_psnr * 10**decimal_places) / (
        10**decimal_places
    )
    avg_ssim_rounded = tf.round(avg_ssim * 10**decimal_places) / (
        10**decimal_places
    )
    loss_rounded = tf.round(loss * 10**decimal_places) / (10**decimal_places)
    return avg_psnr_rounded, avg_ssim_rounded, loss_rounded

# quick prediction during training
def quick_pred(data_gen, amount: int = 1, display: bool = True):
    predictions = []
    for i in range(amount):
        for x, y in data_gen:
            for j in range(len(x) if len(x) <= 24 else 24):
                input_data = np.expand_dims(x[j], axis=0)
                pred = model.predict(input_data)
                pred = np.clip(pred * 255, 0, 255).astype("uint8")
                predictions.append(pred)
            break
    if display:
        for i in range(len(x) if len(x) <= 24 else 24):
            fig, axs = plt.subplots(1, 3, figsize=(10, 5))
            axs[0].imshow(x[i])
            axs[0].set_title("Input")
            axs[1].imshow(y[i])
            axs[1].set_title("Ground Truth")
            axs[2].imshow(predictions[i][0])
            axs[2].set_title("Prediction")
            plt.show()
    return predictions

# write training progress to csv
def write_to_csv(
    loss_total,
    psnr_total,
    ssim_total,
    loss_epoch=None,
    psnr_epoch=None,
    ssim_epoch=None,
    train=True,
    name="results.csv",
):
    with open(name, "w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["epoch", "loss", "psnr", "ssim"])
        for i in range(len(loss_total)):
            writer.writerow(
                [
                    i + 1,
                    loss_total[i],
                    psnr_total[i],
                    ssim_total[i],
                ]
            )
        if train:
            writer.writerow(["epoch", "loss", "psnr", "ssim"])
            for i in range(len(loss_epoch)):
                writer.writerow(
                    [
                        i + 1,
                        loss_epoch[i],
                        psnr_epoch[i],
                        ssim_epoch[i],
                    ]
                )

# training loop
count = 0
save = True
optimizer = Lion(1e-5, weight_decay=1e-6)
loss_fn = CharbonnierLoss(epsilon=1e-3, reduction=tf.keras.losses.Reduction.SUM)
num_epochs = 5
loss_fn = CombinedLoss(
    counter_loss_weight=0.3,
    charbonnier_loss_weight=0.7,
    psnr_max_val=1.0,
    charbonnier_epsilon=1e-3,
)
max_steps = len(train_data)
decimal_places = 4
epoch_start_time = time.time()
# initialize lists to save training progress
loss_total, psnr_total, ssim_total, loss_epoch, psnr_epoch, ssim_epoch = (
    [],
    [],
    [],
    [],
    [],
    [],
)
save_all = True
save = True
for epoch in range(num_epochs):
    step = 0
    for x, y in train_data:
        step += 1
        if epoch > 1 and loss_epoch[-1] > loss_epoch[-2]:
            optimizer.learning_rate = optimizer.learning_rate * 0.9
            print("changed lr to ", optimizer.learning_rate * 0.9)
        loss, psnr, ssim = train_step(model, optimizer, loss_fn, x, y)
        avg_psnr, avg_ssim, avg_loss = metric_avg(psnr, ssim, loss, decimal_places)
        if save_all and step % 10 == 0:
            loss_total.append(avg_loss)
            psnr_total.append(avg_psnr)
            ssim_total.append(avg_ssim)
        res = list(map(lambda x: x.numpy(), [loss, avg_psnr, avg_ssim]))
        current_time = time.time()
        elapsed_time = current_time - epoch_start_time
        average_time_per_step = elapsed_time / step
        remaining_steps = max_steps - step
        remaining_time = remaining_steps * average_time_per_step
        print(
            f"\rEpoch: {epoch + 1}/{num_epochs}, Loss: {res[0]:.{decimal_places}f}, PSNR: {res[1]:.{decimal_places}f}, SSIM: {res[2]:.{decimal_places}f}, Step: {step}/{max_steps}, Remaining Time: {remaining_time:.2f} seconds/ {remaining_time/60:.2f} minutes -- elapsed: {elapsed_time:.1f} seconds",
            end=" " * 5,
            flush=True,
        )
    loss_epoch.append(avg_loss)
    psnr_epoch.append(avg_psnr)
    ssim_epoch.append(avg_ssim)
    for x, y in test_data:
        loss_e, psnr_e, ssim_e = test_step(model, loss_fn, x, y)
        loss_e, psnr_e, ssim_e = metric_avg(psnr_e, ssim_e, loss_e, decimal_places)
    print(
        f"\nEpoch: {epoch + 1}, Eval Loss: {loss_e:.{decimal_places}f}, Eval PSNR: {psnr_e:.{decimal_places}f}, Eval SSIM: {ssim_e:.{decimal_places}f}"
    )
    quick_pred(test_data, amount=2, display=True)
    if save:
        model.save(f"NafNet_{str(epoch + 1)}_2", save_format="tf")
        print("Model saved after epoch", epoch + 1)
    write_to_csv([loss_e], [psnr_e], [ssim_e], train=False, name="results_eval.csv")
model.save("NafNet", save_format="tf")
write_to_csv(loss_total, psnr_total, ssim_total, loss_epoch, psnr_epoch, ssim_epoch)
print("Training complete, model saved, results saved.")

In [None]:
# move models to drive
import shutil

# shutil.copy("model.h5", "/content/drive/MyDrive/model_resto.h5")
# shutil.move("NafNet_5", "/content/drive/MyDrive/NafNet_5_mixed_wq_fine")
# shutil.move("NafNet", "/content/drive/MyDrive/NafNet_mixed_wq_fine")
shutil.move("NafNet_2_2", "/content/drive/MyDrive/NafNet_new_13.07_5E6F")

'/content/drive/MyDrive/NafNet_new_13.07_5E6F'

In [None]:
model = tf.keras.models.load_model("/content/drive/MyDrive/NafNet_new_13.07_5E6F")
model.compile(optimizer, loss_fn, metrics=[psnr_metric, ssim_metric])
loss_e, psnr_e, ssim_e = model.evaluate(test_data)
loss_e, psnr_e, ssim_e = metric_avg(psnr_e, ssim_e, loss_e, decimal_places)
write_to_csv([loss_e], [psnr_e], [ssim_e], train=False, name="results_eval.csv")