## Imports

In [None]:
import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, Dropout, UpSampling2D, Concatenate
import shutil
import tempfile
import requests
from io import BytesIO
from PIL import Image

## Params & Optimizations

In [None]:
tf.config.optimizer.set_jit(True)

# --- Configuration Parameters ---
PATCH_SIZE = 224
PATCH_STRIDE = PATCH_SIZE // 2
BATCH_SIZE = 32
AUTOTUNE = tf.data.AUTOTUNE
MIN_VAR = 1e-6 
MAX_PATCHES_PER_IMAGE = 6
DATA_PATH = '/content/DATA'
img_path = '/content/DATA/1812.png'

## Image degradation & load

In [None]:
def degrade_tf(img_uint8):
    """
    Applies random degradations (JPEG, brightness, contrast, noise) using TensorFlow.

    Args:
        img_uint8 (numpy.ndarray): Input image as a uint8 array.

    Returns:
        numpy.ndarray: Degraded image as a float32 array.
    """
    img = tf.convert_to_tensor(img_uint8, dtype=tf.uint8)
    # Apply random JPEG compression artifacts
    img = tf.image.random_jpeg_quality(img, 20, 50)
    # Randomly adjust brightness and contrast
    img = tf.image.random_brightness(tf.cast(img, tf.float32), 0.15)
    img = tf.image.random_contrast(img, 0.8, 1.2)
    # Add additive Gaussian noise
    noise = tf.random.normal(tf.shape(img), stddev=0.03 * 255.0)
    img = img + noise
    # Ensure pixel values stay within [0, 255] range
    img = tf.clip_by_value(img, 0.0, 255.0)
    return img.numpy().astype(np.float32)

def load_img(path):
    """
    Loads and decodes an image from a file path.

    Args:
        path (str): Path to the image file.

    Returns:
        tf.Tensor: Decoded image tensor with 3 color channels.
    """
    img = tf.io.read_file(path)
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    return img

## Patch extraction/merge and padding/unpadding

In [None]:
def extract_patches(img, patch_size=224, step=112):
    """
    Splits an image into overlapping patches.

    Args:
        img (numpy.ndarray): Input image array.
        patch_size (int): Dimensions of each square patch.
        step (int): Stride for extraction.

    Returns:
        tuple: (Array of patches, List of (y, x) coordinates, Original height, Original width).
    """
    H, W, C = img.shape
    patches = []
    positions = []

    y_poss = list(range(0, H - patch_size + 1, step))
    if y_poss[-1] + patch_size < H:
        y_poss.append(H - patch_size)

    x_poss = list(range(0, W - patch_size + 1, step))
    if x_poss[-1] + patch_size < W:
        x_poss.append(W - patch_size)

    for y in y_poss:
        for x in x_poss:
            patches.append(img[y:y+patch_size, x:x+patch_size])
            positions.append((y, x))

    return np.array(patches), positions, H, W

def get_weight_map(patch_size, eps=0.05):
    """
    Creates a 2D Hanning window map for smooth patch blending.

    Args:
        patch_size (int): Size of the patch side.
        eps (float): Minimum weight threshold.

    Returns:
        numpy.ndarray: 2D weight map array.
    """
    w = np.hanning(patch_size)
    w = np.maximum(w, eps)
    w2d = np.outer(w, w)
    return w2d[..., None]

def merge_patches(patches, positions, H, W, patch_size):
    """
    Reconstructs an image from patches using weighted averaging.

    Args:
        patches (numpy.ndarray): Array of patches.
        positions (list): Coordinates for each patch.
        H (int): Target height.
        W (int): Target width.
        patch_size (int): Size of the patches.

    Returns:
        numpy.ndarray: Merged image as uint8.
    """
    C = patches[0].shape[2]
    output = np.zeros((H, W, C), dtype=np.float32)
    weight = np.zeros((H, W, 1), dtype=np.float32)
    weight_map = get_weight_map(patch_size)

    for patch, (y, x) in zip(patches, positions):
        output[y:y+patch_size, x:x+patch_size] += patch * weight_map
        weight[y:y+patch_size, x:x+patch_size] += weight_map

    output /= np.maximum(weight, 1e-6)
    output = np.clip(output, 0, 255)
    return output.astype(np.uint8)

def pad_to_square(img, patch_size, pad_mode="reflect"):
    """
    Pads image to a square with dimensions divisible by patch_size.

    Args:
        img (numpy.ndarray): Input image.
        patch_size (int): Size requirement for dimensions.
        pad_mode (str): Padding strategy.

    Returns:
        tuple: (Padded image, dict containing padding metadata).
    """
    if img.ndim != 3:
        raise ValueError(f"Expected (H, W, C), got {img.shape}")

    h, w, c = img.shape
    side = max(h, w)
    target = ((side + patch_size - 1) // patch_size) * patch_size

    pad_h = target - h
    pad_w = target - w

    top = pad_h // 2
    bottom = pad_h - top
    left = pad_w // 2
    right = pad_w - left

    padded = np.pad(img, ((top, bottom), (left, right), (0, 0)), mode=pad_mode)
    pad_info = {"top": top, "bottom": bottom, "left": left, "right": right, "orig_shape": (h, w)}

    return padded, pad_info

def unpad_image(img, pad_info):
    """
    Removes padding using previously saved metadata.

    Args:
        img (numpy.ndarray): Padded image.
        pad_info (dict): Metadata with offsets and original shape.

    Returns:
        numpy.ndarray: Cropped original image.
    """
    top = pad_info["top"]
    h, w = pad_info["orig_shape"]
    left = pad_info["left"]
    return img[top : top + h, left : left + w]

## Dataset building

In [None]:
all_paths = [os.path.join(DATA_PATH, f) for f in os.listdir(DATA_PATH)
             if f.lower().endswith(('.jpg', '.png'))]
random.shuffle(all_paths)
split = int(0.8 * len(all_paths))
train_paths = all_paths[:split]
val_paths = all_paths[split:]

def patch_generator(paths, patch_size=PATCH_SIZE, step=PATCH_STRIDE,
                    max_patches_per_image=MAX_PATCHES_PER_IMAGE,
                    min_var=MIN_VAR, identity_prob=0.2):
    """
    Generator for training data (degraded patch vs original patch).
    
    Yields:
        tuple: (Normalized input patch, Normalized target patch).
    """
    for path in random.sample(paths, len(paths)):
        img = cv2.imread(path)
        if img is None: continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img, _ = pad_to_square(img, patch_size=patch_size)
        patches, _, _, _ = extract_patches(img, patch_size=patch_size, step=step)
        
        idxs = list(range(len(patches)))
        random.shuffle(idxs)
        count = 0
        for i in idxs:
            patch = patches[i]
            if patch.std() < min_var: continue
            
            target = patch.astype(np.float32) / 127.5 - 1.0
            if random.random() < identity_prob:
                input_patch = patch.copy()
            else:
                input_patch = degrade_tf(patch)
            
            input_patch = (np.clip(input_patch, 0, 255) / 127.5 - 1.0).astype(np.float32)
            yield input_patch, target
            
            count += 1
            if count >= max_patches_per_image: break

def val_patch_generator(paths, patch_size=PATCH_SIZE, step=PATCH_STRIDE, min_var=MIN_VAR):
    """
    Generator for validation data. Always applies degradation.

    Yields:
        tuple: (Normalized input patch, Normalized target patch).
    """
    for path in paths:
        img = cv2.imread(path)
        if img is None: continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img, _ = pad_to_square(img, patch_size=patch_size)
        patches, _, _, _ = extract_patches(img, patch_size=patch_size, step=step)

        for patch in patches:
            if patch.std() < min_var: continue
            target = patch.astype(np.float32) / 127.5 - 1.0
            input_patch = (np.clip(degrade_tf(patch), 0, 255) / 127.5 - 1.0).astype(np.float32)
            yield input_patch, target

# Build TF Datasets
train_ds = tf.data.Dataset.from_generator(
    lambda: patch_generator(train_paths),
    output_signature=(
        tf.TensorSpec((PATCH_SIZE, PATCH_SIZE, 3), tf.float32),
        tf.TensorSpec((PATCH_SIZE, PATCH_SIZE, 3), tf.float32)
    )
).shuffle(500).batch(BATCH_SIZE).prefetch(AUTOTUNE)

val_ds = tf.data.Dataset.from_generator(
    lambda: val_patch_generator(val_paths),
    output_signature=(
        tf.TensorSpec((PATCH_SIZE, PATCH_SIZE, 3), tf.float32),
        tf.TensorSpec((PATCH_SIZE, PATCH_SIZE, 3), tf.float32)
    )
).batch(BATCH_SIZE).prefetch(AUTOTUNE)


In [None]:
plt.imshow(load_img(img_path))

Build Model

In [None]:
inputs = layers.Input(shape=input_shape)

# --- Encoder ---
# Block 1: 224 -> 112
x = layers.Conv2D(16, 3, padding='same')(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(16, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
skip1 = x
x = layers.MaxPool2D(2)(x)
x = layers.Dropout(0.2)(x)

# Block 2: 112 -> 56
x = layers.Conv2D(32, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(32, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
skip2 = x
x = layers.MaxPool2D(2)(x)
x = layers.Dropout(0.3)(x)

# Block 3: 56 -> 28
x = layers.Conv2D(64, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(64, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
skip3 = x
x = layers.MaxPool2D(2)(x)
x = layers.Dropout(0.4)(x)

# Block 4: 28 -> 14
x = layers.Conv2D(128, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(128, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
skip4 = x
x = layers.MaxPool2D(2)(x)
x = layers.Dropout(0.5)(x)

# --- Bottleneck ---
x = layers.Conv2D(256, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(256, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(0.5)(x)

# --- Decoder ---
# Block 4 Up: 14 -> 28
x = layers.UpSampling2D(2)(x)
x = layers.Concatenate()([x, skip4])
x = layers.Conv2D(128, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(128, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(0.5)(x)

# Block 3 Up: 28 -> 56
x = layers.UpSampling2D(2)(x)
x = layers.Concatenate()([x, skip3])
x = layers.Conv2D(64, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(64, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(0.4)(x)

# Block 2 Up: 56 -> 112
x = layers.UpSampling2D(2)(x)
x = layers.Concatenate()([x, skip2])
x = layers.Conv2D(32, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(32, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(0.3)(x)

# Block 1 Up: 112 -> 224
x = layers.UpSampling2D(2)(x)
x = layers.Concatenate()([x, skip1])
x = layers.Conv2D(16, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(16, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Dropout(0.2)(x)

# Final Layer
outputs = layers.Conv2D(3, 1, activation='tanh', padding='same')(x)

Fit

In [None]:
# This is a mse_mae loss function

# ===============================================================
# LOSS FUNCTIONS AND EVALUATION METRICS
# ===============================================================

def mae_mse_loss(y_true, y_pred):
    """
    Combined Loss function using Mean Absolute Error (MAE) and Mean Squared Error (MSE).
    
    Weights are set to 60% MAE and 40% MSE to balance pixel-wise accuracy 
    and outlier penalization.

    Args:
        y_true (tf.Tensor): Ground truth image patches.
        y_pred (tf.Tensor): Predicted (enhanced) image patches.

    Returns:
        tf.Tensor: Weighted loss value.
    """
    mae = tf.reduce_mean(tf.abs(y_true - y_pred))
    mse = tf.reduce_mean(tf.square(y_true - y_pred))
    return 0.6 * mae + 0.4 * mse

def psnr_metric(y_true, y_pred):
    """
    Peak Signal-to-Noise Ratio (PSNR) metric.
    
    Measures the quality of reconstruction. Higher values indicate better quality.
    Note: max_val is 2.0 because data is normalized in range [-1, 1].

    Args:
        y_true (tf.Tensor): Ground truth image patches.
        y_pred (tf.Tensor): Predicted image patches.

    Returns:
        tf.Tensor: PSNR value in decibels.
    """
    return tf.image.psnr(y_true, y_pred, max_val=2.0)

def ssim_metric(y_true, y_pred):
    """
    Structural Similarity Index (SSIM) metric.
    
    Evaluates visual similarity based on luminance, contrast, and structure.
    Note: max_val is 2.0 because data is normalized in range [-1, 1].

    Args:
        y_true (tf.Tensor): Ground truth image patches.
        y_pred (tf.Tensor): Predicted image patches.

    Returns:
        tf.Tensor: SSIM value (range -1 to 1, where 1 is identical).
    """
    return tf.image.ssim(y_true, y_pred, max_val=2.0)

# Optimizer setup with AdamW (Adam with Weight Decay)
optimizer = tf.keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-5)

# Compile model using the custom hybrid loss and metrics
model.compile(
    optimizer=optimizer,
    loss=mae_mse_loss,
    metrics=[
        tf.keras.metrics.MeanSquaredError(name='mse'),
        psnr_metric,
        ssim_metric
    ]
)

# ===============================================================
# TRAINING PIPELINE
# ===============================================================

# Define training monitoring and adjustment callbacks
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath='best_model.keras',
        save_best_only=True,
        monitor='val_loss',
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        patience=3,
        factor=0.5,
        min_lr=1e-7,
        verbose=1
    )
]

# Execute model training
history = model.fit(
    train_ds,
    epochs=20,
    validation_data=val_ds,
    callbacks=callbacks
)

In [None]:
# This is a edge loss function

# ===============================================================
# PERFORMANCE METRICS
# ===============================================================

def psnr_metric(y_true, y_pred):
    """
    Computes the Peak Signal-to-Noise Ratio (PSNR) between images.

    Args:
        y_true (tf.Tensor): Ground truth image patches.
        y_pred (tf.Tensor): Predicted image patches.

    Returns:
        tf.Tensor: PSNR value based on a max value of 2.0 (range [-1, 1]).
    """
    return tf.image.psnr(y_true, y_pred, max_val=2.0)


# ===============================================================
# PRIMARY LOSS FUNCTIONS
# ===============================================================

@tf.function
def mae_mse_loss(y_true, y_pred):
    """
    Hybrid loss combining Mean Absolute Error and Mean Squared Error.

    Args:
        y_true (tf.Tensor): Ground truth image patches.
        y_pred (tf.Tensor): Predicted image patches.

    Returns:
        tf.Tensor: Weighted loss (60% MAE + 40% MSE).
    """
    mae = tf.reduce_mean(tf.abs(y_true - y_pred))
    mse = tf.reduce_mean(tf.square(y_true - y_pred))
    return 0.6 * mae + 0.4 * mse


# ===============================================================
# EDGE-AWARE LOSS COMPONENTS
# ===============================================================

@tf.function
def luminance(x):
    """
    Converts an RGB image to grayscale luminance using standard coefficients.

    Args:
        x (tf.Tensor): Input RGB tensor.

    Returns:
        tf.Tensor: Grayscale luminance tensor.
    """
    return 0.299 * x[..., 0:1] + 0.587 * x[..., 1:2] + 0.114 * x[..., 2:3]


@tf.function
def edge_loss(y_true, y_pred):
    """
    Computes loss based on image edges using the Sobel operator.
    Downsamples the image first to focus on significant structural edges.

    Args:
        y_true (tf.Tensor): Ground truth image patches.
        y_pred (tf.Tensor): Predicted image patches.

    Returns:
        tf.Tensor: Mean Absolute Error between Sobel edge maps.
    """
    # Downsample to reduce noise sensitivity
    y_true_down = tf.nn.avg_pool2d(y_true, 4, 4, padding="SAME")
    y_pred_down = tf.nn.avg_pool2d(y_pred, 4, 4, padding="SAME")

    y_true_lum = luminance(y_true_down)
    y_pred_lum = luminance(y_pred_down)

    # Calculate Sobel edges
    edge_true = tf.stop_gradient(tf.image.sobel_edges(y_true_lum))
    edge_pred = tf.stop_gradient(tf.image.sobel_edges(y_pred_lum))

    return tf.reduce_mean(tf.abs(edge_true - edge_pred))


# ===============================================================
# LOSS CONTROL PARAMETERS
# ===============================================================

# Boolean flag to toggle edge loss calculation during training batches
edge_enabled = tf.Variable(False, trainable=False)
# Weight factor for the edge loss component
edge_weight = tf.constant(0.15, dtype=tf.float32)


# ===============================================================
# COMPOSITE LOSS CALCULATION
# ===============================================================

def total_loss(y_true, y_pred):
    """
    Calculates the final loss by combining base loss and optional edge loss.

    Args:
        y_true (tf.Tensor): Ground truth image patches.
        y_pred (tf.Tensor): Predicted image patches.

    Returns:
        tf.Tensor: Sum of base loss and weighted edge loss if enabled.
    """
    base = mae_mse_loss(y_true, y_pred)
    edge = tf.cond(
        edge_enabled,
        lambda: 
        edge_loss(y_true, y_pred),
        lambda: tf.constant(0.0, dtype=tf.float32)
    )
    return base + edge_weight * edge


# ===============================================================
# TRAINING CALLBACKS
# ===============================================================

class EdgeScheduler(tf.keras.callbacks.Callback):
    """
    Keras Callback to toggle Edge Loss calculation every N batches.
    This helps balance structural learning without excessive computational overhead.
    """
    def __init__(self, every_n_steps=4):
        super().__init__()
        self.every_n_steps = every_n_steps

    def on_train_batch_begin(self, batch, logs=None):
        """Executed at the start of every batch to update the edge_enabled flag."""
        edge_enabled.assign(batch % self.every_n_steps == 0)


# ===============================================================
# MODEL COMPILATION
# ===============================================================

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
    loss=total_loss,
    metrics=[
        tf.keras.metrics.MeanSquaredError(name="mse"),
        psnr_metric
    ],
    run_eagerly=False
)

# List of callbacks for monitoring and scheduled adjustments
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath="best_model.keras",
        save_best_only=True,
        monitor="val_loss",
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        patience=3,
        factor=0.5,
        verbose=1
    ),
    EdgeScheduler(every_n_steps=4)
]


# ===============================================================
# MODEL TRAINING
# ===============================================================

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=callbacks
)