In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.model_selection import train_test_split, KFold
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import (
    Callback,
    LearningRateScheduler,
    EarlyStopping,
    ModelCheckpoint
)
from tensorflow.keras.optimizers import Adam
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from tensorflow.keras.layers import (
    Layer,
    Input,
    Conv2D,
    MaxPooling2D,
    UpSampling2D,
    Concatenate,
    Add,
    Activation,
    GlobalAveragePooling2D,
    GlobalMaxPooling2D,
    Dense,
    Reshape,
    Lambda,
    BatchNormalization,
    Multiply
)
from tensorflow.keras import backend as K
import glob
import zipfile

# Decorate custom functions for serialization
@tf.keras.utils.register_keras_serializable()
def mean_axis_last(x):
    return K.mean(x, axis=-1, keepdims=True)

@tf.keras.utils.register_keras_serializable()
def max_axis_last(x):
    return K.max(x, axis=-1, keepdims=True)

@tf.keras.utils.register_keras_serializable()
def batch_dot_axes(x):
    return K.batch_dot(x[0], x[1], axes=[2, 2])

@tf.keras.utils.register_keras_serializable()
def batch_dot(x):
    return K.batch_dot(x[0], x[1])

In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import glob
import cv2
import scipy.io
import zipfile
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, Add, Activation,
                                     Multiply, GlobalAveragePooling2D, Reshape, Dense, Lambda)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (EarlyStopping, ModelCheckpoint, LearningRateScheduler, Callback)
from tensorflow.keras.utils import plot_model
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

In [None]:
import os
import cv2
import numpy as np
import glob
import scipy.io

# Function to load images from directories
def load_images(noisy_dir, clean_dir, image_size=(128, 128)):
    noisy_images = []
    clean_images = []

    # Sort files to ensure alignment
    noisy_files = sorted(os.listdir(noisy_dir))
    clean_files = sorted(os.listdir(clean_dir))

    for noisy_file, clean_file in zip(noisy_files, clean_files):
        noisy_path = os.path.join(noisy_dir, noisy_file)
        clean_path = os.path.join(clean_dir, clean_file)

        # Read images
        noisy_image = cv2.imread(noisy_path)
        clean_image = cv2.imread(clean_path)

        if noisy_image is None or clean_image is None:
            print(f"Warning: Skipping unmatched or unreadable file pair: {noisy_file}, {clean_file}")
            continue

        # Resize images to target size
        noisy_image = cv2.resize(noisy_image, image_size)
        clean_image = cv2.resize(clean_image, image_size)

        # Normalize images to [0, 1]
        noisy_images.append(noisy_image / 255.0)
        clean_images.append(clean_image / 255.0)

    return np.array(noisy_images), np.array(clean_images)

# Function to load test data from .mat files
def load_data_from_mat(noisy_path, clean_path):
    noisy_data = scipy.io.loadmat(noisy_path)
    clean_data = scipy.io.loadmat(clean_path)

    # Extract data from .mat structure
    noisy_images = noisy_data.get("ValidationNoisyBlocksSrgb", None)
    clean_images = clean_data.get("ValidationGtBlocksSrgb", None)

    if noisy_images is None or clean_images is None:
        raise ValueError("Invalid keys in .mat files. Ensure keys are 'ValidationNoisyBlocksSrgb' and 'ValidationGtBlocksSrgb'.")

    # Reshape and normalize data
    noisy_images = noisy_images / 255.0
    clean_images = clean_images / 255.0

    # Flatten the blocks into a list of images
    noisy_images = noisy_images.reshape((-1, *noisy_images.shape[2:]))
    clean_images = clean_images.reshape((-1, *clean_images.shape[2:]))

    return noisy_images, clean_images


# ------------------------------------------------------------------------------
# Paths
train_noisy_dir = " /Final_DN_Traning/DSENet/dataset/SSID_new/train/SIDD/input_crops"
train_clean_dir = " /Final_DN_Traning/DSENet/dataset/SSID_new/train/SIDD/target_crops"
val_noisy_dir = " /Final_DN_Traning/DSENet/dataset/SSID_new/val/SIDD/input_crops"
val_clean_dir = " /Final_DN_Traning/DSENet/dataset/SSID_new/val/SIDD/target_crops"
test_noisy_path = " /Final_DN_Traning/DSENet/dataset/SSID_new/test/SIDD/ValidationNoisyBlocksSrgb.mat"
test_clean_path = " /Final_DN_Traning/DSENet/dataset/SSID_new/test/SIDD/ValidationGtBlocksSrgb.mat"

# Load training and validation datasets
print("Loading training data...")
X_train, y_train = load_images(train_noisy_dir, train_clean_dir)

print("Loading validation data...")
X_val, y_val = load_images(val_noisy_dir, val_clean_dir)

# Print dataset shapes
print(f"Training Data: Noisy {X_train.shape}, Clean {y_train.shape}")
print(f"Validation Data: Noisy {X_val.shape}, Clean {y_val.shape}")

print("Loading test data...")
# Load test data from .mat files
X_test, y_test = load_data_from_mat(test_noisy_path, test_clean_path)

# Print dataset shapes for verification
print(f"Training Data (Original): Noisy {X_train.shape}, Clean {y_train.shape}")
print(f"Validation Data: Noisy {X_val.shape}, Clean {y_val.shape}")
print(f"Test Data: Noisy {X_test.shape}, Clean {y_test.shape}")

In [None]:
import tensorflow as tf

# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#     try:
#         for gpu in gpus:
#             tf.config.experimental.set_memory_growth(gpu, True)
#     except RuntimeError as e:
#         print(e)
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Run your model again
# import tensorflow as tf

# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#     try:
#         for gpu in gpus:
#             tf.config.experimental.set_memory_growth(gpu, True)
#     except RuntimeError as e:
#         print(e)


# TQDM Progress Bar Callback
class TQDMProgressBar(Callback):
    def __init__(self, total_steps, update_interval=0.1):
        super().__init__()
        self.total_steps = total_steps
        self.update_interval = update_interval
        self.steps_per_update = max(1, int(self.total_steps * self.update_interval))
        self.total_epochs = None

    def on_train_begin(self, logs=None):
        self.total_epochs = self.params['epochs']

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch
        self.tqdm = tqdm(total=self.total_steps, desc=f'Epoch {epoch+1}/{self.total_epochs}', position=0, leave=True)
        self.start_time = time.time()

    def on_batch_end(self, batch, logs=None):
        self.tqdm.update(1)
        if batch % self.steps_per_update == 0:
            loss = logs.get('loss')
            psnr = logs.get('psnr_metric')
            ssim = logs.get('ssim_metric')
            self.tqdm.set_postfix(loss=loss, PSNR=psnr, SSIM=ssim)

    def on_epoch_end(self, epoch, logs=None):
        self.tqdm.close()
        epoch_time = time.time() - self.start_time
        loss = logs.get('loss', 0)
        psnr = logs.get('psnr_metric', 0)
        ssim = logs.get('ssim_metric', 0)
        print(f"Epoch {epoch+1} completed in {epoch_time:.2f} seconds - loss: {loss:.4f}, PSNR: {psnr:.4f}, SSIM: {ssim:.4f}")

# Custom Callback for logging
class CustomCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"Average loss in this epoch is: {logs['loss']:.4f}")
        print(f"PSNR: {logs['psnr_metric']:.4f}, SSIM: {logs['ssim_metric']:.4f}")

# Function to load images
def load_images(noisy_dir, clean_dir, image_size=(128, 128), max_images=20000):
    noisy_images = []
    clean_images = []
    noisy_pattern = os.path.join(noisy_dir, '*.*')
    clean_pattern = os.path.join(clean_dir, '*.*')
    noisy_files = sorted(glob.glob(noisy_pattern))[:max_images]
    clean_files = sorted(glob.glob(clean_pattern))[:max_images]

    for noisy_path, clean_path in zip(noisy_files, clean_files):
        noisy_image = cv2.imread(noisy_path)
        if noisy_image is None:
            print(f"Warning: Unable to read noisy image file {noisy_path}")
            continue

        clean_image = cv2.imread(clean_path)
        if clean_image is None:
            print(f"Warning: Unable to read clean image file {clean_path}")
            continue

        noisy_image = cv2.resize(noisy_image, image_size)
        clean_image = cv2.resize(clean_image, image_size)

        noisy_images.append(noisy_image / 255.0)
        clean_images.append(clean_image / 255.0)

    return np.array(noisy_images), np.array(clean_images)

from tensorflow.keras.layers import (
    Input, GlobalAveragePooling2D, Reshape, Dense, Multiply,
    Conv2D, Add, Activation, Lambda, MaxPooling2D, UpSampling2D,
    Concatenate, LayerNormalization, Dropout
)
from tensorflow.keras.models import Model
from tensorflow.keras.initializers import lecun_normal
import tensorflow.keras.backend as K
import tensorflow as tf

# Squeeze-and-Excitation Block
def se_block(input_tensor, ratio=16):
    channel_axis = -1
    filters = input_tensor.shape[channel_axis]
    se_shape = (1, 1, filters)

    se = GlobalAveragePooling2D()(input_tensor)
    se = Reshape(se_shape)(se)
    se = Dense(max(1, filters // ratio), activation='relu', kernel_initializer=lecun_normal(), use_bias=False)(se)
    se = Dense(filters, activation='sigmoid', kernel_initializer=lecun_normal(), use_bias=False)(se)
    x = Multiply()([input_tensor, se])
    return x

# Attention-SE Block Hybrid
def attention_se_block(x, g, inter_channel):
    x = se_block(x)
    
    theta_x = Conv2D(inter_channel, 1, padding='same', kernel_initializer=lecun_normal())(x)
    phi_g = Conv2D(inter_channel, 1, padding='same', kernel_initializer=lecun_normal())(g)
    add_xg = Add()([theta_x, phi_g])
    selu_xg = Activation('relu')(add_xg)
    psi = Conv2D(1, 1, padding='same', kernel_initializer=lecun_normal())(selu_xg)
    sigmoid_xg = Activation('sigmoid')(psi)
    upsample_psi = Multiply()([sigmoid_xg, x])
    return upsample_psi

# Self-Attention-SE Block
def self_attention_se_block(inputs):
    inputs = se_block(inputs)

    channels = K.int_shape(inputs)[-1]
    reduced_channels = max(1, channels // 8)

    f = Conv2D(reduced_channels, kernel_size=1, padding='same', kernel_initializer=lecun_normal())(inputs)  # key
    g = Conv2D(reduced_channels, kernel_size=1, padding='same', kernel_initializer=lecun_normal())(inputs)  # query
    h = Conv2D(channels, kernel_size=1, padding='same', kernel_initializer=lecun_normal())(inputs)          # value

    f_flatten = Reshape((-1, reduced_channels))(f)
    g_flatten = Reshape((-1, reduced_channels))(g)
    h_flatten = Reshape((-1, channels))(h)

    s = Lambda(lambda x: tf.matmul(x[0], x[1], transpose_b=True))([g_flatten, f_flatten])
    beta = Activation('softmax')(s)

    o = Lambda(lambda x: tf.matmul(x[0], x[1]))([beta, h_flatten])

    o_reshaped = Reshape(K.int_shape(inputs)[1:])(o)

    x = Add()([o_reshaped, inputs])

    return x

# Multi-scale block (simplified)
def multi_scale_block(inputs, filters):
    conv3x3 = Conv2D(filters, kernel_size=3, padding='same', activation='relu', kernel_initializer=lecun_normal())(inputs)
    conv5x5 = Conv2D(filters, kernel_size=5, padding='same', activation='relu', kernel_initializer=lecun_normal())(inputs)

    x = Concatenate()([conv3x3, conv5x5])
    return x

# Updated conv_block with SELU
def conv_block(inputs, filters, kernel_size=3, padding='same'):
    x = multi_scale_block(inputs, filters)
    shortcut = inputs

    if inputs.shape[-1] != filters * 2:
        shortcut = Conv2D(filters * 2, kernel_size=1, padding='same', kernel_initializer=lecun_normal())(inputs)

    x = Conv2D(filters * 2, kernel_size, padding=padding, kernel_initializer=lecun_normal())(x)
    x = Add()([x, shortcut])
    x = Activation('relu')(x)  # Activation with SELU
    return x




# Compute filters for PyramidNet with linear increment
def compute_pyramid_filters(num_blocks, initial_filters, increment):
    filters_list = [initial_filters + i * increment for i in range(num_blocks)]
    return filters_list

class SwinTransformerBlock(tf.keras.layers.Layer):
    def __init__(self, dim, num_heads, window_size=7, mlp_ratio=4., dropout=0.0, **kwargs):
        super(SwinTransformerBlock, self).__init__(**kwargs)
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        self.dropout = dropout
        
        # Simplified attention mechanism
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.attention = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=dim)
        self.projection = Dense(dim)  # Align attention output back to `dim`
        
        # Reduced MLP layer
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.mlp_dense = Dense(int(dim * mlp_ratio), activation='relu')  # Reduced to single layer MLP
        self.mlp_output_projection = Dense(dim)  # Project back to `dim`
        
        self.dropout_layer = Dropout(rate=dropout)
        

def decoder_block(inputs, skip_features, filters, num_heads=8, window_size=7):
    # Step 1: Upsampling the inputs
    x = UpSampling2D(size=(2, 2))(inputs)
    
    # Step 2: Apply Attention-SE Block on the skip features
    skip_features = attention_se_block(skip_features, x, filters)
    
    # Step 3: Concatenate upsampled features with skip features
    x = Concatenate()([x, skip_features])
    
    # Step 4: Apply Convolutional Block to refine combined features
    x = conv_block(x, filters)
    
    # Step 5: Apply Swin Transformer Block for further feature refinement
    x = SwinTransformerBlock(dim=filters, num_heads=num_heads, window_size=window_size)(x)
    
    return x
    
# Encoder Block with Swin Transformer Integration
def encoder_block_with_swin(inputs, filters, num_heads=8, window_size=7):
    # Convolutional block
    x = conv_block(inputs, filters)
    x = se_block(x)  # Add SE block

    # Add Swin Transformer block
    x = SwinTransformerBlock(dim=filters, num_heads=num_heads, window_size=window_size)(x)
    
    # Pooling
    p = MaxPooling2D(pool_size=(2, 2))(x)
    return x, p

# Updated U-Net with Swin Transformer
def unet_pyramid_model(input_shape=(128, 128, 3)):
    inputs = Input(input_shape)
    
    num_encoder_blocks = 4  # Number of encoder blocks
    initial_filters = 32    # Starting number of filters
    increment = 32          # Linear increment of filters per block
    
    # Compute filters for encoder and bottleneck
    encoder_filters = compute_pyramid_filters(num_encoder_blocks + 1, initial_filters, increment)
    
    # Encoder
    s1, p1 = encoder_block_with_swin(inputs, encoder_filters[0])  # 32 filters
    s2, p2 = encoder_block_with_swin(p1, encoder_filters[1])      # 64 filters
    s3, p3 = encoder_block_with_swin(p2, encoder_filters[2])      # 96 filters
    s4, p4 = encoder_block_with_swin(p3, encoder_filters[3])      # 128 filters
    
    # Bottleneck
    b1 = conv_block(p4, encoder_filters[4])                       # 160 filters
    b1 = self_attention_se_block(b1)
    
    # Compute filters for decoder (reverse of encoder filters)
    decoder_filters = encoder_filters[:-1][::-1]                  # [128, 96, 64, 32]
    
    # Decoder
    d1 = decoder_block(b1, s4, decoder_filters[0])                # 128 filters
    d2 = decoder_block(d1, s3, decoder_filters[1])                # 96 filters
    d3 = decoder_block(d2, s2, decoder_filters[2])                # 64 filters
    d4 = decoder_block(d3, s1, decoder_filters[3])                # 32 filters
    
    # Output
    outputs = Conv2D(3, 1, padding='same', activation='sigmoid')(d4)
    
    model = Model(inputs, outputs)
    return model


def lr_schedule(epoch, lr):
    if epoch > 10:
        lr = lr * tf.math.exp(-0.1)
    return float(lr)

# Define PSNR Metric
def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

# Define SSIM Metric
def ssim_metric(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=1.0)

def compute_metrics(original_images, denoised_images):
    psnr_values = []
    ssim_values = []

    for i in range(len(original_images)):
        psnr = peak_signal_noise_ratio(original_images[i], denoised_images[i], data_range=1.0)
        ssim = structural_similarity(original_images[i], denoised_images[i], channel_axis=-1, data_range=1.0, win_size=3)

        psnr_values.append(psnr)
        ssim_values.append(ssim)

    return psnr_values, ssim_values

# Learning Rate Logger Callback
class LearningRateLogger(Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        # Get the optimizer from the model
        optimizer = self.model.optimizer
        # Access the learning rate
        lr = optimizer.learning_rate

        # If the learning rate is a schedule, get the current value
        if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
            # Evaluate the schedule at the current iteration
            lr = lr(self.model.optimizer.iterations)
        else:
            # If it's a constant, get the value directly
            lr = lr

        # Convert the learning rate to a float value
        lr = K.get_value(lr)
        # Log the learning rate
        logs['lr'] = lr
        print(f"Current learning rate: {lr}")

# Instantiate the learning rate logger
lr_logger = LearningRateLogger()

# Function to plot training history
def plot_training_history(history):
    plt.figure(figsize=(12, 5))
    epochs = range(1, len(history.history['loss']) + 1)

    # Plot training & validation loss values
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history.history['loss'], 'bo-', label='Training Loss')
    plt.plot(epochs, history.history['val_loss'], 'ro-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot learning rate
    if 'lr' in history.history:
        plt.subplot(1, 2, 2)
        plt.plot(epochs, history.history['lr'], 'go-', label='Learning Rate')
        plt.title('Learning Rate over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.legend()

    plt.tight_layout()
    plt.show()

# Function to plot PSNR and SSIM metrics
def plot_psnr_ssim(psnr_values, ssim_values):
    indices = range(len(psnr_values))

    plt.figure(figsize=(12, 5))

    # Plot PSNR values
    plt.subplot(1, 2, 1)
    plt.plot(indices, psnr_values, 'bo-', label='PSNR')
    plt.title('PSNR Values on Test Set')
    plt.xlabel('Image Index')
    plt.ylabel('PSNR')
    plt.legend()

    # Plot SSIM values
    plt.subplot(1, 2, 2)
    plt.plot(indices, ssim_values, 'ro-', label='SSIM')
    plt.title('SSIM Values on Test Set')
    plt.xlabel('Image Index')
    plt.ylabel('SSIM')
    plt.legend()

    plt.tight_layout()
    plt.show()

# Custom Loss Functions converted from PyTorch to TensorFlow
def gaussian(window_size, sigma):
    gauss = tf.exp(-tf.square(tf.range(window_size, dtype=tf.float32) - window_size // 2) / (2 * sigma**2))
    return gauss / tf.reduce_sum(gauss)

def get_gaussian_kernel(ksize, sigma):
    if not isinstance(ksize, int) or ksize % 2 == 0 or ksize <= 0:
        raise TypeError(f"ksize must be an odd positive integer. Got {ksize}")
    return gaussian(ksize, sigma)

def get_gaussian_kernel2d(ksize, sigma):
    if not isinstance(ksize, tuple) or len(ksize) != 2:
        raise TypeError(f"ksize must be a tuple of length two. Got {ksize}")
    if not isinstance(sigma, tuple) or len(sigma) != 2:
        raise TypeError(f"sigma must be a tuple of length two. Got {sigma}")
    
    kernel_x = get_gaussian_kernel(ksize[0], sigma[0])
    kernel_y = get_gaussian_kernel(ksize[1], sigma[1])
    
    kernel_2d = tf.tensordot(kernel_x, kernel_y, axes=0)
    return kernel_2d

# PSNR Loss
class PSNRLoss(tf.keras.losses.Loss):
    def __init__(self, loss_weight=1.0, toY=False):
        super(PSNRLoss, self).__init__()
        self.loss_weight = loss_weight
        self.scale = 10 / tf.math.log(10.0)
        self.toY = toY
        self.coef = tf.constant([65.481, 128.553, 24.966], shape=(1, 1, 1, 3))
    
    def call(self, y_true, y_pred):
        if self.toY:
            y_true = tf.reduce_sum(y_true * self.coef, axis=-1, keepdims=True) + 16
            y_pred = tf.reduce_sum(y_pred * self.coef, axis=-1, keepdims=True) + 16
            
            y_true = y_true / 255.0
            y_pred = y_pred / 255.0
        
        mse_loss = tf.reduce_mean(tf.square(y_pred - y_true), axis=[1, 2, 3])
        loss = -self.loss_weight * self.scale * tf.reduce_mean(tf.math.log(mse_loss + 1e-8))
        return loss

# SSIM Loss
class SSIMLoss(tf.keras.losses.Loss):
    def __init__(self, window_size=11, max_val=1.0):
        super(SSIMLoss, self).__init__()
        self.window_size = window_size
        self.max_val = max_val
        self.window = get_gaussian_kernel2d((window_size, window_size), (1.5, 1.5))
        self.padding = (window_size - 1) // 2

        self.C1 = (0.01 * self.max_val) ** 2
        self.C2 = (0.03 * self.max_val) ** 2
    
    def filter2D(self, img, kernel):
        kernel = kernel[:, :, tf.newaxis, tf.newaxis]
        kernel = tf.tile(kernel, [1, 1, img.shape[-1], 1])
        return tf.nn.depthwise_conv2d(img, kernel, strides=[1, 1, 1, 1], padding='SAME')
    
    def call(self, y_true, y_pred):
        kernel = self.window
        mu1 = self.filter2D(y_true, kernel)
        mu2 = self.filter2D(y_pred, kernel)

        mu1_sq = tf.square(mu1)
        mu2_sq = tf.square(mu2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = self.filter2D(y_true * y_true, kernel) - mu1_sq
        sigma2_sq = self.filter2D(y_pred * y_pred, kernel) - mu2_sq
        sigma12 = self.filter2D(y_true * y_pred, kernel) - mu1_mu2

        ssim_map = ((2 * mu1_mu2 + self.C1) * (2 * sigma12 + self.C2)) / \
                   ((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2))

        return tf.reduce_mean((1.0 - ssim_map) / 2.0)

# Charbonnier Loss
class CharbonnierLoss(tf.keras.losses.Loss):
    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps
    
    def call(self, y_true, y_pred):
        diff = y_pred - y_true
        loss = tf.reduce_mean(tf.sqrt(tf.square(diff) + self.eps**2))
        return loss

# Edge Loss
class EdgeLoss(tf.keras.losses.Loss):
    def __init__(self):
        super(EdgeLoss, self).__init__()
        k = tf.constant([[0.05, 0.25, 0.4, 0.25, 0.05]], dtype=tf.float32)
        self.kernel = tf.matmul(k, k, transpose_b=True)
        self.kernel = self.kernel[:, :, tf.newaxis, tf.newaxis]
        self.loss_fn = CharbonnierLoss()

    def conv_gauss(self, img):
        n_channels = img.shape[-1]
        kernel = tf.tile(self.kernel, [1, 1, n_channels, 1])
        return tf.nn.depthwise_conv2d(img, kernel, strides=[1, 1, 1, 1], padding='SAME')

    def laplacian_kernel(self, img):
        filtered = self.conv_gauss(img)
        downsampled = filtered[:, ::2, ::2, :]
        upsampled = tf.image.resize(downsampled, filtered.shape[1:3])
        filtered_up = self.conv_gauss(upsampled)
        return img - filtered_up

    def call(self, y_true, y_pred):
        return self.loss_fn(self.laplacian_kernel(y_true), self.laplacian_kernel(y_pred))

# ------------------------------------------------------------------------------
# Load images

noisy_dir = '/Final_DN_Traning/DSENet/dataset/complete_merged_dataset/train/input/'
clean_dir = '/Final_DN_Traning/DSENet/dataset/complete_merged_dataset/train/groundtruth'
X, y = load_images(noisy_dir, clean_dir)

# Split data into training, validation, and test sets
X_train_full, X_test, y_train_full, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.1, random_state=42)
# Split data into training and test sets
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# total_steps = len(X_train) // batch_size
# progress_bar_callback = TQDMProgressBar(total_steps=total_steps)

batch_size = 16
total_epochs = 100

# Create U-Net model
input_shape = X_train.shape[1:]

# Adjust early stopping to monitor PSNR instead of loss
early_stopping_callback = EarlyStopping(
    monitor='val_psnr_metric',  # Change to 'val_ssim_metric' if SSIM is preferred
    patience=5,
    mode='max',  # PSNR and SSIM are maximization metrics
    restore_best_weights=True,
    verbose=1
)

# Model Checkpoint for the entire model
best_model_checkpoint = ModelCheckpoint(
    filepath='Checkpoint/HSENet.keras',
    monitor='val_psnr_metric',
    mode='max',
    save_best_only=True,
    verbose=1
)

# Model Checkpoint for only the weights
best_weights_checkpoint = ModelCheckpoint(
    filepath='Checkpoint/HSENet.weights.h5',
    monitor='val_psnr_metric',
    mode='max',
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

# Learning rate scheduler callback
scheduler_callback = LearningRateScheduler(lr_schedule)

total_steps = len(X_train) // batch_size
progress_bar_callback = TQDMProgressBar(total_steps=total_steps)

# Compile model
model = unet_pyramid_model(input_shape)
optimizer = Adam(learning_rate=1e-4)

# Define combined loss function
def combined_loss(y_true, y_pred):
    charbonnier = CharbonnierLoss()(y_true, y_pred)
    edge = EdgeLoss()(y_true, y_pred)
    ssim = SSIMLoss()(y_true, y_pred)
    return charbonnier + edge + ssim

# Compile the model with the combined loss function
model.compile(optimizer=optimizer, loss=combined_loss, metrics=[psnr_metric, ssim_metric])
model.summary()

# Start timing for training
training_start_time = time.time()

# Train the model
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=total_epochs,
    batch_size=batch_size,
    callbacks=[
        CustomCallback(),
        scheduler_callback,
        early_stopping_callback,
        best_model_checkpoint,
        best_weights_checkpoint,
        progress_bar_callback, 
        lr_logger
    ],
    verbose=0
)

# End timing for training
training_time = time.time() - training_start_time
print(f"Total training time: {training_time:.2f} seconds")

# Plot training history
plot_training_history(history)

# Evaluate on validation set
denoised_images_val = model.predict(X_val)
psnr_values_val, ssim_values_val = compute_metrics(y_val, denoised_images_val)
mean_psnr_val = np.mean(psnr_values_val)
mean_ssim_val = np.mean(ssim_values_val)
print(f"Validation PSNR: {mean_psnr_val:.4f}, Validation SSIM: {mean_ssim_val:.4f}")

# Evaluate on test set
denoised_images_test = model.predict(X_test)
psnr_values_test, ssim_values_test = compute_metrics(y_test, denoised_images_test)
mean_psnr_test = np.mean(psnr_values_test)
mean_ssim_test = np.mean(ssim_values_test)

print(f"\nTest PSNR: {mean_psnr_test:.4f}")
print(f"Test SSIM: {mean_ssim_test:.4f}")

# Plot PSNR and SSIM metrics
plot_psnr_ssim(psnr_values_test, ssim_values_test)

# Save and zip the model
model_save_path = '/Final_DN_Traning/DSENet/DN_model/HSENet.keras'
model.save(model_save_path)
print(f"Model saved to {model_save_path}")

zip_save_path = '/Final_DN_Traning/DSENet/DN_model/HSENet.zip'

with zipfile.ZipFile(zip_save_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    zipf.write(model_save_path, arcname=os.path.basename(model_save_path))
print(f"Model zipped and saved to {zip_save_path}")

# Save denoised images
denoised_image_save_path = '/Final_DN_Traning/DSENet/DN_model/HSENet'
os.makedirs(denoised_image_save_path, exist_ok=True)
for i, denoised_image in enumerate(denoised_images_test):
    denoised_image_path = os.path.join(denoised_image_save_path, f'{i}.png')
    cv2.imwrite(denoised_image_path, (denoised_image * 255).astype(np.uint8))

# Visualize results
def plot_results(noisy_images, clean_images, denoised_images, n=5):
    plt.figure(figsize=(15, 10))
    for i in range(n):
        plt.subplot(n, 3, i * 3 + 1)
        plt.imshow(noisy_images[i])
        plt.title('Noisy Image')
        plt.axis('off')

        plt.subplot(n, 3, i * 3 + 2)
        plt.imshow(clean_images[i])
        plt.title('Clean Image')
        plt.axis('off')

        plt.subplot(n, 3, i * 3 + 3)
        plt.imshow(denoised_images[i])
        plt.title('Denoised Image')
        plt.axis('off')
    plt.show()

plot_results(X_test, y_test, denoised_images_test)