# __Importing required libraries__

In [None]:
import tensorflow as tf
from tensorflow.keras import applications, Model, losses, layers, optimizers
import tensorflow_datasets as tfds

import os

import numpy as np
import matplotlib.pyplot as plt

# __Config values__

In [None]:
# [========= Data Preprocessing =========]
HR_SIZE = 128
SCALE = 4
LR_SIZE = int(HR_SIZE / 4)
BATCH_SIZE = 8

GEN_FILTERS = 64
DISC_FILTERS = 64

# __Data preprocessing and augmentation__

In [None]:
# [====================================================]
# [================ Random Compressions ===============]
# [====================================================]

def random_compression(example):
    hr = example['hr']
    hr_shape = tf.shape(hr)
    compression_idx = tf.random.uniform(shape = (), maxval = 7, dtype = tf.int32)
    
    if compression_idx == 0 or compression_idx == 1:
        # bicubic
        lr = tf.image.resize(hr, [int(hr_shape[0] / SCALE), int(hr_shape[1] / SCALE)], method = 'bicubic')
        lr = tf.cast(tf.round(tf.clip_by_value(lr, 0, 255)), tf.uint8)
    elif compression_idx == 2 or compression_idx == 3:
        # bilinear
        lr = tf.image.resize(hr, [int(hr_shape[0] / SCALE), int(hr_shape[1] / SCALE)], method = 'bilinear')
        lr = tf.cast(tf.round(tf.clip_by_value(lr, 0, 255)), tf.uint8)
    elif compression_idx == 4 or compression_idx == 5:
        # nearest
        lr = tf.image.resize(hr, [int(hr_shape[0] / SCALE), int(hr_shape[1] / SCALE)], method = 'nearest')
        lr = tf.cast(tf.round(tf.clip_by_value(lr, 0, 255)), tf.uint8)
    else:
        # default
        lr = example['lr']
    
    return lr, hr

# [======================================================]
# [============= Spatial Random Augmentations ===========]
# [======================================================]

@tf.function()
def random_crop(lr, hr):
    lr_shape = tf.shape(lr)[:2]

    lr_w = tf.random.uniform(shape = (), maxval = lr_shape[1] - LR_SIZE + 1, dtype = tf.int32)
    lr_h = tf.random.uniform(shape = (), maxval = lr_shape[0] - LR_SIZE + 1, dtype = tf.int32)

    hr_w = lr_w * int(SCALE)
    hr_h = lr_h * int(SCALE)

    lr_cropped = lr[lr_h:lr_h + LR_SIZE, lr_w: lr_w + LR_SIZE]
    hr_cropped = hr[hr_h:hr_h + HR_SIZE, hr_w: hr_w + HR_SIZE]

    return lr_cropped, hr_cropped

@tf.function()
def random_rotate(lr, hr):
    rn = tf.random.uniform(shape = (), maxval = 4, dtype = tf.int32)
    return tf.image.rot90(lr, rn), tf.image.rot90(hr, rn)

@tf.function()
def random_spatial_augmentation(lrs, hrs):
    lrs, hrs = tf.cond(
        tf.random.uniform(shape = (), maxval = 1) < 0.5,
        lambda: (lrs, hrs),
        lambda: random_rotate(lrs, hrs)
    )

    return tf.cast(lrs, tf.float32), tf.cast(hrs, tf.float32)

# __Downloading dataset and creating data loader__

In [None]:
train_data = tfds.load(f'div2k/bicubic_x{SCALE}', split = 'train', shuffle_files = True)
train_data = train_data.map(random_compression, num_parallel_calls = tf.data.AUTOTUNE)
train_data = train_data.map(random_crop, num_parallel_calls = tf.data.AUTOTUNE)
train_data = train_data.batch(BATCH_SIZE, drop_remainder = True)
train_data = train_data.map(random_spatial_augmentation, num_parallel_calls = tf.data.AUTOTUNE)

train_data = train_data.prefetch(tf.data.AUTOTUNE)

In [None]:
!git clone https://github.com/braindotai/Real-Time-Super-Resolution.git

In [None]:
for lrs, hrs in train_data:
    break

print(lrs.shape, hrs.shape)
print(lrs.dtype, hrs.dtype)
print(tf.reduce_min(lrs), tf.reduce_max(lrs))
print(tf.reduce_min(hrs), tf.reduce_max(hrs))

In [None]:
def visualize_samples(images_lists, titles = None, size = (12, 12), masked = False):
    assert len(images_lists) == len(titles)
    
    cols = len(images_lists)
    
    for images in zip(*images_lists):
        plt.figure(figsize = size)
        for idx, image in enumerate(images):
            plt.subplot(1, cols, idx + 1)
            plt.imshow(tf.cast(tf.round(tf.clip_by_value(image, 0, 255)), tf.uint8))
            plt.axis('off')
            if titles:
                plt.title(titles[idx])
        plt.show()

In [None]:
visualize_samples(images_lists = (lrs[:15], hrs[:15]), titles = ('Low Resolution', 'High Resolution'), size = (8, 8))

# __Layers for creating models__

In [None]:
class Conv2D(layers.Conv2D):
    def __init__(self, kernel_size = 3, padding = 'same', **kwargs):
        super(Conv2D, self).__init__(
            kernel_size = kernel_size,
            padding = padding,
            bias_initializer = tf.keras.initializers.Zeros(),
            **kwargs
        )

# __Modules for creating models__

In [None]:
class Conv2DBlock(layers.Layer):
    def __init__(self, filters, batchnorm = True, activate = True, **kwargs):
        super(Conv2DBlock, self).__init__()

        self.conv = Conv2D(filters = filters, **kwargs)
        self.batchnorm = layers.BatchNormalization() if batchnorm else None
        self.activate = layers.PReLU(shared_axes = [1, 2]) if activate else None
        
    def call(self, inputs):
        x = self.conv(inputs)
        if self.batchnorm:
            x = self.batchnorm(x)
        if self.activate:
            x = self.activate(x)
        return x

In [None]:
class ResidualDenseBlock(layers.Layer):
    def __init__(self, filters = 64):
        super(ResidualDenseBlock, self).__init__()

        self.conv1 = Conv2DBlock(filters = filters // 2)
        self.conv2 = Conv2DBlock(filters = filters // 2)
        self.conv3 = Conv2DBlock(filters = filters, activate = False)

    def call(self, inputs):
        x1 = self.conv1(inputs)
        x2 = self.conv2(tf.concat([x1, inputs], 3))
        outputs = self.conv3(tf.concat([x2, x1], 3))
        
        return outputs + inputs

In [None]:
class RRDBlock(layers.Layer):
    def __init__(self, filters, **kwargs):
        super(RRDBlock, self).__init__(**kwargs)

        self.rdb_1 = ResidualDenseBlock(filters)
        self.rdb_2 = ResidualDenseBlock(filters)
        self.rdb_3 = ResidualDenseBlock(filters)

        self.rrdb_inputs_scales = tf.Variable(
            tf.constant(value = 1.0, dtype = tf.float32, shape = [1, 1, 1, filters]),
            name = f'{self.name}_rrdb_inputs_scales',
            trainable = True
        )
        self.rrdb_outputs_scales = tf.Variable(
            tf.constant(value = 0.5, dtype = tf.float32, shape = [1, 1, 1, filters]),
            name = f'{self.name}_rrdb_outputs_scales',
            trainable = True
        )

    def call(self, inputs):
        x1 = self.rdb_1(inputs)
        x2 = self.rdb_2(x1)
        outputs = self.rdb_3(x2)

        return (self.rrdb_inputs_scales * inputs) + (self.rrdb_outputs_scales * outputs)

In [None]:
class PixelShuffleUpSampling(layers.Layer):
    def __init__(self, filters, scale, **kwargs):
        super(PixelShuffleUpSampling, self).__init__(**kwargs)

        self.conv1 = Conv2DBlock(filters = filters, batchnorm = False, activate = False)
        self.upsample = layers.Lambda(lambda x: tf.nn.depth_to_space(x, scale))
        self.prelu = layers.PReLU(shared_axes = [1, 2])
    
    def call(self, x):
        x = self.conv1(x)
        x = self.upsample(x)
        x = self.prelu(x)
        return x

# __Building the models__

In [None]:
def Generator():
    lr_image = layers.Input(shape = (None, None, 3))
    
    spatial_feats = layers.Lambda(lambda x: x / 255.0)(lr_image)
    spatial_feats = Conv2DBlock(filters = GEN_FILTERS, kernel_size = 3, strides = 1, padding = 'same', batchnorm = False)(spatial_feats)
    spatial_feats = Conv2DBlock(filters = GEN_FILTERS, kernel_size = 1, strides = 1, padding = 'valid', batchnorm = False)(spatial_feats)

    rrdb1 = RRDBlock(GEN_FILTERS)(spatial_feats)
    rrdb2 = RRDBlock(GEN_FILTERS)(rrdb1)
    rrdb3 = RRDBlock(GEN_FILTERS)(rrdb2)
    rrdb4 = RRDBlock(GEN_FILTERS)(rrdb3)

    upsample1 = PixelShuffleUpSampling(GEN_FILTERS * 4, 2)(rrdb4)
    upsample2 = PixelShuffleUpSampling(GEN_FILTERS * 4, 2)(upsample1)

    x = Conv2DBlock(filters = GEN_FILTERS, batchnorm = False)(upsample2)
    x = Conv2DBlock(filters = 3, kernel_size = 3, activate = False, batchnorm = False)(x)
    x = layers.Activation('tanh')(x)

    sr_image = layers.Lambda(lambda x: (x + 1) * 127.5)(x)

    return Model(inputs = lr_image, outputs = sr_image, name = 'Generator')

In [None]:
generator = Generator()
generator.summary(100)

In [None]:
visualize_samples([lrs[:5], generator(lrs[:5])], titles = ['LR', 'SR'], size = (6, 6))

In [None]:
def Discriminator():
    hr_image = layers.Input(shape = (HR_SIZE, HR_SIZE, 3))
    x = layers.Lambda(lambda x: x / 127.5 - 1)(hr_image)

    x = Conv2D(kernel_size = 3, filters = DISC_FILTERS // 2)(x)
    x = layers.LeakyReLU(0.2)(x)

    x = Conv2D(filters = DISC_FILTERS // 2, kernel_size = 3, strides = 2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = Conv2D(filters = DISC_FILTERS, kernel_size = 3)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = Conv2D(filters = DISC_FILTERS, kernel_size = 3, strides = 2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = Conv2D(filters = DISC_FILTERS * 2, kernel_size = 3, strides = 1)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = Conv2D(filters = DISC_FILTERS * 2, kernel_size = 3, strides = 2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = Conv2D(filters = DISC_FILTERS * 4, kernel_size = 3, strides = 1)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = Conv2D(filters = DISC_FILTERS * 4, kernel_size = 3, strides = 2)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Flatten()(x)
    
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(0.2)(x)

    logits = layers.Dense(1)(x)

    return Model(inputs = hr_image, outputs = logits)

In [None]:
discriminator = Discriminator()
discriminator.summary(100)

# __Defining training procedures__

In [None]:
class PixelLossTraining:
    def setup_pixel_loss(self, pixel_loss):
        if pixel_loss == 'l1':
            self.pixel_loss_type = losses.MeanAbsoluteError()
        elif pixel_loss == 'l2':
            self.pixel_loss_type = losses.MeanSquaredError()

    @tf.function
    def pixel_loss(self, srs, hrs):
        return self.pixel_loss_type(hrs, srs)

In [None]:
class VGGContentTraining:
    def setup_content_loss(self, content_loss):
        if content_loss == 'l1':
            self.content_loss_type = losses.MeanAbsoluteError()
        elif content_loss == 'l2':
            self.content_loss_type = losses.MeanSquaredError()
        
        vgg = applications.VGG19(
            input_shape = (224, 224, 3),
            include_top = False,
            weights = 'imagenet'
        )
        
        vgg.layers[5].activation = None
        vgg.layers[10].activation = None
        vgg.layers[20].activation = None

        self.feature_extrator = Model(
            inputs = vgg.input,
            outputs = [
                vgg.layers[5].output,
                vgg.layers[10].output,
                vgg.layers[20].output
            ]
        )
        for layer in self.feature_extrator.layers:
            layer.trainable = False
    
    @tf.function
    def content_loss(self, srs, hrs):
        srs = applications.vgg19.preprocess_input(tf.image.resize(srs, (224, 224)))
        hrs = applications.vgg19.preprocess_input(tf.image.resize(hrs, (224, 224)))
        
        srs_features = self.feature_extrator(srs)
        hrs_features = self.feature_extrator(hrs)

        loss = 0.0
        for srs_feature, hrs_feature in zip(srs_features, hrs_features):
            loss += self.content_loss_type(hrs_feature / 12.75, srs_feature / 12.75)

        return loss

In [None]:
class GramStyleTraining:
    def setup_gram_style_loss(self, style_loss):
        if style_loss == 'l1':
            self.style_loss_type = losses.MeanAbsoluteError()
        elif style_loss == 'l2':
            self.style_loss_type = losses.MeanSquaredError()
        
        efficientnet = applications.EfficientNetB4(
            input_shape = (224, 224, 3),
            include_top = False,
            weights = 'imagenet'
        )
        
        self.style_features_extractor = Model(
            inputs = efficientnet.input,
            outputs = [
                # efficientnet.layers[25].output,
                # efficientnet.layers[84].output,
                # efficientnet.layers[143].output,
                efficientnet.layers[320].output,
                # efficientnet.layers[467].output,
            ]
        )
        for layer in self.style_features_extractor.layers:
            layer.trainable = False

    @tf.function
    def gram_matrix(self, features):
        features = tf.transpose(features, (0, 3, 1, 2)) # (-1, C, H, W)
        features_a = tf.reshape(features, (tf.shape(features)[0], tf.shape(features)[1], -1)) # (-1, C, H * W)
        features_b = tf.reshape(features, (tf.shape(features)[0], -1, tf.shape(features)[1])) # (-1, H * W, C)
        
        return tf.linalg.matmul(features_a, features_b) # (-1, C, C)

    @tf.function
    def gram_style_loss(self, srs, hrs):
        srs = applications.efficientnet.preprocess_input(tf.image.resize(srs, (224, 224)))
        hrs = applications.efficientnet.preprocess_input(tf.image.resize(hrs, (224, 224)))

        srs_features = self.style_features_extractor(srs) # (2, -1, H, W, C)
        hrs_features = self.style_features_extractor(hrs) # (2, -1, H, W, C)

        # style_loss = 0.0
        # for srs_feature, hrs_feature in zip(srs_features, hrs_features):
        srs_gram = self.gram_matrix(srs_features)
        hrs_gram = self.gram_matrix(hrs_features)

        style_loss = self.style_loss_type(hrs_gram, srs_gram)

        return style_loss

In [None]:
class AdversarialTraining:
    def setup_adversarial_loss(self, adv_loss):
        self.adv_loss_type = adv_loss
        self.binary_cross_entropy = losses.BinaryCrossentropy(from_logits = True)

    @tf.function
    def gen_adv_loss(self, fake_logits, real_logits = None):
        if self.adv_loss_type == 'gan':
            loss = self.binary_cross_entropy(tf.ones_like(fake_logits), fake_logits)
        
        elif self.adv_loss_type == 'ragan':
            real_loss = self.binary_cross_entropy(tf.ones_like(fake_logits), fake_logits - tf.reduce_mean(real_logits))
            fake_loss = self.binary_cross_entropy(tf.zeros_like(real_logits), real_logits - tf.reduce_mean(fake_logits))
            loss = real_loss + fake_loss
        
        return loss
        
    @tf.function
    def disc_adv_loss(self, fake_logits, real_logits):
        if self.adv_loss_type == 'gan':
            real_loss = self.binary_cross_entropy(tf.ones_like(real_logits), real_logits)
            fake_loss = self.binary_cross_entropy(tf.zeros_like(fake_logits), fake_logits)
        
        elif self.adv_loss_type == 'ragan':
            real_loss = self.binary_cross_entropy(tf.ones_like(real_logits), real_logits - tf.reduce_mean(fake_logits))
            fake_loss = self.binary_cross_entropy(tf.zeros_like(fake_logits), fake_logits - tf.reduce_mean(real_logits))
        
        return real_loss + fake_loss

# __Defining SRGAN Model__

In [None]:
class SRGAN(
        Model,
        PixelLossTraining,
        GramStyleTraining,
        VGGContentTraining,
        AdversarialTraining
    ):
    def __init__(
        self,
        generator,
        discriminator,
    ):
        super(SRGAN, self).__init__(self, dynamic = True)

        self.generator = generator
        self.discriminator = discriminator
    
    def compile(
        self,

        generator_optimizer, 
        discriminator_optimizer,

        perceptual_finetune,

        pixel_loss,
        style_loss,
        content_loss,
        adv_loss,

        loss_weights,
    ):
        super(SRGAN, self).compile()

        self.generator.optimizer = generator_optimizer
        self.discriminator.optimizer = discriminator_optimizer

        self.perceptual_finetune = perceptual_finetune

        self.setup_pixel_loss(pixel_loss)
        # self.setup_gram_style_loss(style_loss)
        self.setup_content_loss(content_loss)
        self.setup_adversarial_loss(adv_loss)

        if self.perceptual_finetune:
            self.loss_weights = loss_weights

    def train_step(self, batch):
        self.lrs = batch[0]
        self.hrs = batch[1]

        if self.perceptual_finetune:
            # [=================== Training Discriminator ===================]

            with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
                self.srs = self.generator(self.lrs, training = True)

                real_logits = self.discriminator(self.hrs, training = True)
                fake_logits = self.discriminator(self.srs, training = True)

                content_loss = self.loss_weights['content_loss'] * self.content_loss(self.srs, self.hrs)
                gen_adv_loss = self.loss_weights['adv_loss'] * self.gen_adv_loss(fake_logits, real_logits)
                perceptual_loss = content_loss + gen_adv_loss
                
                # style_loss = self.loss_weights['style_loss'] * self.gram_style_loss(self.srs, self.hrs)

                gen_loss = perceptual_loss

                disc_adv_loss = self.disc_adv_loss(fake_logits, real_logits)
            
            discriminator_gradients = disc_tape.gradient(disc_adv_loss, self.discriminator.trainable_variables)
            generator_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
            
            self.discriminator.optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator.trainable_variables))
            self.generator.optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables))

            return {
                'Perceptual Loss': perceptual_loss,
                # 'Style Loss': style_loss,
                'Generator Adv Loss': gen_adv_loss,
                'Discriminator Adv Loss': disc_adv_loss,
            }
        
        else:
            with tf.GradientTape() as gen_tape:
                self.srs = self.generator(self.lrs, training = True)

                pixel_loss = self.pixel_loss(self.srs, self.hrs)

            generator_gradients = gen_tape.gradient(pixel_loss, self.generator.trainable_variables)
            self.generator.optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables))

            return {
                'Pixel Loss': pixel_loss,
            }

# __Checkpoint Callback__

In [None]:
class CheckpointCallback(tf.keras.callbacks.Callback):
    def __init__(self, checkpoint_dir, resume = False, epoch_step = 1):
        super(CheckpointCallback, self).__init__()
        
        self.checkpoint_dir = checkpoint_dir
        self.resume = resume
        self.epoch_step = epoch_step
    
    def setup_checkpoint(self, *args, **kwargs):
        self.checkpoint = tf.train.Checkpoint(
            generator = self.model.generator,
            discriminator = self.model.discriminator,
            generator_optimizer = self.model.generator.optimizer,
            discriminator_optimizer = self.model.discriminator.optimizer
        )
        self.manager = tf.train.CheckpointManager(
            self.checkpoint,
            directory = self.checkpoint_dir,
            checkpoint_name = 'SRGAN',
            max_to_keep = 1
        )

        if self.resume:
            self.load_checkpoint()
        else:
            print('Starting training from scratch...\n')
        
    def on_batch_end(self, batch, *args, **kwargs): 
        if (batch + 1) % int(self.epoch_step * len(train_data)) == 0:
            print(f"\n\nCheckpoint saved to {self.manager.save()}\n")
    
    def load_checkpoint(self):
        if self.manager.latest_checkpoint:
            self.checkpoint.restore(self.manager.latest_checkpoint)
            print(f"Checkpoint restored from '{self.manager.latest_checkpoint}'\n")
        else:
            print("No checkpoints found, initializing from scratch...\n")
    
    def set_lr(self, lr, beta_1 = 0.9):
        print(f'Continuing with learning rate: {lr}')
        self.model.generator.optimizer.beta_1 = beta_1
        self.model.generator.optimizer.learning_rate = lr
        self.model.discriminator.optimizer.beta_1 = beta_1
        self.model.discriminator.optimizer.learning_rate = lr

# __Optimization Progress Callback__

In [None]:
class ProgressCallback(tf.keras.callbacks.Callback):
    def __init__(self, logs_step, generator_step):
        super(ProgressCallback, self).__init__()

        self.logs_step = logs_step
        self.generator_step = generator_step

    def on_batch_end(self, batch, logs, **kwargs):
        if (batch + 1) % int(self.generator_step * len(train_data)) == 0:
            if self.model.perceptual_finetune:
                visualize_samples(
                    images_lists = (self.model.lrs[:3], self.model.srs[:3], self.model.hrs[:3]),
                    titles = ('Low Resolution', 'Predicted Enhanced', 'High Resolution'),
                    size = (11, 11)
                )
            else:
                visualize_samples(
                    images_lists = (self.model.lrs[:3], self.model.srs[:3]),
                    titles = ('Low Resolution', 'Predicted Enhanced'),
                    size = (7, 7)
                )

# __Optimization Config Values__

In [None]:
EPOCHS = 100
LR = 0.00002
BETA_1 = 0.8
BETA_2 = 0.999

PERCEPTUAL_FINETUNE = True

PIXEL_LOSS = 'l1'
STYLE_LOSS = 'l1'
CONTENT_LOSS = 'l1'
ADV_LOSS = 'ragan'

LOSS_WEIGHTS = {'content_loss': 1.0, 'adv_loss': 0.09, 'style_loss': 1.0}

CHECKPOINT_DIR = os.path.join('drive', 'MyDrive', 'Model-Checkpoints', 'Super Resolution')

# __Initializing Models__

In [None]:
generator_optimizer = optimizers.Adam(
    learning_rate = LR,
    beta_1 = BETA_1,
    beta_2 = BETA_2
)
discriminator_optimizer = optimizers.Adam(
    learning_rate = LR,
    beta_1 = BETA_1,
    beta_2 = BETA_2
)

srgan = SRGAN(generator, discriminator)
srgan.compile(
    generator_optimizer = generator_optimizer,
    discriminator_optimizer = discriminator_optimizer,
    
    perceptual_finetune = PERCEPTUAL_FINETUNE,
    pixel_loss = PIXEL_LOSS,
    style_loss = STYLE_LOSS,
    content_loss = CONTENT_LOSS,
    adv_loss = ADV_LOSS,

    loss_weights = LOSS_WEIGHTS
)

# __Setting up checkpoint callback__

In [None]:
ckpt_callback = CheckpointCallback(
    checkpoint_dir = CHECKPOINT_DIR,
    resume = True,
    epoch_step = 4
)
ckpt_callback.set_model(srgan)
ckpt_callback.setup_checkpoint(srgan)
ckpt_callback.set_lr(0.0002, BETA_1)

# __Training the Model__

In [None]:
srgan.fit(
    train_data.repeat(EPOCHS // 10),
    epochs = 5,
    callbacks = [
        ckpt_callback,
        ProgressCallback(
            logs_step = 0.2,
            generator_step = 2
        )
    ]
)

# __Testing the model__

In [None]:
def enhance_image(lr_image = None, path = None, output_path = None, visualize = True, size = (20, 16)):
    assert any([lr_image is not None, path])
    if path:
        lr_image = tf.image.decode_jpeg(tf.io.read_file(f"{path}"), channels = 3)

    sr_image = srgan.generator(tf.expand_dims(lr_image, 0), training = False)[0]
    sr_image = tf.clip_by_value(sr_image, 0, 255)
    sr_image = tf.round(sr_image)
    sr_image = tf.cast(sr_image, tf.uint8)

    if visualize:
        visualize_samples(images_lists = [[lr_image], [sr_image]], titles = ['LR Image', 'SR_Image'], size = size)

    if output_path:
        tf.io.write_file(output_path, tf.image.encode_jpeg(sr_image))

In [None]:
for idx, (lr, _) in enumerate(tfds.load(f'div2k/bicubic_x{SCALE}', split = 'validation', shuffle_files = True).shuffle(100).take(20).map(random_compression, num_parallel_calls = tf.data.AUTOTUNE)):
    lr_x = tf.random.uniform(maxval = int(lr.shape[1] // 2), shape = (), dtype = tf.int32)
    lr_y = tf.random.uniform(maxval = int(lr.shape[0] // 2), shape = (), dtype = tf.int32)
    lr = lr[lr_x: lr_x + 100, lr_y: lr_y + 100, :]
    sr = enhance_image(lr_image = lr)