In [None]:
#!unzip "/content/drive/MyDrive/SR_DATASET.zip"
#!unzip "/content/drive/MyDrive/SR_DATASET_SMALLER.zip"
!unzip "/content/drive/MyDrive/resto_data.zip"

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

''' paths = {
    "train_x": "/content/SR_DATASET/SR_TRAIN/SR_DATA_AUGMENTED",
    "train_y":"/content/SR_DATASET/SR_TRAIN/SR_DATA",
    "test_x":"/content/SR_DATASET/SR_TEST/SR_DATA_AUGMENTED",
    "test_y": "/content/SR_DATASET/SR_TEST/SR_DATA"
} '''

''' paths = {
    "train_x": "/content/SR_DATASET_SMALLER/SR_TRAIN/SR_DATA_AUGMENTED_SMALLER",
    "train_y":"/content/SR_DATASET_SMALLER/SR_TRAIN/SR_DATA_SMALLER",
    "test_x":"/content/SR_DATASET_SMALLER/SR_TEST/SR_DATA_AUGMENTED_SMALLER",
    "test_y": "/content/SR_DATASET_SMALLER/SR_TEST/SR_DATA_SMALLER"
}
 '''
paths = {
    "train_x": "/content/resto_data/resto_train/resto_augmented",
    "train_y": "/content/resto_data/resto_train/resto",
    "test_x": "/content/resto_data/resto_test/resto_augmented",
    "test_y": "/content/resto_data/resto_test/resto"
}

class SuperResolutionDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, paths, batch_size, train=True, size=(256, 256)):
        self.paths = paths
        self.batch_size = batch_size
        self.size = size
        self.image_paths_x = self._get_image_paths(paths["train_x"] if train else paths["test_x"])
        self.image_paths_y = self._get_image_paths(paths["train_y"] if train else paths["test_y"])
        self.on_epoch_end()

    def _get_image_paths(self, directory):
        image_paths = []
        for filename in os.listdir(directory):
            image_paths.append(os.path.join(directory, filename))
        return image_paths

    def __len__(self):
        return len(self.image_paths_x) // self.batch_size

    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        batch_x = [self.image_paths_x[i] for i in indexes]
        batch_y = [self.image_paths_y[i] for i in indexes]
        images_x = self._load_images(batch_x)
        images_y = self._load_images(batch_y)
        images_x = images_x / 255.0
        images_y = images_y / 255.0
        return images_x, images_y

    def on_epoch_end(self):
        self.indexes = list(range(len(self.image_paths_x)))
        random.shuffle(self.indexes)

    def _load_images(self, image_paths):
        images = []
        for path in image_paths:
            image = tf.keras.preprocessing.image.load_img(path, target_size=self.size)
            image = tf.keras.preprocessing.image.img_to_array(image)
            images.append(image)
        return np.array(images)

high = True
batch_size = 4 if high else 1
train_data = SuperResolutionDataGenerator(paths, batch_size)
test_data = SuperResolutionDataGenerator(paths, batch_size, train=False)

def show_example(train_gen):
    index = np.random.randint(len(train_gen))
    x, y = train_gen.__getitem__(index)
    print(x.shape, y.shape, np.min(x), np.max(x), np.min(y), np.max(y))
    num_samples = x.shape[0]
    x,y = x * 255.0, y * 255.0
    sample_indices = np.random.choice(num_samples, size=4, replace=False)

    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i, ax in enumerate(axes.flatten()):
        if i < 4:
            ax.imshow(x[sample_indices[i]].astype(np.uint8))
            ax.set_title("Low-res")
        else:
            ax.imshow(y[sample_indices[i-4]].astype(np.uint8))
            ax.set_title("High-res")
        ax.axis("off")
    plt.tight_layout()
    plt.show()

show_example(train_data)

In [None]:
show_example(test_data)

In [None]:
from typing import Dict, List, Optional, Tuple, Union
import tensorflow as tf
import numpy as np


class DownBlock(tf.keras.layers.Layer):
    """Submodule of `DownSampleBlock`.

    Reference:

    1. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    2. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L130)

    Args:
        channels (int): number of input channels.
        channel_factor (float): factor by which number of the number of output channels vary.
    """

    def __init__(self, channels: int, channel_factor: float, *args, **kwargs) -> None:
        super(DownBlock, self).__init__(*args, **kwargs)

        self.channels = channels
        self.channel_factor = channel_factor

        self.average_pool = tf.keras.layers.AveragePooling2D(pool_size=2, strides=2)
        self.conv = tf.keras.layers.Conv2D(
            int(channels * channel_factor), kernel_size=1, strides=1, padding="same"
        )

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        return self.conv(self.average_pool(inputs))

    def get_config(self) -> Dict:
        return {"channels": self.channels, "channel_factor": self.channel_factor}


class DownSampleBlock(tf.keras.layers.Layer):
    """Layer for downsampling feature map for the Multi-scale Residual Block.

    Reference:

    1. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    2. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L142)

    Args:
        channels (int): number of input channels.
        scale_factor (int): number of downsample operations.
        channel_factor (float): factor by which number of the number of output channels vary.
    """

    def __init__(
        self, channels: int, scale_factor: int, channel_factor: float, *args, **kwargs
    ) -> None:
        super(DownSampleBlock, self).__init__(*args, **kwargs)

        self.channels = channels
        self.scale_factor = scale_factor
        self.channel_factor = channel_factor

        self.layers = []
        for _ in range(int(np.log2(scale_factor))):
            self.layers.append(DownBlock(channels, channel_factor))
            channels = int(channels * channel_factor)

    def call(self, x: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x

    def get_config(self) -> Dict:
        return {
            "channels": self.channels,
            "channel_factor": self.channel_factor,
            "scale_factor": self.scale_factor,
        }

class MultiScaleResidualBlock(tf.keras.layers.Layer):
    """Implementation of the Multi-scale Residual Block.

    The Multi-scale Residual Block mechanism of collecting multiscale spatial information.
    This block forms the core component of the recursive residual design of MIRNet-v2.
    The key advantages of MRB are:

    - It is capable of generating a spatially-precise output by maintaining high-resolution representations, while receiving rich contextual information from low-resolutions.

    - It allows contextualized-information transfer from the low-resolution streams to consolidate the high-resolution features.

    Reference:

    1. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    2. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L189)

    Args:
        channels (int): number of channels in the feature map.
        channel_factor (float): factor by which number of the number of output channels vary.
        groups (int): number of groups in which the input is split along the
            channel axis in the convolution layers.
    """

    def __init__(
        self, channels: int, channel_factor: float, groups: int, *args, **kwargs
    ):
        super().__init__(*args, **kwargs)

        self.channels = channels
        self.channel_factor = channel_factor
        self.groups = groups

        # Residual Context Blocks
        self.rcb_top = ResidualContextBlock(
            int(channels * channel_factor**0), groups=groups
        )
        self.rcb_middle = ResidualContextBlock(
            int(channels * channel_factor**1), groups=groups
        )
        self.rcb_bottom = ResidualContextBlock(
            int(channels * channel_factor**2), groups=groups
        )

        # Downsample Blocks
        self.down_2 = DownSampleBlock(
            channels=int((channel_factor**0) * channels),
            scale_factor=2,
            channel_factor=channel_factor,
        )
        self.down_4_1 = DownSampleBlock(
            channels=int((channel_factor**0) * channels),
            scale_factor=2,
            channel_factor=channel_factor,
        )
        self.down_4_2 = DownSampleBlock(
            channels=int((channel_factor**1) * channels),
            scale_factor=2,
            channel_factor=channel_factor,
        )

        # UpSample Blocks
        self.up21_1 = UpSampleBlock(
            channels=int((channel_factor**1) * channels),
            scale_factor=2,
            channel_factor=channel_factor,
        )
        self.up21_2 = UpSampleBlock(
            channels=int((channel_factor**1) * channels),
            scale_factor=2,
            channel_factor=channel_factor,
        )
        self.up32_1 = UpSampleBlock(
            channels=int((channel_factor**2) * channels),
            scale_factor=2,
            channel_factor=channel_factor,
        )
        self.up32_2 = UpSampleBlock(
            channels=int((channel_factor**2) * channels),
            scale_factor=2,
            channel_factor=channel_factor,
        )

        # SKFF Blocks
        self.skff_top = SelectiveKernelFeatureFusion(
            channels=int(channels * channel_factor**0)
        )
        self.skff_middle = SelectiveKernelFeatureFusion(
            channels=int(channels * channel_factor**1)
        )

        # Convolution
        self.conv_out = tf.keras.layers.Conv2D(channels, kernel_size=1, padding="same")

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        x_top = inputs
        x_middle = self.down_2(x_top)
        x_bottom = self.down_4_2(self.down_4_1(x_top))

        x_top = self.rcb_top(x_top)
        x_middle = self.rcb_middle(x_middle)
        x_bottom = self.rcb_bottom(x_bottom)

        x_middle = self.skff_middle([x_middle, self.up32_1(x_bottom)])
        x_top = self.skff_top([x_top, self.up21_1(x_middle)])

        x_top = self.rcb_top(x_top)
        x_middle = self.rcb_middle(x_middle)
        x_bottom = self.rcb_bottom(x_bottom)

        x_middle = self.skff_middle([x_middle, self.up32_2(x_bottom)])
        x_top = self.skff_top([x_top, self.up21_2(x_middle)])

        output = self.conv_out(x_top)
        output = output + inputs

        return output

    def get_config(self) -> Dict:
        return {
            "channels": self.channels,
            "channel_factor": self.channel_factor,
            "groups": self.groups,
        }


class ContextBlock(tf.keras.layers.Layer):
    """Submodule of the Residual Contextual Block.

    Reference:

    1. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    2. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L57)

    Args:
        channels (int): number of channels in the feature map.
    """

    def __init__(self, channels: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.channels = channels

        self.mask_conv = tf.keras.layers.Conv2D(1, kernel_size=1, padding="same")

        self.channel_add_conv_1 = tf.keras.layers.Conv2D(
            channels, kernel_size=1, padding="same"
        )
        self.channel_add_conv_2 = tf.keras.layers.Conv2D(
            channels, kernel_size=1, padding="same"
        )

        self.softmax = tf.keras.layers.Softmax(axis=1)
        self.leaky_relu = tf.keras.layers.LeakyReLU(alpha=0.2)

    def modeling(self, inputs: tf.Tensor) -> tf.Tensor:
        _, height, width, channels = [
            tf.shape(inputs)[_shape_idx] if _shape is None else _shape
            for _shape_idx, _shape in enumerate(inputs.shape.as_list())
        ]
        reshaped_inputs = tf.expand_dims(
            tf.reshape(inputs, (-1, channels, height * width)), axis=1
        )

        context_mask = self.mask_conv(inputs)
        context_mask = tf.reshape(context_mask, (-1, height * width, 1))
        context_mask = self.softmax(context_mask)
        context_mask = tf.expand_dims(context_mask, axis=1)

        context = tf.reshape(
            tf.matmul(reshaped_inputs, context_mask), (-1, 1, 1, channels)
        )
        return context

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        context = self.modeling(inputs)
        channel_add_term = self.channel_add_conv_1(context)
        channel_add_term = self.leaky_relu(channel_add_term)
        channel_add_term = self.channel_add_conv_2(channel_add_term)
        return inputs + channel_add_term

    def get_config(self) -> Dict:
        return {"channels": self.channels}


class ResidualContextBlock(tf.keras.layers.Layer):
    """Implementation of the Residual Contextual Block.

    The Residual Contextual Block is used to extract features in the convolutional
    streams and suppress less useful features. The overall process of RCB is
    summarized as:

    $$F_{RCB} = F_{a} + W(CM(F_{b}))$$

    where...

    - $F_{a}$ are the input feature maps.

    - $F_{b}$ represents feature maps that are obtained by applying two 3x3 group
        convolution layers to the input features.

    - $CM$ respresents a **contextual modules**.

    - $W$ denotes the last convolutional layer with filter size $1 \times 1$.

    Reference:

    1. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    2. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L105)

    Args:
        channels (int): number of channels in the feature map.
        groups (int): number of groups in which the input is split along the
            channel axis in the convolution layers.
    """

    def __init__(self, channels: int, groups: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.channels = channels
        self.groups = groups

        self.conv_1 = tf.keras.layers.Conv2D(
            channels, kernel_size=3, padding="same", groups=groups
        )
        self.conv_2 = tf.keras.layers.Conv2D(
            channels, kernel_size=3, padding="same", groups=groups
        )
        self.leaky_relu = tf.keras.layers.LeakyReLU(alpha=0.2)

        self.context_block = ContextBlock(channels=channels)

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        x = self.conv_1(inputs)
        x = self.leaky_relu(x)
        x = self.conv_2(x)
        x = self.context_block(x)
        x = self.leaky_relu(x)
        x = x + inputs
        return x

    def get_config(self) -> Dict:
        return {"channels": self.channels, "groups": self.groups}


class SelectiveKernelFeatureFusion(tf.keras.layers.Layer):
    """Implementation of the Selective Kernel Feature Fusion Layer.

    This layer adaptively adjusts the input receptive fields by using multi-scale
    feature generation (in the same layer) followed by feature aggregation and
    selection. This is done using two distinct operations:

    - **Fuse Operation:** The fuse operator generates global feature descriptors by
        combining the information from multiresolution streams.
    - **Select Operation:** The select operator uses the feature descriptors
        generated by the fuse operator to recalibrate the feature maps
        (of different streams) followed by their aggregation.

    Reference:

    1. [Selective Kernel Networks](https://arxiv.org/abs/1903.06586)
    2. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    3. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L17)

    Args:
        channels (int): number of channels in the feature map.
    """

    def __init__(self, channels: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.channels = channels
        self.hidden_channels = max(int(self.channels / 8), 4)
        self.average_pooling = tf.keras.layers.GlobalAveragePooling2D(keepdims=True)

        self.conv_channel_downscale = tf.keras.layers.Conv2D(
            self.hidden_channels, kernel_size=1, padding="same"
        )
        self.conv_attention_1 = tf.keras.layers.Conv2D(
            self.channels, kernel_size=1, strides=1, padding="same"
        )
        self.conv_attention_2 = tf.keras.layers.Conv2D(
            self.channels, kernel_size=1, strides=1, padding="same"
        )
        self.softmax = tf.keras.layers.Softmax(axis=-1)

    def call(
        self, inputs: Tuple[tf.Tensor], training: Optional[bool] = None
    ) -> tf.Tensor:
        # Fuse operation
        combined_input_features = inputs[0] + inputs[1]
        channel_wise_statistics = self.average_pooling(combined_input_features)
        downscaled_channel_wise_statistics = self.conv_channel_downscale(
            channel_wise_statistics
        )
        attention_vector_1 = self.softmax(
            self.conv_attention_1(downscaled_channel_wise_statistics)
        )
        attention_vector_2 = self.softmax(
            self.conv_attention_2(downscaled_channel_wise_statistics)
        )

        # Select operation
        selected_features = (
            inputs[0] * attention_vector_1 + inputs[1] * attention_vector_2
        )
        return selected_features

    def get_config(self) -> Dict:
        return {"channels": self.channels}


class UpBlock(tf.keras.layers.Layer):
    """Submodule of `UpSampleBlock`.

    Reference:

    1. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    2. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L158)

    Args:
        channels (int): number of input channels.
        channel_factor (float): factor by which number of the number of output channels vary.
    """

    def __init__(self, channels: int, channel_factor: float, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.channels = channels
        self.channel_factor = channel_factor

        self.conv = tf.keras.layers.Conv2D(
            int(channels // channel_factor), kernel_size=1, strides=1, padding="same"
        )
        self.upsample = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear")

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        return self.upsample(self.conv(inputs))

    def get_config(self) -> Dict:
        return {"channels": self.channels, "channel_factor": self.channel_factor}


class UpSampleBlock(tf.keras.layers.Layer):
    """Layer for upsampling feature map for the Multi-scale Residual Block.

    Reference:

    1. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    2. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L170)

    Args:
        channels (int): number of input channels.
        scale_factor (int): number of downsample operations.
        channel_factor (float): factor by which number of the number of output channels vary.
    """

    def __init__(
        self, channels: int, scale_factor: int, channel_factor: float, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)

        self.channels = channels
        self.scale_factor = scale_factor
        self.channel_factor = channel_factor

        self.layers = []
        for _ in range(int(np.log2(scale_factor))):
            self.layers.append(UpBlock(channels, channel_factor))
            channels = int(channels // channel_factor)

    def call(self, x, *args, **kwargs):
        for layer in self.layers:
            x = layer(x)
        return x

    def get_config(self) -> Dict:
        return {
            "channels": self.channels,
            "scale_factor": self.scale_factor,
            "channel_factor": self.channel_factor,
        }




class RecursiveResidualGroup(tf.keras.layers.Layer):
    """Implementation of the Recursive Residual Group.

    The Recursive Residual Group forms the basic building block on MirNetV2.
    It progressively breaks down the input signal in order to simplify the overall
    learning process, and allows the construction of very deep networks.

    Reference:

    1. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    2. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L242)

    Args:
        channels (int): number of channels in the feature map.
        num_mrb_blocks (int): number of multi-scale residual blocks.
        channel_factor (float): factor by which number of the number of output channels vary.
        groups (int): number of groups in which the input is split along the
            channel axis in the convolution layers.
    """

    def __init__(
        self,
        channels: int,
        num_mrb_blocks: int,
        channel_factor: float,
        groups: int,
        *args,
        **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)

        self.channels = channels
        self.num_mrb_blocks = num_mrb_blocks
        self.channel_factor = channel_factor
        self.groups = groups

        self.layers = [
            MultiScaleResidualBlock(channels, channel_factor, groups)
            for _ in range(num_mrb_blocks)
        ]
        self.layers.append(
            tf.keras.layers.Conv2D(channels, kernel_size=3, strides=1, padding="same")
        )

    def call(self, inputs: tf.Tensor, *args, **kwargs) -> tf.Tensor:
        residual = inputs
        for layer in self.layers:
            residual = layer(residual)
        residual = residual + inputs
        return residual

    def get_config(self) -> Dict:
        return {
            "channels": self.channels,
            "num_mrb_blocks": self.num_mrb_blocks,
            "channel_factor": self.channel_factor,
            "groups": self.groups,
        }


class MirNetv2(tf.keras.Model):
    """Implementation of the MirNetv2 model.

    MirNetv2 is a fully convolutional architecture that learns enriched feature
    representations for image restoration and enhancement. It is based on a
    **recursive residual design** with the **multi-scale residual block** or **MRB**
    at its core. The main branch of the MRB is dedicated to maintaining spatially-precise
    high-resolution representations through the entire network and the complimentary set
    of parallel branches provide better contextualized features.

    Reference:

    1. [Learning Enriched Features for Fast Image Restoration and Enhancement](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/zamir-2022-mirnetv2.pdf)
    2. [Official PyTorch implementation of MirNetv2](https://github.com/swz30/MIRNetv2/blob/main/basicsr/models/archs/mirnet_v2_arch.py#L242)

    Args:
        channels (int): number of channels in the feature map.
        channel_factor (float): factor by which number of the number of output channels vary.
        num_mrb_blocks (int): number of multi-scale residual blocks.
        add_residual_connection (bool): add a residual connection between the inputs and the
            outputs or not.
    """

    def __init__(
        self,
        channels: int,
        channel_factor: float,
        num_mrb_blocks: int,
        add_residual_connection: bool,
        *args,
        **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)

        self.channels = channels
        self.channel_factor = channel_factor
        self.num_mrb_blocks = num_mrb_blocks
        self.add_residual_connection = add_residual_connection

        self.conv_in = tf.keras.layers.Conv2D(channels, kernel_size=3, padding="same")

        self.rrg_block_1 = RecursiveResidualGroup(
            channels, num_mrb_blocks, channel_factor, groups=1
        )
        self.rrg_block_2 = RecursiveResidualGroup(
            channels, num_mrb_blocks, channel_factor, groups=2
        )
        self.rrg_block_3 = RecursiveResidualGroup(
            channels, num_mrb_blocks, channel_factor, groups=4
        )
        self.rrg_block_4 = RecursiveResidualGroup(
            channels, num_mrb_blocks, channel_factor, groups=4
        )

        self.conv_out = tf.keras.layers.Conv2D(3, kernel_size=3, padding="same")

    def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor:
        shallow_features = self.conv_in(inputs)
        deep_features = self.rrg_block_1(shallow_features)
        deep_features = self.rrg_block_2(deep_features)
        deep_features = self.rrg_block_3(deep_features)
        deep_features = self.rrg_block_4(deep_features)
        output = self.conv_out(deep_features)
        output = output + inputs if self.add_residual_connection else output
        return output

    def save(self, filepath: str, *args, **kwargs) -> None:
        input_tensor = tf.keras.Input(shape=[None, None, 3])
        saved_model = tf.keras.Model(
            inputs=input_tensor, outputs=self.call(input_tensor)
        )
        saved_model.save(filepath, *args, **kwargs)

    def get_config(self) -> Dict:
        return {
            "channels": self.channels,
            "num_mrb_blocks": self.num_mrb_blocks,
            "channel_factor": self.channel_factor,
            "add_residual_connection": self.add_residual_connection,
        }


model = MirNetv2(
    channels=64,
    channel_factor=1.5,
    num_mrb_blocks=1,
    add_residual_connection=False,
)
model.build([4,256,256,3])
model.summary()

In [None]:
import numpy as np
import cv2
import tensorflow as tf
from keras.optimizers import optimizer
from PIL import Image
from keras.callbacks import *
from keras import backend as K
import scipy.misc
from skimage.metrics import structural_similarity

def save_eval_results(eval_results, filename):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(model.metrics_names)
        writer.writerow(eval_results)


def lr_scheduler(epoch, lr):
    if epoch > 0 and epoch % 2 == 0:
        lr *= 0.1
        print("Learning rate updated to:", lr)
    return lr

callbacks = [
    ModelCheckpoint('best_model', save_best_only=True, monitor="loss"),
    EarlyStopping(monitor='loss', patience=2),
    TensorBoard(log_dir='logs'),
   # ReduceLROnPlateau(monitor='loss', patience=1, factor=0.1),
    LearningRateScheduler(lr_scheduler)
]

def psnr(y_true,y_pred):
    return tf.image.psnr(y_true,y_pred,1.0)

def ssim(y_true,y_pred):
    return tf.image.ssim(y_true,y_pred,1.0)

@tf.function
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
    # stepweight decay

    p.assign(p * (1 - lr * wd))

    # weight update

    update = tf.raw_ops.LinSpace(start=1.0, stop=0.0, num=1, name=None)[0]*exp_avg + (1 - tf.raw_ops.LinSpace(start=1.0, stop=0.0, num=1, name=None)[0])*grad
    p.assign_add(tf.sign(update) * -lr)

    # decay the momentum running average coefficient

    exp_avg.assign(exp_avg * beta2 + grad * (1 - beta2))


def lerp(start, end, weight):
    return start + weight * (end - start)


def sparse_lerp(start, end, weight):
    # Mathematically equivalent, but you can't subtract a dense Tensor from sparse
    # IndexedSlices, so we have to flip it around.
    return start + weight * -(start - end)

class Lion(optimizer.Optimizer):
    """Optimizer that implements the Lion algorithm.
    Lion was published in the paper "Symbolic Discovery of Optimization Algorithms"
    which is available at https://arxiv.org/abs/2302.06675
    Args:
      learning_rate: A `tf.Tensor`, floating point value, a schedule that is a
        `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
        that takes no arguments and returns the actual value to use. The
        learning rate. Defaults to 1e-4.
      beta_1: A float value or a constant float tensor, or a callable
        that takes no arguments and returns the actual value to use. Factor
         used to interpolate the current gradient and the momentum. Defaults to 0.9.
      beta_2: A float value or a constant float tensor, or a callable
        that takes no arguments and returns the actual value to use. The
        exponential decay rate for the momentum. Defaults to 0.99.
    Notes:
    The sparse implementation of this algorithm (used when the gradient is an
    IndexedSlices object, typically because of `tf.gather` or an embedding
    lookup in the forward pass) does apply momentum to variable slices even if
    they were not used in the forward pass (meaning they have a gradient equal
    to zero). Momentum decay (beta2) is also applied to the entire momentum
    accumulator. This means that the sparse behavior is equivalent to the dense
    behavior (in contrast to some momentum implementations which ignore momentum
    unless a variable slice was actually used).
    """

    def __init__(
        self,
        learning_rate=1e-4,
        beta_1=0.9,
        beta_2=0.99,
        weight_decay=None,
        clipnorm=None,
        clipvalue=None,
        global_clipnorm=None,
        jit_compile=True,
        name="Lion",
        **kwargs
    ):
        super().__init__(
            name=name,
            weight_decay=weight_decay,
            clipnorm=clipnorm,
            clipvalue=clipvalue,
            global_clipnorm=global_clipnorm,
            jit_compile=jit_compile,
            **kwargs
        )
        self._learning_rate = self._build_learning_rate(learning_rate)
        self.beta_1 = beta_1
        self.beta_2 = beta_2

    def build(self, var_list):
        """Initialize optimizer variables.
          var_list: list of model variables to build Lion variables on.
        """
        super().build(var_list)
        if hasattr(self, "_built") and self._built:
            return
        self._built = True
        self._emas = []
        for var in var_list:
            self._emas.append(
                self.add_variable_from_reference(
                    model_variable=var, variable_name="ema"
                )
            )

    def update_step(self, gradient, variable):
        """Update step given gradient and the associated model variable."""
        lr = tf.cast(self.learning_rate, variable.dtype)
        beta_1 = tf.constant(self.beta_1, shape=(1,))
        beta_2 = tf.constant(self.beta_2, shape=(1,))

        var_key = self._var_key(variable)
        ema = self._emas[self._index_dict[var_key]]

        if isinstance(gradient, tf.IndexedSlices):
            # Sparse gradients.
            lerp_fn = sparse_lerp
        else:
            # Dense gradients.
            lerp_fn = lerp

        update = lerp_fn(ema, gradient, 1 - beta_1)
        update = tf.sign(update)
        variable.assign_sub(update * lr)

        ema.assign(lerp_fn(ema, gradient, 1 - beta_2))

    def get_config(self):
        config = super().get_config()

        config.update(
            {
                "learning_rate": self._serialize_hyperparameter(
                    self._learning_rate
                ),
                "beta_1": self.beta_1,
                "beta_2": self.beta_2,
            }
        )
        return config

def fast_check(generator, amount: int = 20):
    count = 0
    for item in generator:
        yield item
        count += 1
        if count == amount:
            break

def SSIMLoss(y_true, y_pred):
    return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))

mae = tf.keras.losses.MeanAbsoluteError()
epochs = 5
count = 0
save = True
optimizer = Lion(1e-5)
model.compile(optimizer=optimizer, loss=SSIMLoss, metrics=[psnr, ssim])
for i in range(epochs):
  count += 1
  model.fit(train_data, epochs=1)
  small_gen = fast_check(test_data)
  model.evaluate(small_gen)
  if save:
    model.save(f"model_epoch_{str(count)}", save_format="tf")
model.save("model", save_format="tf")


In [None]:
import shutil
#shutil.copy("model.h5", "/content/drive/MyDrive/model_resto.h5")
shutil.move("model_epoch_2", "/content/drive/MyDrive/model_epoch_2")
