# Import library

In [2]:
!pip install --force-reinstall "protobuf==3.20.*"

^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [3]:
"""
Super-Resolution Project
- SRCNN baseline (PSNR-oriented)
- SRGAN baseline (SRResNet + BCE + VGG content)
- Attentive ESRGAN (no BN, channel attention, RaLSGAN + VGG + L1)

Paste this into a Kaggle notebook cell and run it once to define everything.
"""

# ============================================================
# 0. Imports & Global Config
# ============================================================

import os
import math
import random
import time

import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, Input

from tensorflow.keras.applications import VGG19
from tensorflow.keras.applications.vgg19 import preprocess_input

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.18.0


In [4]:
# ---------------- Configuration ----------------

# DIV2K root (adjust these paths for your environment)
DIV2K_ROOT = "/kaggle/input/div2k-high-resolution-images"

# Example pre-trained model paths (for Kaggle; change for your setup)
SRCNN_PRETRAINED_PATH = "/kaggle/input/my-sr-models/srcnn_baseline_model.h5"
SRRESNET_WARMUP_PATH = "/kaggle/input/my-sr-models/srresnet_warmup.h5"
ATTENTIVE_WARMUP_PATH = "/kaggle/input/my-sr-models/attentive_generator_warmup.h5"

# Training hyperparameters
BATCH_SIZE = 16
HR_CROP_SIZE = 128
UPSCALE = 4
LR_CROP_SIZE = HR_CROP_SIZE // UPSCALE


# 1. Data Pipeline for SRGAN / ESRGAN

In [5]:
# ============================================================
# 1. Data Pipeline for SRGAN / ESRGAN
# ============================================================

class SRGANDataGenerator(keras.utils.Sequence):
    """
    Custom Data Generator for Super Resolution (SRGAN / ESRGAN).
    - Loads HR images from a folder
    - Random crops HR patches of size HR_CROP_SIZE x HR_CROP_SIZE
    - Generates LR patches via bicubic downsampling
    - Normalizes both LR & HR to [-1, 1]
    """

    def __init__(self, hr_dir, batch_size=16, crop_size=128, scale_factor=4, shuffle=True):
        self.hr_dir = hr_dir
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.scale_factor = scale_factor
        self.shuffle = shuffle

        try:
            self.image_files = sorted(
                f for f in os.listdir(hr_dir)
                if f.lower().endswith((".png", ".jpg", ".jpeg"))
            )
            print(f"[SRGANDataGenerator] {len(self.image_files)} images found in: {hr_dir}")
        except FileNotFoundError:
            print(f"[SRGANDataGenerator] ERROR: Directory not found: {hr_dir}")
            self.image_files = []

        self.indexes = np.arange(len(self.image_files))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __len__(self):
        if not self.image_files:
            return 0
        return math.ceil(len(self.image_files) / self.batch_size)

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __getitem__(self, index):
        start = index * self.batch_size
        end = (index + 1) * self.batch_size
        batch_indexes = self.indexes[start:end]
        batch_files = [self.image_files[i] for i in batch_indexes]

        lr_batch, hr_batch = [], []

        for fname in batch_files:
            img_path = os.path.join(self.hr_dir, fname)
            try:
                hr_img = cv2.imread(img_path)
                if hr_img is None:
                    continue
                hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)

                h, w, _ = hr_img.shape
                if h < self.crop_size or w < self.crop_size:
                    continue

                # Random HR crop
                x = random.randint(0, w - self.crop_size)
                y = random.randint(0, h - self.crop_size)
                hr_patch = hr_img[y:y + self.crop_size, x:x + self.crop_size]

                # LR via bicubic downsampling
                lr_size = (self.crop_size // self.scale_factor,
                           self.crop_size // self.scale_factor)
                lr_patch = cv2.resize(hr_patch, lr_size, interpolation=cv2.INTER_CUBIC)

                # Normalize to [-1, 1]
                lr_batch.append(lr_patch / 127.5 - 1.0)
                hr_batch.append(hr_patch / 127.5 - 1.0)

            except Exception as e:
                print(f"[SRGANDataGenerator] Error processing {fname}: {e}")
                continue

        if len(lr_batch) == 0:
            # fallback to next batch if something went wrong
            return self.__getitem__((index + 1) % self.__len__())

        return np.array(lr_batch), np.array(hr_batch)


def resolve_div2k_paths(root):
    """
    Handle possible nested structures in Kaggle DIV2K dataset.
    Returns (train_hr_dir, valid_hr_dir).
    """
    train = os.path.join(root, "DIV2K_train_HR", "DIV2K_train_HR")
    valid = os.path.join(root, "DIV2K_valid_HR", "DIV2K_valid_HR")
    if not os.path.exists(train):
        train = os.path.join(root, "DIV2K_train_HR")
        valid = os.path.join(root, "DIV2K_valid_HR")
    return train, valid


# Prepare generators (you can comment these out and call later if you want)
train_hr_dir, valid_hr_dir = resolve_div2k_paths(DIV2K_ROOT)
train_gen = SRGANDataGenerator(train_hr_dir, batch_size=BATCH_SIZE,
                               crop_size=HR_CROP_SIZE, scale_factor=UPSCALE)
val_gen = SRGANDataGenerator(valid_hr_dir, batch_size=BATCH_SIZE,
                             crop_size=HR_CROP_SIZE, scale_factor=UPSCALE)

[SRGANDataGenerator] 800 images found in: /kaggle/input/div2k-high-resolution-images/DIV2K_train_HR/DIV2K_train_HR
[SRGANDataGenerator] 100 images found in: /kaggle/input/div2k-high-resolution-images/DIV2K_valid_HR/DIV2K_valid_HR


# 2. SRCNN Baseline (PSNR-oriented)

In [6]:
# ============================================================
# 2. SRCNN Baseline (PSNR-oriented)
# ============================================================

def build_srcnn():
    """
    SRCNN architecture (3 conv layers):
    - Conv 64 @ 9x9
    - Conv 32 @ 1x1
    - Conv 3  @ 5x5
    Input/Output: RGB in [0,1].
    """
    model = models.Sequential([
        layers.Conv2D(64, (9, 9), activation='relu', padding='same',
                      input_shape=(None, None, 3)),
        layers.Conv2D(32, (1, 1), activation='relu', padding='same'),
        layers.Conv2D(3, (5, 5), activation='linear', padding='same'),
    ])

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-3),
        loss='mean_squared_error',
        metrics=['mean_squared_error']
    )
    return model


def predict_srcnn_full_image(model, image_path, scale_factor=4):
    """
    Evaluate SRCNN on a full-resolution image (baseline PSNR model).
    Steps:
        HR -> downscale (LR) -> bicubic upsample -> SRCNN -> compare to HR
    Everything is in [0,1] for SRCNN.
    """
    hr_img = cv2.imread(image_path)
    if hr_img is None:
        print(f"[SRCNN] Could not load image: {image_path}")
        return
    hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)

    h, w, _ = hr_img.shape
    h, w = (h // scale_factor) * scale_factor, (w // scale_factor) * scale_factor
    hr_img = hr_img[:h, :w, :]

    lr_shape = (w // scale_factor, h // scale_factor)
    lr_img_small = cv2.resize(hr_img, lr_shape, interpolation=cv2.INTER_CUBIC)
    lr_up = cv2.resize(lr_img_small, (w, h), interpolation=cv2.INTER_CUBIC)

    # SRCNN input: [0,1]
    inp = lr_up.astype(np.float32) / 255.0
    inp_batch = np.expand_dims(inp, axis=0)

    sr_img = model.predict(inp_batch, verbose=0)[0]
    sr_img = np.clip(sr_img, 0.0, 1.0)

    hr_img_01 = hr_img.astype(np.float32) / 255.0
    bicubic_img = lr_up.astype(np.float32) / 255.0

    tf_hr = tf.convert_to_tensor(hr_img_01, tf.float32)
    tf_sr = tf.convert_to_tensor(sr_img, tf.float32)
    tf_bic = tf.convert_to_tensor(bicubic_img, tf.float32)

    psnr_sr = tf.image.psnr(tf_hr, tf_sr, max_val=1.0).numpy()
    psnr_bic = tf.image.psnr(tf_hr, tf_bic, max_val=1.0).numpy()

    fig, axes = plt.subplots(1, 3, figsize=(24, 10))
    axes[0].imshow(bicubic_img)
    axes[0].set_title(f"Bicubic\nPSNR: {psnr_bic:.2f} dB")
    axes[0].axis("off")

    axes[1].imshow(sr_img)
    axes[1].set_title(
        f"SRCNN\nPSNR: {psnr_sr:.2f} dB",
        color="green" if psnr_sr > psnr_bic else "black",
        fontweight="bold",
    )
    axes[1].axis("off")

    axes[2].imshow(hr_img_01)
    axes[2].set_title("Ground Truth")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

print("SRCNN helpers defined.")

SRCNN helpers defined.


# 3. SRGAN Baseline (SRResNet + BCE + VGG)

In [None]:
# ============================================================
# 3. SRGAN Baseline (SRResNet + BCE + VGG)
# ============================================================

def residual_block_bn(x):
    """
    Residual block with BatchNorm (original SRGAN).
    Conv -> BN -> PReLU -> Conv -> BN -> Add.
    """
    shortcut = x
    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    return layers.Add()([x, shortcut])


def upsample_block_bn(x):
    """
    Upsample block with BN-based generator (PixelShuffle x2).
    """
    x = layers.Conv2D(256, 3, padding='same')(x)
    x = layers.Lambda(lambda z: tf.nn.depth_to_space(z, block_size=2))(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    return x


def build_srgan_generator(scale=4, num_res_blocks=16):
    """
    Baseline SRGAN generator (SRResNet) with BN in residual blocks.
    Output: tanh in [-1,1].
    """
    lr_input = Input(shape=(None, None, 3))

    x1 = layers.Conv2D(64, 9, padding='same')(lr_input)
    x1 = layers.PReLU(shared_axes=[1, 2])(x1)

    x = x1
    for _ in range(num_res_blocks):
        x = residual_block_bn(x)

    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.Add()([x, x1])

    if scale >= 2:
        x = upsample_block_bn(x)
    if scale >= 4:
        x = upsample_block_bn(x)

    out = layers.Conv2D(3, 9, padding='same', activation='tanh')(x)
    return models.Model(lr_input, out, name="SRGAN_Generator")


def discriminator_block(x, filters, strides=1, batch_norm=True):
    x = layers.Conv2D(filters, 3, strides=strides, padding='same')(x)
    if batch_norm:
        x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    return x


def build_srgan_discriminator(input_shape):
    img_input = Input(shape=input_shape)

    x = discriminator_block(img_input, 64, strides=1, batch_norm=False)
    x = discriminator_block(x, 64, strides=2)
    x = discriminator_block(x, 128, strides=1)
    x = discriminator_block(x, 128, strides=2)
    x = discriminator_block(x, 256, strides=1)
    x = discriminator_block(x, 256, strides=2)
    x = discriminator_block(x, 512, strides=1)
    x = discriminator_block(x, 512, strides=2)

    x = layers.Flatten()(x)
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    validity = layers.Dense(1, activation='sigmoid')(x)
    return models.Model(img_input, validity, name="SRGAN_Discriminator")


def build_vgg(hr_shape):
    vgg = VGG19(weights="imagenet", include_top=False, input_shape=hr_shape)
    model = models.Model(
        inputs=vgg.inputs,
        outputs=vgg.get_layer("block5_conv4").output
    )
    model.trainable = False
    return model


def build_srgan_combined(generator, discriminator, vgg, lr_shape):
    # freeze D + VGG chỉ cho combined
    discriminator.trainable = False
    for l in discriminator.layers: 
        l.trainable = False

    vgg.trainable = False
    for l in vgg.layers:
        l.trainable = False

    lr_input = Input(shape=lr_shape)
    sr = generator(lr_input)
    sr_features = vgg(sr)
    validity = discriminator(sr)

    model = models.Model(lr_input, [validity, sr_features])
    model.compile(
        loss=['binary_crossentropy', 'mse'],
        loss_weights=[1e-3, 1.0],
        optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    )

    # IMPORTANT: bật lại D để bạn còn train D riêng
    discriminator.trainable = True
    for l in discriminator.layers:
        l.trainable = True

    return model

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

def train_srgan_baseline(generator,
                         discriminator,
                         srgan,
                         vgg,
                         train_loader,
                         val_loader=None,
                         epochs=30,
                         steps_per_epoch=50):
    """
    Classical SRGAN training:
      - D: BCE on real/fake HR images
      - G: combined model with BCE (fool D) + VGG MSE (content)

    Returns:
        history: dict with per-epoch averages:
            - 'epoch'
            - 'd_loss'
            - 'g_loss'
            - 'val_psnr' (if val_loader given)
            - 'val_ssim' (if val_loader given)
    """
    batch_size = BATCH_SIZE
    real_labels = np.ones((batch_size, 1))
    fake_labels = np.zeros((batch_size, 1))

    history = {
        "epoch": [],
        "d_loss": [],
        "g_loss": [],
    }
    if val_loader is not None:
        history["val_psnr"] = []
        history["val_ssim"] = []

    print("[SRGAN] Starting training...")

    for epoch in range(epochs):
        d_losses, g_losses = [], []
        start = time.time()

        for step in range(steps_per_epoch):
            try:
                lr_imgs, hr_imgs = train_loader.__getitem__(step % len(train_loader))
            except Exception:
                continue

            if len(lr_imgs) != batch_size:
                continue

            # ---------- Train Discriminator ----------
            fake_imgs = generator.predict(lr_imgs, verbose=0)

            d_loss_real = discriminator.train_on_batch(hr_imgs, real_labels)
            d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_labels)

            d_real_val = d_loss_real[0] if isinstance(d_loss_real, list) else d_loss_real
            d_fake_val = d_loss_fake[0] if isinstance(d_loss_fake, list) else d_loss_fake
            d_loss = 0.5 * (d_real_val + d_fake_val)
            d_losses.append(d_loss)

            # ---------- Train Generator ----------
            hr_features = vgg.predict(hr_imgs, verbose=0)
            g_loss = srgan.train_on_batch(lr_imgs, [real_labels, hr_features])
            g_total = g_loss[0] if isinstance(g_loss, list) else g_loss
            g_losses.append(g_total)

            if step % 10 == 0:
                print(f"[SRGAN] Epoch {epoch+1}/{epochs} "
                      f"Step {step}/{steps_per_epoch} | D: {d_loss:.4f} | G: {g_total:.4f}")

        # ---- End of epoch: aggregate ----
        avg_d = float(np.mean(d_losses)) if d_losses else np.nan
        avg_g = float(np.mean(g_losses)) if g_losses else np.nan

        history["epoch"].append(epoch + 1)
        history["d_loss"].append(avg_d)
        history["g_loss"].append(avg_g)

        # ---- Optional: validation PSNR / SSIM on a few batches ----
        if val_loader is not None and len(val_loader) > 0:
            val_psnrs, val_ssims = [], []
            # only a few batches to keep it cheap
            num_val_batches = min(3, len(val_loader))
            for i in range(num_val_batches):
                try:
                    lr_val, hr_val = val_loader.__getitem__(i)
                except Exception:
                    continue
                if len(lr_val) == 0:
                    continue

                # G expects [-1,1], so lr_val is already [-1,1] from generator
                sr_val = generator.predict(lr_val, verbose=0)
                # Convert to [0,1]
                hr_01 = (hr_val + 1.0) / 2.0
                sr_01 = (sr_val + 1.0) / 2.0
                hr_01 = np.clip(hr_01, 0.0, 1.0)
                sr_01 = np.clip(sr_01, 0.0, 1.0)

                # Compute PSNR / SSIM per image, then average
                for h_im, s_im in zip(hr_01, sr_01):
                    tf_hr = tf.convert_to_tensor(h_im, tf.float32)
                    tf_sr = tf.convert_to_tensor(s_im, tf.float32)
                    val_psnrs.append(tf.image.psnr(tf_hr, tf_sr, max_val=1.0).numpy())
                    val_ssims.append(tf.image.ssim(tf_hr, tf_sr, max_val=1.0).numpy())

            if val_psnrs:
                history["val_psnr"].append(float(np.mean(val_psnrs)))
                history["val_ssim"].append(float(np.mean(val_ssims)))
                print(f"[SRGAN] Epoch {epoch+1} VAL | "
                      f"PSNR: {history['val_psnr'][-1]:.2f} dB | "
                      f"SSIM: {history['val_ssim'][-1]:.4f}")
            else:
                history["val_psnr"].append(np.nan)
                history["val_ssim"].append(np.nan)

        print(f"[SRGAN] Epoch {epoch+1} done in {time.time()-start:.1f}s "
              f"| D: {avg_d:.4f} | G: {avg_g:.4f}")
        generator.save(f"srgan_generator_epoch_{epoch+1}.keras")

    print("[SRGAN] Training complete.")
    return history



def predict_srgan_full_image(generator, image_path, scale_factor=4):
    """
    Evaluate SRGAN baseline on a full image.
    Generator expects LR in [-1,1] and outputs SR in [-1,1].
    """
    hr_img = cv2.imread(image_path)
    if hr_img is None:
        print(f"[SRGAN] Could not load image: {image_path}")
        return
    hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)

    h, w, _ = hr_img.shape
    h, w = (h // scale_factor) * scale_factor, (w // scale_factor) * scale_factor
    hr_img = hr_img[:h, :w, :]

    lr_shape = (w // scale_factor, h // scale_factor)
    lr_img_small = cv2.resize(hr_img, lr_shape, interpolation=cv2.INTER_CUBIC)

    # Normalize to [-1,1] for generator
    lr_input_01 = lr_img_small.astype(np.float32) / 255.0
    lr_input = lr_input_01 * 2.0 - 1.0
    inp_batch = np.expand_dims(lr_input, axis=0)

    sr = generator.predict(inp_batch, verbose=0)[0]        # [-1,1]
    sr_01 = (sr + 1.0) / 2.0                               # [0,1]
    sr_01 = np.clip(sr_01, 0.0, 1.0)

    bicubic = cv2.resize(lr_img_small, (w, h), interpolation=cv2.INTER_CUBIC)
    bicubic = bicubic.astype(np.float32) / 255.0

    hr_01 = hr_img.astype(np.float32) / 255.0

    tf_hr = tf.convert_to_tensor(hr_01, tf.float32)
    tf_sr = tf.convert_to_tensor(sr_01, tf.float32)
    tf_bic = tf.convert_to_tensor(bicubic, tf.float32)

    psnr_sr = tf.image.psnr(tf_hr, tf_sr, max_val=1.0).numpy()
    ssim_sr = tf.image.ssim(tf_hr, tf_sr, max_val=1.0).numpy()
    psnr_bic = tf.image.psnr(tf_hr, tf_bic, max_val=1.0).numpy()
    ssim_bic = tf.image.ssim(tf_hr, tf_bic, max_val=1.0).numpy()

    fig, axes = plt.subplots(1, 3, figsize=(24, 10))
    axes[0].imshow(bicubic)
    axes[0].set_title(f"Bicubic\nPSNR: {psnr_bic:.2f} dB | SSIM: {ssim_bic:.4f}")
    axes[0].axis("off")

    axes[1].imshow(sr_01)
    title_col = "green" if psnr_sr > psnr_bic else "black"
    axes[1].set_title(f"SRGAN\nPSNR: {psnr_sr:.2f} dB | SSIM: {ssim_sr:.4f}",
                      color=title_col, fontweight="bold")
    axes[1].axis("off")

    axes[2].imshow(hr_01)
    axes[2].set_title("Ground Truth")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

print("SRGAN baseline models & helpers defined.")


SRGAN baseline models & helpers defined.


In [None]:
srgan_gen = build_srgan_generator(scale=UPSCALE, num_res_blocks=16)
srgan_gen.load_weights(SRRESNET_WARMUP_PATH)
srgan_disc = build_srgan_discriminator(input_shape=(HR_CROP_SIZE, HR_CROP_SIZE, 3))
vgg = build_vgg(hr_shape=(HR_CROP_SIZE, HR_CROP_SIZE, 3))

srgan_disc.compile(
    loss='binary_crossentropy',
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    metrics=['accuracy'],
)

srgan_combined = build_srgan_combined(
    srgan_gen, srgan_disc, vgg,
    lr_shape=(LR_CROP_SIZE, LR_CROP_SIZE, 3)
)

srgan_history = train_srgan_baseline(
    generator=srgan_gen,
    discriminator=srgan_disc,
    srgan=srgan_combined,
    vgg=vgg,
    train_loader=train_gen,
    val_loader=val_gen,      # can be None if you don't want val metrics
    epochs=30,
    steps_per_epoch=50,
)


# 4. Attentive ESRGAN (No BN, Channel Attention, RaLSGAN)

In [10]:
# ============================================================
# 4. Attentive ESRGAN (No BN, Channel Attention, RaLSGAN)
# ============================================================

def channel_attention_block(x, ratio=16):
    channels = x.shape[-1]
    if channels is None:
        channels = 64  # safe fallback for this architecture

    se = layers.GlobalAveragePooling2D(keepdims=True)(x)
    reduced_channels = max(int(channels) // ratio, 1)
    se = layers.Dense(reduced_channels, activation='relu', use_bias=False)(se)
    se = layers.Dense(int(channels), activation='sigmoid', use_bias=False)(se)
    return layers.Multiply()([x, se])


class PixelShuffle(layers.Layer):
    def __init__(self, scale=2, **kwargs):
        super().__init__(**kwargs)
        self.scale = scale

    def call(self, inputs):
        return tf.nn.depth_to_space(inputs, block_size=self.scale)

    def get_config(self):
        cfg = super().get_config()
        cfg.update({"scale": self.scale})
        return cfg


# def attentive_residual_block(x):
#     shortcut = x

#     x = layers.Conv2D(64, 3, padding='same')(x)
#     x = layers.PReLU(shared_axes=[1, 2])(x)
#     x = layers.Conv2D(64, 3, padding='same')(x)
#     x = channel_attention_block(x)

#     # residual scaling 0.2, no Lambda
#     x = layers.Multiply()([x, tf.constant(0.2, dtype=tf.float32)])
#     return layers.Add()([x, shortcut])

def attentive_residual_block(x):
    shortcut = x

    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    x = layers.Conv2D(64, 3, padding='same')(x)
    x = channel_attention_block(x)

    # residual scaling 0.2 — just use tensor math
    x = x * 0.2

    return layers.Add()([x, shortcut])

def upsample_block_no_bn(x, scale=2):
    # 64 * scale^2 filters -> after pixel shuffle we still have 64 channels
    x = layers.Conv2D(64 * (scale ** 2), 3, padding='same')(x)
    x = PixelShuffle(scale)(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    return x


def build_attentive_generator(scale=4, num_res_blocks=16):
    lr_input = Input(shape=(None, None, 3))

    x1 = layers.Conv2D(64, 9, padding='same')(lr_input)
    x1 = layers.PReLU(shared_axes=[1, 2])(x1)

    x = x1
    for _ in range(num_res_blocks):
        x = attentive_residual_block(x)

    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.Add()([x, x1])

    if scale >= 2:
        x = upsample_block_no_bn(x, scale=2)
    if scale >= 4:
        x = upsample_block_no_bn(x, scale=2)

    out = layers.Conv2D(3, 9, padding='same', activation='tanh')(x)
    return models.Model(lr_input, out, name="Attentive_Generator")


def build_relativistic_discriminator(input_shape):
    img_input = Input(shape=input_shape)

    def d_block(x, filters, strides=1, bn=True):
        x = layers.Conv2D(filters, 3, strides=strides, padding='same')(x)
        if bn:
            x = layers.BatchNormalization(momentum=0.8)(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        return x

    x = d_block(img_input, 64, strides=1, bn=False)
    x = d_block(x, 64, strides=2)
    x = d_block(x, 128, strides=1)
    x = d_block(x, 128, strides=2)
    x = d_block(x, 256, strides=1)
    x = d_block(x, 256, strides=2)
    x = d_block(x, 512, strides=1)
    x = d_block(x, 512, strides=2)

    x = layers.Flatten()(x)
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    validity = layers.Dense(1)(x)  # logits
    return models.Model(img_input, validity, name="Relativistic_Discriminator")

def train_attentive_esrgan(generator,
                           discriminator,
                           vgg,
                           train_loader,
                           val_loader=None,
                           epochs=30,
                           steps_per_epoch=50):
    """
    ESRGAN-style training:
      - RaLSGAN adversarial loss
      - VGG19 perceptual loss
      - L1 pixel loss

    Returns:
        history dict with:
            - 'epoch'
            - 'd_loss'
            - 'g_loss'
            - 'val_psnr' (optional)
            - 'val_ssim' (optional)
    """
    g_opt = keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, clipnorm=1.0)
    d_opt = keras.optimizers.Adam(learning_rate=5e-5, beta_1=0.9, beta_2=0.999, clipnorm=1.0)
    mse = keras.losses.MeanSquaredError()

    history = {
        "epoch": [],
        "d_loss": [],
        "g_loss": [],
    }
    if val_loader is not None:
        history["val_psnr"] = []
        history["val_ssim"] = []

    print("[Attentive ESRGAN] Starting training (RaLSGAN)...")

    for epoch in range(epochs):
        start = time.time()
        epoch_d_losses, epoch_g_losses = [], []

        for step in range(steps_per_epoch):
            try:
                lr_imgs, hr_imgs = train_loader.__getitem__(step % len(train_loader))
            except Exception:
                continue

            if len(lr_imgs) != BATCH_SIZE:
                continue

            # -------- Train Discriminator --------
            with tf.GradientTape() as tape_d:
                fake_imgs = generator(lr_imgs, training=True)

                real_logits = discriminator(hr_imgs, training=True)
                fake_logits = discriminator(fake_imgs, training=True)

                mean_fake = tf.reduce_mean(fake_logits, axis=0, keepdims=True)
                mean_real = tf.reduce_mean(real_logits, axis=0, keepdims=True)

                real_rel = real_logits - mean_fake
                fake_rel = fake_logits - mean_real

                d_loss_real = tf.reduce_mean((real_rel - 1.0) ** 2)
                d_loss_fake = tf.reduce_mean((fake_rel + 1.0) ** 2)
                d_loss = 0.5 * (d_loss_real + d_loss_fake)

            d_grads = tape_d.gradient(d_loss, discriminator.trainable_variables)
            d_opt.apply_gradients(zip(d_grads, discriminator.trainable_variables))
            epoch_d_losses.append(d_loss.numpy())

            # -------- Train Generator --------
            with tf.GradientTape() as tape_g:
                fake_imgs = generator(lr_imgs, training=True)

                real_logits = discriminator(hr_imgs, training=False)
                fake_logits = discriminator(fake_imgs, training=False)

                mean_fake = tf.reduce_mean(fake_logits, axis=0, keepdims=True)
                mean_real = tf.reduce_mean(real_logits, axis=0, keepdims=True)

                real_rel = real_logits - mean_fake
                fake_rel = fake_logits - mean_real

                g_loss_real = tf.reduce_mean((real_rel + 1.0) ** 2)
                g_loss_fake = tf.reduce_mean((fake_rel - 1.0) ** 2)
                adv_loss = 0.5 * (g_loss_real + g_loss_fake)

                # VGG perceptual loss
                hr_vgg = preprocess_input((hr_imgs + 1.0) * 127.5)
                fake_vgg = preprocess_input((fake_imgs + 1.0) * 127.5)
                img_features = vgg(hr_vgg, training=False)
                gen_features = vgg(fake_vgg, training=False)
                content_loss = mse(img_features, gen_features)

                # L1 pixel loss
                pixel_loss = tf.reduce_mean(tf.abs(hr_imgs - fake_imgs))

                total_g_loss = (0.006 * content_loss) + (5e-3 * adv_loss) + (1e-2 * pixel_loss)

            g_grads = tape_g.gradient(total_g_loss, generator.trainable_variables)
            g_opt.apply_gradients(zip(g_grads, generator.trainable_variables))
            epoch_g_losses.append(total_g_loss.numpy())

            if step % 10 == 0:
                print(f"[Attentive ESRGAN] Epoch {epoch+1}/{epochs} "
                      f"Step {step}/{steps_per_epoch} | "
                      f"D: {d_loss.numpy():.4f} | G: {total_g_loss.numpy():.4f}")

        # ---- end epoch: aggregate ----
        avg_d = float(np.mean(epoch_d_losses)) if epoch_d_losses else np.nan
        avg_g = float(np.mean(epoch_g_losses)) if epoch_g_losses else np.nan

        history["epoch"].append(epoch + 1)
        history["d_loss"].append(avg_d)
        history["g_loss"].append(avg_g)

        # ---- optional validation metrics ----
        if val_loader is not None and len(val_loader) > 0:
            val_psnrs, val_ssims = [], []
            num_val_batches = min(3, len(val_loader))
            for i in range(num_val_batches):
                try:
                    lr_val, hr_val = val_loader.__getitem__(i)
                except Exception:
                    continue
                if len(lr_val) == 0:
                    continue

                sr_val = generator.predict(lr_val, verbose=0)
                hr_01 = (hr_val + 1.0) / 2.0
                sr_01 = (sr_val + 1.0) / 2.0
                hr_01 = np.clip(hr_01, 0.0, 1.0)
                sr_01 = np.clip(sr_01, 0.0, 1.0)

                for h_im, s_im in zip(hr_01, sr_01):
                    tf_hr = tf.convert_to_tensor(h_im, tf.float32)
                    tf_sr = tf.convert_to_tensor(s_im, tf.float32)
                    val_psnrs.append(tf.image.psnr(tf_hr, tf_sr, max_val=1.0).numpy())
                    val_ssims.append(tf.image.ssim(tf_hr, tf_sr, max_val=1.0).numpy())

            if val_psnrs:
                history["val_psnr"].append(float(np.mean(val_psnrs)))
                history["val_ssim"].append(float(np.mean(val_ssims)))
                print(f"[Attentive ESRGAN] Epoch {epoch+1} VAL | "
                      f"PSNR: {history['val_psnr'][-1]:.2f} dB | "
                      f"SSIM: {history['val_ssim'][-1]:.4f}")
            else:
                history["val_psnr"].append(np.nan)
                history["val_ssim"].append(np.nan)

        generator.save(f"attentive_esrgan_epoch_{epoch+1}.keras")
        print(f"[Attentive ESRGAN] Epoch {epoch+1} done in {time.time()-start:.1f}s | "
              f"D: {avg_d:.4f} | G: {avg_g:.4f}")

    print("[Attentive ESRGAN] Training complete.")
    return history



def predict_attentive_full_image(generator, image_path, scale_factor=4):
    """
    Evaluate Attentive ESRGAN on a full image.
    Same normalization as SRGAN baseline: [-1,1] in, tanh output.
    """
    hr_img = cv2.imread(image_path)
    if hr_img is None:
        print(f"[Attentive ESRGAN] Could not load image: {image_path}")
        return
    hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)

    h, w, _ = hr_img.shape
    h, w = (h // scale_factor) * scale_factor, (w // scale_factor) * scale_factor
    hr_img = hr_img[:h, :w, :]

    lr_shape = (w // scale_factor, h // scale_factor)
    lr_img_small = cv2.resize(hr_img, lr_shape, interpolation=cv2.INTER_CUBIC)

    lr_01 = lr_img_small.astype(np.float32) / 255.0
    lr_in = lr_01 * 2.0 - 1.0
    inp_batch = np.expand_dims(lr_in, axis=0)

    sr = generator.predict(inp_batch, verbose=0)[0]   # [-1,1]
    sr_01 = (sr + 1.0) / 2.0
    sr_01 = np.clip(sr_01, 0.0, 1.0)

    bicubic = cv2.resize(lr_img_small, (w, h), interpolation=cv2.INTER_CUBIC)
    bicubic = bicubic.astype(np.float32) / 255.0
    hr_01 = hr_img.astype(np.float32) / 255.0

    tf_hr = tf.convert_to_tensor(hr_01, tf.float32)
    tf_sr = tf.convert_to_tensor(sr_01, tf.float32)
    tf_bic = tf.convert_to_tensor(bicubic, tf.float32)

    psnr_sr = tf.image.psnr(tf_hr, tf_sr, max_val=1.0).numpy()
    ssim_sr = tf.image.ssim(tf_hr, tf_sr, max_val=1.0).numpy()
    psnr_bic = tf.image.psnr(tf_hr, tf_bic, max_val=1.0).numpy()
    ssim_bic = tf.image.ssim(tf_hr, tf_bic, max_val=1.0).numpy()

    fig, axes = plt.subplots(1, 3, figsize=(24, 10))
    axes[0].imshow(bicubic)
    axes[0].set_title(f"Bicubic\nPSNR: {psnr_bic:.2f} dB | SSIM: {ssim_bic:.4f}")
    axes[0].axis("off")

    axes[1].imshow(sr_01)
    title_col = "green" if psnr_sr > psnr_bic else "black"
    axes[1].set_title(
        f"Attentive ESRGAN\nPSNR: {psnr_sr:.2f} dB | SSIM: {ssim_sr:.4f}",
        color=title_col, fontweight="bold",
    )
    axes[1].axis("off")

    axes[2].imshow(hr_01)
    axes[2].set_title("Ground Truth")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

print("Attentive ESRGAN models & helpers defined.")



Attentive ESRGAN models & helpers defined.


In [11]:
att_gen = build_attentive_generator(scale=UPSCALE, num_res_blocks=16)
att_gen.load_weights(ATTENTIVE_WARMUP_PATH)
att_disc = build_relativistic_discriminator(input_shape=(HR_CROP_SIZE, HR_CROP_SIZE, 3))
vgg_esr = build_vgg(hr_shape=(HR_CROP_SIZE, HR_CROP_SIZE, 3))

att_history = train_attentive_esrgan(
    generator=att_gen,
    discriminator=att_disc,
    vgg=vgg_esr,
    train_loader=train_gen,
    val_loader=val_gen,
    epochs=100,
    steps_per_epoch=50,
)


[Attentive ESRGAN] Starting training (RaLSGAN)...


Expected: ['keras_tensor_462']
Received: inputs=Tensor(shape=(16, 128, 128, 3))


[Attentive ESRGAN] Epoch 1/100 Step 0/50 | D: 1.3509 | G: 0.1321
[Attentive ESRGAN] Epoch 1/100 Step 10/50 | D: 1.6118 | G: 0.1139
[Attentive ESRGAN] Epoch 1/100 Step 20/50 | D: 2.6380 | G: 0.1706
[Attentive ESRGAN] Epoch 1/100 Step 30/50 | D: 0.9081 | G: 0.1625
[Attentive ESRGAN] Epoch 1/100 Step 40/50 | D: 1.0236 | G: 0.1425
[Attentive ESRGAN] Epoch 1 VAL | PSNR: 21.24 dB | SSIM: 0.5260
[Attentive ESRGAN] Epoch 1 done in 174.6s | D: 2.1114 | G: 0.1428
[Attentive ESRGAN] Epoch 2/100 Step 0/50 | D: 0.7063 | G: 0.0774
[Attentive ESRGAN] Epoch 2/100 Step 10/50 | D: 0.9908 | G: 0.1407
[Attentive ESRGAN] Epoch 2/100 Step 20/50 | D: 0.6910 | G: 0.0913
[Attentive ESRGAN] Epoch 2/100 Step 30/50 | D: 0.3261 | G: 0.1947
[Attentive ESRGAN] Epoch 2/100 Step 40/50 | D: 0.6228 | G: 0.1789
[Attentive ESRGAN] Epoch 2 VAL | PSNR: 21.16 dB | SSIM: 0.5388
[Attentive ESRGAN] Epoch 2 done in 169.2s | D: 0.5406 | G: 0.1293
[Attentive ESRGAN] Epoch 3/100 Step 0/50 | D: 0.8355 | G: 0.0866
[Attentive ESRGAN] 

In [16]:
import json
# --- Saving the dictionary to a .json file ---
with open('esrgan_history.json', 'w') as json_file:
    json.dump(att_history, json_file, indent=4) # 'indent' makes the file human-readable

# Visualization

In [None]:
def plot_gan_history(history, title_prefix="Model"):
    """
    Plot D/G loss and optional validation PSNR/SSIM vs epoch.
    history: dict returned by train_srgan_baseline / train_attentive_esrgan.
    """
    epochs = history.get("epoch", list(range(1, len(history.get("d_loss", [])) + 1)))

    # ---- 1) Loss curves ----
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, history["d_loss"], label="D loss")
    plt.plot(epochs, history["g_loss"], label="G loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"{title_prefix} Training Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

    # ---- 2) Validation metrics (if available) ----
    if "val_psnr" in history and "val_ssim" in history:
        fig, ax1 = plt.subplots(figsize=(8, 5))

        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("PSNR (dB)", color="tab:blue")
        ax1.plot(epochs, history["val_psnr"], marker="o", label="PSNR", color="tab:blue")
        ax1.tick_params(axis='y', labelcolor="tab:blue")
        ax1.grid(True, axis='y', linestyle='--', alpha=0.3)

        ax2 = ax1.twinx()
        ax2.set_ylabel("SSIM", color="tab:orange")
        ax2.plot(epochs, history["val_ssim"], marker="s", label="SSIM", color="tab:orange")
        ax2.tick_params(axis='y', labelcolor="tab:orange")

        plt.title(f"{title_prefix} Validation Metrics")
        fig.tight_layout()
        plt.show()


plot_gan_history(srgan_history, title_prefix="SRGAN Baseline")
plot_gan_history(att_history, title_prefix="Attentive ESRGAN")


In [None]:
def compare_two_gan_models(gen_a,
                           gen_b,
                           label_a="Model A",
                           label_b="Model B",
                           image_path=None,
                           scale_factor=4):
    """
    Compare two GAN-based SR models (tanh output in [-1,1]) on the same image.
    Shows: Bicubic, Model A, Model B, Ground Truth
    and prints PSNR/SSIM for each model.
    """
    if image_path is None:
        print("Please provide image_path.")
        return

    # 1. Load HR image
    hr_img = cv2.imread(image_path)
    if hr_img is None:
        print(f"[Compare] Could not load image: {image_path}")
        return
    hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)

    h, w, _ = hr_img.shape
    h, w = (h // scale_factor) * scale_factor, (w // scale_factor) * scale_factor
    hr_img = hr_img[:h, :w, :]

    # 2. Create LR
    lr_shape = (w // scale_factor, h // scale_factor)
    lr_img_small = cv2.resize(hr_img, lr_shape, interpolation=cv2.INTER_CUBIC)

    # 3. Prepare input [-1,1] for both models
    lr_01 = lr_img_small.astype(np.float32) / 255.0
    lr_in = lr_01 * 2.0 - 1.0
    inp_batch = np.expand_dims(lr_in, axis=0)

    # 4. Run both models
    sr_a = gen_a.predict(inp_batch, verbose=0)[0]    # [-1,1]
    sr_b = gen_b.predict(inp_batch, verbose=0)[0]    # [-1,1]

    sr_a_01 = np.clip((sr_a + 1.0) / 2.0, 0.0, 1.0)
    sr_b_01 = np.clip((sr_b + 1.0) / 2.0, 0.0, 1.0)

    bicubic = cv2.resize(lr_img_small, (w, h), interpolation=cv2.INTER_CUBIC)
    bicubic = bicubic.astype(np.float32) / 255.0
    hr_01 = hr_img.astype(np.float32) / 255.0

    # 5. Compute metrics
    def _metrics(sr_img_01):
        tf_hr = tf.convert_to_tensor(hr_01, tf.float32)
        tf_sr = tf.convert_to_tensor(sr_img_01, tf.float32)
        psnr = tf.image.psnr(tf_hr, tf_sr, max_val=1.0).numpy()
        ssim = tf.image.ssim(tf_hr, tf_sr, max_val=1.0).numpy()
        return psnr, ssim

    psnr_bic, ssim_bic = _metrics(bicubic)
    psnr_a, ssim_a = _metrics(sr_a_01)
    psnr_b, ssim_b = _metrics(sr_b_01)

    print(f"[Bicubic]   PSNR: {psnr_bic:.2f} dB | SSIM: {ssim_bic:.4f}")
    print(f"[{label_a}] PSNR: {psnr_a:.2f} dB | SSIM: {ssim_a:.4f}")
    print(f"[{label_b}] PSNR: {psnr_b:.2f} dB | SSIM: {ssim_b:.4f}")

    # 6. Plot
    fig, axes = plt.subplots(1, 4, figsize=(28, 8))

    axes[0].imshow(bicubic)
    axes[0].set_title(f"Bicubic\nPSNR: {psnr_bic:.2f} | SSIM: {ssim_bic:.4f}")
    axes[0].axis("off")

    axes[1].imshow(sr_a_01)
    axes[1].set_title(f"{label_a}\nPSNR: {psnr_a:.2f} | SSIM: {ssim_a:.4f}")
    axes[1].axis("off")

    axes[2].imshow(sr_b_01)
    axes[2].set_title(f"{label_b}\nPSNR: {psnr_b:.2f} | SSIM: {ssim_b:.4f}")
    axes[2].axis("off")

    axes[3].imshow(hr_01)
    axes[3].set_title("Ground Truth")
    axes[3].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
test_y_dir = "/kaggle/input/super-resolution-test-cases/test cases"
test_files = sorted(os.listdir(test_y_dir))
img_path = os.path.join(test_y_dir, test_files[0])

compare_two_gan_models(
    gen_a=srgan_gen,
    gen_b=att_gen,
    label_a="SRGAN Baseline",
    label_b="Attentive ESRGAN",
    image_path=img_path,
    scale_factor=4,
)


In [None]:
# # ============================================================
# # 5. Example usage (comment/uncomment as needed)
# # ============================================================

# if __name__ == "__main__":
#     print("Setup complete. Edit this block to run training/inference.")

#     # Example test images dir (must add this dataset in Kaggle "Data" tab)
#     test_y_dir = "/kaggle/input/super-resolution-test-cases/test cases"
#     test_files = sorted(os.listdir(test_y_dir)) if os.path.exists(test_y_dir) else []

#     # ---------- 1) SRCNN baseline ----------
#     # srcnn = build_srcnn()
#     # try:
#     #     srcnn = keras.models.load_model(SRCNN_PRETRAINED_PATH, compile=False)
#     #     print("Loaded pretrained SRCNN from:", SRCNN_PRETRAINED_PATH)
#     # except Exception as e:
#     #     print("Could not load pretrained SRCNN:", e)
#     #
#     # if test_files:
#     #     img_path = os.path.join(test_y_dir, test_files[0])
#     #     predict_srcnn_full_image(srcnn, img_path)

#     # ---------- 2) SRGAN baseline ----------
#     # srgan_gen = build_srgan_generator(scale=UPSCALE, num_res_blocks=16)
#     # srgan_disc = build_srgan_discriminator(input_shape=(HR_CROP_SIZE, HR_CROP_SIZE, 3))
#     # vgg = build_vgg(hr_shape=(HR_CROP_SIZE, HR_CROP_SIZE, 3))
#     # srgan_disc.compile(loss='binary_crossentropy',
#     #                    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
#     #                    metrics=['accuracy'])
#     # srgan_combined = build_srgan_combined(srgan_gen, srgan_disc, vgg,
#     #                                       lr_shape=(LR_CROP_SIZE, LR_CROP_SIZE, 3))
#     #
#     # try:
#     #     srgan_gen.load_weights(SRRESNET_WARMUP_PATH)
#     #     print("Loaded SRResNet warm-up weights from:", SRRESNET_WARMUP_PATH)
#     # except Exception as e:
#     #     print("Could not load SRResNet warm-up weights:", e)
#     #
#     # # Quick sanity training run
#     # # train_srgan_baseline(srgan_gen, srgan_disc, srgan_combined, vgg,
#     # #                      train_gen, epochs=1, steps_per_epoch=10)
#     #
#     # if test_files:
#     #     img_path = os.path.join(test_y_dir, test_files[0])
#     #     predict_srgan_full_image(srgan_gen, img_path)

#     # ---------- 3) Attentive ESRGAN ----------
#     # att_gen = build_attentive_generator(scale=UPSCALE, num_res_blocks=16)
#     # att_disc = build_relativistic_discriminator(input_shape=(HR_CROP_SIZE, HR_CROP_SIZE, 3))
#     # vgg_esr = build_vgg(hr_shape=(HR_CROP_SIZE, HR_CROP_SIZE, 3))
#     #
#     # try:
#     #     att_gen.load_weights(ATTENTIVE_WARMUP_PATH)
#     #     print("Loaded attentive warm-up weights from:", ATTENTIVE_WARMUP_PATH)
#     # except Exception as e:
#     #     print("Could not load attentive warm-up weights:", e)
#     #
#     # # Quick sanity training run
#     # # train_attentive_esrgan(att_gen, att_disc, vgg_esr,
#     # #                        train_gen, epochs=1, steps_per_epoch=10)
#     #
#     # if test_files:
#     #     img_path = os.path.join(test_y_dir, test_files[0])
#     #     predict_attentive_full_image(att_gen, img_path)
# # 