In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import datetime
import glob
import IPython.display as display
import logging
from tqdm.auto import tqdm
import random

# Suppress TensorFlow logging except for errors
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel('ERROR')

# Set GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(f"{len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
    except RuntimeError as e:
        print(e)

# --- Configuration ---
IMG_HEIGHT = 256
IMG_WIDTH = 256
CHANNELS = 3
BATCH_SIZE = 4  # Keep this for memory safety; can try increasing if GPU permits
BUFFER_SIZE = tf.data.AUTOTUNE # Reverted to AUTOTUNE for prefetching, shuffle buffer size handled dynamically

# Dataset paths
BASE_DATASET_PATH = '/kaggle/input/a-curated-list-of-image-deblurring-datasets/DBlur/'
# Updated DATASETS_TO_USE to include Wider-Face for validation
DATASETS_TO_USE = ['Helen', 'Wider-Face']

# Model parameters
TEACHER_FILTERS = 64 # Increased for more model capacity
STUDENT_FILTERS = 48 # Increased for more model capacity
LAMBDA_L1 = 100.0
LAMBDA_SSIM = 20.0 # Increased to emphasize SSIM more
LAMBDA_VGG = 0.02  # Slightly increased VGG loss weight for better perceptual quality
ALPHA = 0.5  # Weight for distillation loss (0.5 * teacher_output_loss + 0.5 * ground_truth_loss)

# Training parameters
# --- FULL RUN SETTINGS ---
TEACHER_EPOCHS = 50 # Increased epochs for better convergence
STUDENT_EPOCHS = 60 # Increased epochs for better convergence
# -------------------------
INITIAL_LR = 1e-4 # Reverted to original for ExponentialDecay

# Early Stopping parameters
EARLY_STOPPING_PATIENCE = 5  # Number of epochs with no improvement after which training will be stopped
MIN_DELTA = 1e-4  # Minimum change in the monitored quantity to qualify as an improvement

# Target Metrics (for reporting success)
TARGET_PSNR = 25.0  # Goal: PSNR >= 25 dB
TARGET_SSIM = 0.90  # Goal: SSIM >= 0.90

# Checkpoint and logging
CHECKPOINT_DIR = './checkpoints'
LOG_DIR = './logs'
EXAMPLE_IMAGES_DIR = './example_images'

# Create directories if they don't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(EXAMPLE_IMAGES_DIR, exist_ok=True)

# --- Data Loading and Preprocessing (Reverted to more robust version) ---

def _load_single_image_py(image_path):
    """
    Pure Python function to load and decode a single image.
    Handles corrupt/malformed images by returning None.
    """
    try:
        img_bytes = tf.io.read_file(image_path).numpy()
        img = tf.image.decode_image(img_bytes, channels=CHANNELS, expand_animations=False)
        if img is None or img.shape == (0, 0, 0): # Check for empty/malformed tensors
            logging.warning(f"Skipping malformed or empty image: {image_path.decode()}")
            return None
        # Ensure image has 3 channels even if decoded with fewer (e.g., grayscale)
        if img.shape[-1] != CHANNELS:
            logging.warning(f"Image has {img.shape[-1]} channels, expected {CHANNELS}: {image_path.decode()}. Converting to RGB.")
            if img.shape[-1] == 1:
                img = tf.image.grayscale_to_rgb(img)
            else:
                return None # Too complex to auto-handle all cases, better to skip.
        img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH], method=tf.image.ResizeMethod.BICUBIC)
        img = tf.cast(img, tf.float32) / 255.0 # Normalize to [0, 1]
        return img
    except Exception as e:
        logging.warning(f"Error loading image {image_path.decode()}: {e}. Skipping.")
        return None

def _tf_py_function_wrapper(blur_path, sharp_path):
    """
    Wrapper for tf.py_function to load image pairs and mark validity.
    Returns (blurred_image, sharp_image, is_valid_pair).
    """
    def load_and_validate(b_path, s_path):
        blur_img = _load_single_image_py(b_path)
        sharp_img = _load_single_image_py(s_path)
        is_valid = blur_img is not None and sharp_img is not None
        # Return dummy tensors if not valid to maintain shape, will be filtered later
        return blur_img if is_valid else tf.zeros((IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=tf.float32), \
               sharp_img if is_valid else tf.zeros((IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=tf.float32), \
               tf.constant(is_valid, dtype=tf.bool)

    blur_img, sharp_img, is_valid = tf.py_function(
        load_and_validate,
        [blur_path, sharp_path],
        [tf.float32, tf.float32, tf.bool]
    )

    # Crucially, set the shape after tf.py_function
    blur_img.set_shape([IMG_HEIGHT, IMG_WIDTH, CHANNELS])
    sharp_img.set_shape([IMG_HEIGHT, IMG_WIDTH, CHANNELS])
    is_valid.set_shape([]) # Scalar boolean

    return blur_img, sharp_img, is_valid

def augment_image_pair_tf_graph_mode(blur_image, sharp_image):
    """
    TensorFlow graph-mode augmentation for image pairs.
    Assumes valid input tensors.
    """
    # Random horizontal flip
    if tf.random.uniform(()) > 0.5:
        blur_image = tf.image.flip_left_right(blur_image)
        sharp_image = tf.image.flip_left_right(sharp_image)

    # Random rotation by 90, 180, 270 degrees
    k = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
    blur_image = tf.image.rot90(blur_image, k)
    sharp_image = tf.image.rot90(sharp_image, k)

    return blur_image, sharp_image

def create_image_dataset(dataset_type, batch_size, datasets_to_use=DATASETS_TO_USE, shuffle=True):
    """
    Creates a robust tf.data.Dataset pipeline.
    dataset_type: 'train', 'validation', or 'test'
    """
    all_blur_paths = []
    all_sharp_paths = []

    for dataset_name in datasets_to_use:
        blur_dir = os.path.join(BASE_DATASET_PATH, dataset_name, dataset_type, 'blur')
        sharp_dir = os.path.join(BASE_DATASET_PATH, dataset_name, dataset_type, 'sharp')

        if not os.path.exists(blur_dir) or not os.path.exists(sharp_dir):
            logging.warning(f"Skipping {dataset_name}/{dataset_type} as directories not found: {blur_dir}, {sharp_dir}")
            continue

        blur_files = sorted(glob.glob(os.path.join(blur_dir, '*.*')))
        sharp_files = sorted(glob.glob(os.path.join(sharp_dir, '*.*')))

        # Create a dictionary for quick sharp image lookup by filename
        sharp_map = {os.path.basename(f): f for f in sharp_files}

        matched_blur_paths = []
        matched_sharp_paths = []

        for blur_path in blur_files:
            filename = os.path.basename(blur_path)
            if filename in sharp_map:
                matched_blur_paths.append(blur_path)
                matched_sharp_paths.append(sharp_map[filename])
            else:
                logging.warning(f"No matching sharp image found for {blur_path}. Skipping.")

        all_blur_paths.extend(matched_blur_paths)
        all_sharp_paths.extend(matched_sharp_paths)

    if not all_blur_paths:
        logging.error(f"No valid image pairs found for {dataset_type} across specified datasets.")
        # Return an empty dataset with correct structure to prevent errors
        empty_dataset = tf.data.Dataset.from_tensor_slices((tf.constant([], dtype=tf.string), tf.constant([], dtype=tf.string))) \
                               .map(_tf_py_function_wrapper, num_parallel_calls=tf.data.AUTOTUNE) \
                               .filter(lambda b, s, valid: valid) \
                               .map(lambda b, s, valid: (b, s)) \
                               .batch(batch_size).prefetch(BUFFER_SIZE)
        return empty_dataset, 0


    num_elements = len(all_blur_paths) # This count is before filtering corrupt images
    dataset = tf.data.Dataset.from_tensor_slices((all_blur_paths, all_sharp_paths))

    if shuffle and dataset_type == 'train':
        # Use num_elements for buffer_size if available, else a default large value
        dataset = dataset.shuffle(buffer_size=num_elements if num_elements > 0 else 10000)

    dataset = dataset.map(_tf_py_function_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.filter(lambda blur_img, sharp_img, is_valid: is_valid) # Filter out invalid pairs
    dataset = dataset.map(lambda blur_img, sharp_img, is_valid: (blur_img, sharp_img), num_parallel_calls=tf.data.AUTOTUNE)

    if dataset_type == 'train':
        dataset = dataset.map(augment_image_pair_tf_graph_mode, num_parallel_calls=tf.data.AUTOTUNE)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=BUFFER_SIZE)

    return dataset, num_elements

# --- Model Architecture (Enhanced U-Net from previous working version) ---
def conv_block(inputs, filters, kernel_size=(3, 3), strides=(1, 1), padding='same', activation='relu', use_bn=True):
    x = tf.keras.layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, use_bias=not use_bn)(inputs)
    if use_bn:
        x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(activation)(x)
    return x

def residual_block(inputs, filters, activation='relu'):
    x = conv_block(inputs, filters, activation=activation)
    x = conv_block(x, filters, activation=None) # No activation on last conv of residual block
    x = tf.keras.layers.Add()([inputs, x])
    x = tf.keras.layers.Activation(activation)(x)
    return x

def build_enhanced_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), base_filters=64):
    inputs = tf.keras.layers.Input(shape=input_shape)

    # Encoder
    conv1 = conv_block(inputs, base_filters)
    conv1 = residual_block(conv1, base_filters)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1) # 128x128

    conv2 = conv_block(pool1, base_filters * 2)
    conv2 = residual_block(conv2, base_filters * 2)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2) # 64x64

    conv3 = conv_block(pool2, base_filters * 4)
    conv3 = residual_block(conv3, base_filters * 4)
    pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3) # 32x32

    conv4 = conv_block(pool3, base_filters * 8)
    conv4 = residual_block(conv4, base_filters * 8)
    pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv4) # 16x16

    # Bottleneck
    conv_bridge = conv_block(pool4, base_filters * 16)
    conv_bridge = residual_block(conv_bridge, base_filters * 16)

    # Decoder
    up1 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv_bridge) # 32x32
    concat1 = tf.keras.layers.Concatenate()([up1, conv4])
    conv5 = conv_block(concat1, base_filters * 8)
    conv5 = residual_block(conv5, base_filters * 8)

    up2 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv5) # 64x64
    concat2 = tf.keras.layers.Concatenate()([up2, conv3])
    conv6 = conv_block(concat2, base_filters * 4)
    conv6 = residual_block(conv6, base_filters * 4)

    up3 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv6) # 128x128
    concat3 = tf.keras.layers.Concatenate()([up3, conv2])
    conv7 = conv_block(concat3, base_filters * 2)
    conv7 = residual_block(conv7, base_filters * 2)

    up4 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv7) # 256x256
    concat4 = tf.keras.layers.Concatenate()([up4, conv1])
    conv8 = conv_block(concat4, base_filters)
    conv8 = residual_block(conv8, base_filters)

    # Final output layer uses sigmoid for [0, 1] range
    output = tf.keras.layers.Conv2D(CHANNELS, (1, 1), activation='sigmoid', padding='same')(conv8)

    model = tf.keras.Model(inputs=inputs, outputs=output)
    return model

# --- VGG Loss Model (from previous working version) ---
class VGGFeatureExtractor(tf.keras.Model):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        # Load VGG19 with weights trained on ImageNet, without the top classification layers.
        vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS))
        vgg.trainable = False # Freeze VGG weights
        
        # Define the layers from which to extract features.
        self.output_layers = [
            vgg.get_layer('block1_conv2').output,
            vgg.get_layer('block2_conv2').output,
            vgg.get_layer('block3_conv2').output,
            vgg.get_layer('block4_conv2').output,
            vgg.get_layer('block5_conv2').output
        ]
        self.vgg_model = tf.keras.Model(inputs=vgg.input, outputs=self.output_layers, name="VGG19_Feature_Extractor")

    @tf.function
    def call(self, inputs):
        # VGG models in tf.keras.applications expect inputs in the [0, 255] range
        # and then perform their own mean subtraction and channel reordering (RGB to BGR).
        # Our model outputs are in [0, 1], so we scale them up.
        
        # Ensure inputs are float32 and scale to 0-255
        scaled_inputs = tf.cast(inputs, tf.float32) * 255.0
        
        # Preprocess according to VGG19's requirements
        processed_inputs = tf.keras.applications.vgg19.preprocess_input(scaled_inputs)
            
        return self.vgg_model(processed_inputs)

vgg_feature_extractor = VGGFeatureExtractor()

# --- Loss Functions (Reverted and refined) ---
def perceptual_loss(y_true, y_pred):
    # L1 Loss
    l1_loss_val = tf.reduce_mean(tf.abs(y_true - y_pred))

    # SSIM Loss (negative, as we want to maximize SSIM, so we minimize 1-SSIM)
    # SSIM expects values in [0, max_val]. Our images are [0,1].
    ssim_loss_val = 1 - tf.image.ssim(y_true, y_pred, max_val=1.0)
    ssim_loss_val = tf.reduce_mean(ssim_loss_val) # Mean across batch and channels

    # VGG Feature Loss
    true_features = vgg_feature_extractor(y_true)
    pred_features = vgg_feature_extractor(y_pred)
    vgg_loss_val = 0.0 # Initialize as float
    for i in range(len(true_features)):
        # Ensure that feature maps are not empty or malformed from VGG
        if tf.reduce_prod(tf.shape(true_features[i])) == 0 or tf.reduce_prod(tf.shape(pred_features[i])) == 0:
            logging.warning("VGG feature map is empty for one of the layers. Skipping VGG loss for this batch.")
            vgg_loss_component = 0.0
        else:
            # Using L1 (absolute difference) for VGG features instead of L2 (squared difference)
            vgg_loss_component = tf.reduce_mean(tf.abs(true_features[i] - pred_features[i]))

        vgg_loss_val += vgg_loss_component

    # Apply LAMBDA_VGG to the total VGG loss
    total_loss = LAMBDA_L1 * l1_loss_val + LAMBDA_SSIM * ssim_loss_val + LAMBDA_VGG * vgg_loss_val
    
    return total_loss

def distillation_loss_fn(teacher_output, student_output, ground_truth):
    # Loss against teacher's output (L_KD)
    teacher_output_loss = perceptual_loss(teacher_output, student_output)

    # Loss against ground truth (L_GT)
    ground_truth_loss = perceptual_loss(ground_truth, student_output)

    # Weighted sum: ALPHA * L_KD + (1 - ALPHA) * L_GT
    return ALPHA * teacher_output_loss + (1 - ALPHA) * ground_truth_loss

# --- Metrics (Consistent with [0,1] range) ---
def calculate_psnr(y_true, y_pred):
    # Ensure inputs are clipped to [0,1] for accurate PSNR calculation
    y_true_clipped = tf.clip_by_value(y_true, 0.0, 1.0)
    y_pred_clipped = tf.clip_by_value(y_pred, 0.0, 1.0)
    return tf.image.psnr(y_true_clipped, y_pred_clipped, max_val=1.0)

def calculate_ssim(y_true, y_pred):
    # Ensure inputs are clipped to [0,1] for accurate SSIM calculation
    y_true_clipped = tf.clip_by_value(y_true, 0.0, 1.0)
    y_pred_clipped = tf.clip_by_value(y_pred, 0.0, 1.0)
    return tf.image.ssim(y_true_clipped, y_pred_clipped, max_val=1.0)

# --- Training Steps (Simplified for non-GAN) ---
@tf.function
def teacher_train_step(blur_images, sharp_images, teacher_model, teacher_optimizer):
    with tf.GradientTape() as tape:
        deblurred_images = teacher_model(blur_images, training=True)
        loss = perceptual_loss(sharp_images, deblurred_images)

    gradients = tape.gradient(loss, teacher_model.trainable_variables)
    teacher_optimizer.apply_gradients(zip(gradients, teacher_model.trainable_variables))
    return loss

@tf.function
def distillation_train_step(blur_images, sharp_images, student_model, teacher_model, student_optimizer):
    with tf.GradientTape() as tape:
        teacher_deblurred_images = teacher_model(blur_images, training=False)
        student_deblurred_images = student_model(blur_images, training=True)
        loss = distillation_loss_fn(teacher_deblurred_images, student_deblurred_images, sharp_images)

    gradients = tape.gradient(loss, student_model.trainable_variables)
    student_optimizer.apply_gradients(zip(gradients, student_model.trainable_variables))
    return loss

# --- Evaluation (from previous working version) ---
def evaluate_model(model, dataset):
    psnr_values = []
    ssim_values = []
    
    # Check if dataset is truly empty by attempting to take one element
    try:
        _ = next(iter(dataset.take(1))) 
    except tf.errors.OutOfRangeError:
        return 0.0, 0.0 # Return 0 if dataset is empty

    for blur_img, sharp_img in dataset: 
        deblurred_img = model(blur_img, training=False)
        psnr_values.append(calculate_psnr(sharp_img, deblurred_img))
        ssim_values.append(calculate_ssim(sharp_img, deblurred_img))

    if not psnr_values: # If no images were processed (e.g., due to filtering)
        return 0.0, 0.0

    avg_psnr = tf.reduce_mean(tf.concat(psnr_values, axis=0)).numpy()
    avg_ssim = tf.reduce_mean(tf.concat(ssim_values, axis=0)).numpy()
    return avg_psnr, avg_ssim

# --- Visualization (from previous working version) ---
def generate_and_save_images(model, test_dataset, epoch, model_name, num_examples=4):
    try:
        # Take a single batch for visualization
        for blur_img_batch, sharp_img_batch in test_dataset.take(1):
            deblurred_img_batch = model(blur_img_batch, training=False)
            break
    except tf.errors.OutOfRangeError:
        logging.warning(f"No images available in test_dataset for {model_name} visualization at epoch {epoch}.")
        return

    if blur_img_batch.shape[0] < num_examples:
        num_examples = blur_img_batch.shape[0]
        if num_examples == 0:
            logging.warning(f"No images available in test_dataset for {model_name} visualization at epoch {epoch} after taking a batch.")
            return

    # Select random indices from the batch
    indices = random.sample(range(blur_img_batch.shape[0]), min(num_examples, blur_img_batch.shape[0]))
    
    fig = plt.figure(figsize=(15, num_examples * 5))
    for i, idx in enumerate(indices):
        plt.subplot(num_examples, 3, i*3 + 1)
        plt.imshow(blur_img_batch[idx].numpy())
        plt.title('Blurred Input')
        plt.axis('off')

        plt.subplot(num_examples, 3, i*3 + 2)
        plt.imshow(deblurred_img_batch[idx].numpy())
        plt.title(f'{model_name} Deblurred')
        plt.axis('off')

        plt.subplot(num_examples, 3, i*3 + 3)
        plt.imshow(sharp_img_batch[idx].numpy())
        plt.title('Ground Truth')
        plt.axis('off')
    plt.suptitle(f'{model_name} - Epoch {epoch}', fontsize=16)
    plt.savefig(os.path.join(EXAMPLE_IMAGES_DIR, f'{model_name}_epoch_{epoch:03d}.png'))
    plt.close(fig)

# --- Main Training Script ---
def run_training():
    print("--- Preparing Datasets ---")
    train_dataset, train_initial_elements = create_image_dataset('train', BATCH_SIZE)
    # Using Helen for validation and test as per original config, assuming it has a 'test' split
    val_dataset, _ = create_image_dataset('validation', BATCH_SIZE, datasets_to_use=['Helen'])
    test_dataset, _ = create_image_dataset('test', BATCH_SIZE, datasets_to_use=['Helen'])

    # Check if training dataset is not empty before proceeding
    # Use next(iter(dataset.take(1))) to check if it has at least one element
    try:
        _ = next(iter(train_dataset.take(1)))
    except tf.errors.OutOfRangeError:
        print("Error: Training dataset is empty after filtering. Please check dataset paths and content.")
        return

    # Calculate steps per epoch based on the (initial) number of elements and batch size
    # This is an estimate, actual steps might be less due to filtering
    steps_per_teacher_epoch = (train_initial_elements + BATCH_SIZE - 1) // BATCH_SIZE
    steps_per_student_epoch = (train_initial_elements + BATCH_SIZE - 1) // BATCH_SIZE
    
    # Adjust decay steps for learning rate schedule to cover more epochs
    teacher_decay_steps = int(steps_per_teacher_epoch * TEACHER_EPOCHS * 0.5) # Decay over half the total epochs
    student_decay_steps = int(steps_per_student_epoch * STUDENT_EPOCHS * 0.5)

    teacher_model = build_enhanced_unet(base_filters=TEACHER_FILTERS)
    student_model = build_enhanced_unet(base_filters=STUDENT_FILTERS)

    teacher_lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=INITIAL_LR, 
        decay_steps=teacher_decay_steps, # Adjusted
        decay_rate=0.96,
        staircase=True
    )
    teacher_optimizer = tf.keras.optimizers.Adam(learning_rate=teacher_lr_scheduler)

    student_lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=INITIAL_LR, 
        decay_steps=student_decay_steps, # Adjusted
        decay_rate=0.96,
        staircase=True
    )
    student_optimizer = tf.keras.optimizers.Adam(learning_rate=student_lr_scheduler)

    teacher_checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "teacher_ckpt")
    teacher_checkpoint = tf.train.Checkpoint(optimizer=teacher_optimizer, model=teacher_model)
    teacher_manager = tf.train.CheckpointManager(teacher_checkpoint, teacher_checkpoint_prefix, max_to_keep=5)

    student_checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "student_ckpt")
    student_checkpoint = tf.train.Checkpoint(optimizer=student_optimizer, model=student_model)
    student_manager = tf.train.CheckpointManager(student_checkpoint, student_checkpoint_prefix, max_to_keep=5)

    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    teacher_log_dir = os.path.join(LOG_DIR, 'teacher', current_time)
    student_log_dir = os.path.join(LOG_DIR, 'student', current_time)
    teacher_summary_writer = tf.summary.create_file_writer(teacher_log_dir)
    student_summary_writer = tf.summary.create_file_writer(student_log_dir)

    print("\n--- Training Teacher Model ---")
    best_val_teacher_ssim = -1.0 # Initialize for early stopping
    epochs_since_teacher_improvement = 0 # Counter for early stopping

    for epoch in range(TEACHER_EPOCHS):
        total_loss = 0.0
        batch_count = 0
        
        with tqdm(train_dataset, desc=f"Teacher Epoch {epoch+1}/{TEACHER_EPOCHS}", unit="batch", total=steps_per_teacher_epoch) as pbar:
            pbar.set_postfix(loss="N/A") 
            for blur_images, sharp_images in pbar:
                loss = teacher_train_step(blur_images, sharp_images, teacher_model, teacher_optimizer)
                total_loss += loss.numpy()
                batch_count += 1
                pbar.set_postfix(loss=f"{loss.numpy():.4f}")

        if batch_count == 0:
            print(f"Teacher Epoch {epoch+1}: No batches processed. This means dataset is empty or all images were filtered. Skipping epoch.")
            continue

        avg_loss = total_loss / batch_count
        
        current_val_psnr, current_val_ssim = 0.0, 0.0
        # Only evaluate if validation dataset is not empty
        try:
            _ = next(iter(val_dataset.take(1))) # Check if val_dataset has at least one element
            current_val_psnr, current_val_ssim = evaluate_model(teacher_model, val_dataset)
        except tf.errors.OutOfRangeError:
            print("Warning: Validation dataset is empty. Skipping validation metrics for Teacher. Early stopping will not be effective.")
        
        print(f"Teacher Epoch {epoch+1} - Avg Train Loss: {avg_loss:.4f} | Val PSNR: {current_val_psnr:.2f} dB | Val SSIM: {current_val_ssim:.4f}")

        with teacher_summary_writer.as_default():
            tf.summary.scalar('train_loss', avg_loss, step=epoch)
            tf.summary.scalar('val_psnr', current_val_psnr, step=epoch)
            tf.summary.scalar('val_ssim', current_val_ssim, step=epoch)

        # Early Stopping Logic for Teacher
        if current_val_ssim > best_val_teacher_ssim + MIN_DELTA:
            best_val_teacher_ssim = current_val_ssim
            teacher_manager.save()
            print(f"Saved Teacher checkpoint at epoch {epoch+1} (best Val SSIM: {best_val_teacher_ssim:.4f})")
            epochs_since_teacher_improvement = 0
        else:
            epochs_since_teacher_improvement += 1
            print(f"Teacher Val SSIM did not improve. Epochs since last improvement: {epochs_since_teacher_improvement}/{EARLY_STOPPING_PATIENCE}")
        
        # Generate images every 10 epochs for visual tracking
        if (epoch + 1) % 10 == 0:
            try:
                _ = next(iter(test_dataset.take(1))) # Check if test_dataset has at least one element
                generate_and_save_images(teacher_model, test_dataset, epoch+1, 'Teacher')
            except tf.errors.OutOfRangeError:
                logging.warning(f"No test dataset available for Teacher visualization at epoch {epoch+1}.")
        
        # Check for early stopping
        if epochs_since_teacher_improvement >= EARLY_STOPPING_PATIENCE:
            print(f"\033[93mEarly stopping Teacher training: No improvement in Val SSIM for {EARLY_STOPPING_PATIENCE} epochs.\033[0m")
            break # Exit the training loop

    # Restore best teacher model after training (if early stopping occurred or final model is best)
    if teacher_manager.latest_checkpoint:
        teacher_checkpoint.restore(teacher_manager.latest_checkpoint).expect_partial()
        print(f"Restored best Teacher model from {teacher_manager.latest_checkpoint}")
    else:
        print("No Teacher checkpoint to restore. Training will proceed with the last epoch's model.")


    print("\n--- Training Student Model with Knowledge Distillation ---")
    best_val_student_ssim = -1.0 # Initialize for early stopping
    epochs_since_student_improvement = 0 # Counter for early stopping

    for epoch in range(STUDENT_EPOCHS):
        total_loss = 0.0
        batch_count = 0
        
        with tqdm(train_dataset, desc=f"Student Epoch {epoch+1}/{STUDENT_EPOCHS}", unit="batch", total=steps_per_student_epoch) as pbar:
            pbar.set_postfix(loss="N/A")
            for blur_images, sharp_images in pbar:
                loss = distillation_train_step(blur_images, sharp_images, student_model, teacher_model, student_optimizer)
                total_loss += loss.numpy()
                batch_count += 1
                pbar.set_postfix(loss=f"{loss.numpy():.4f}")

        if batch_count == 0:
            print(f"Student Epoch {epoch+1}: No batches processed. This means dataset is empty or all images were filtered. Skipping epoch.")
            continue

        avg_loss = total_loss / batch_count
        
        current_val_psnr, current_val_ssim = 0.0, 0.0
        try:
            _ = next(iter(val_dataset.take(1))) # Check if val_dataset has at least one element
            current_val_psnr, current_val_ssim = evaluate_model(student_model, val_dataset)
        except tf.errors.OutOfRangeError:
            print("Warning: Validation dataset is empty. Skipping validation metrics for Student. Early stopping will not be effective.")

        print(f"Student Epoch {epoch+1} - Avg Train Loss: {avg_loss:.4f} | Val PSNR: {current_val_psnr:.2f} dB | Val SSIM: {current_val_ssim:.4f}")

        with student_summary_writer.as_default():
            tf.summary.scalar('train_loss', avg_loss, step=epoch)
            tf.summary.scalar('val_psnr', current_val_psnr, step=epoch)
            tf.summary.scalar('val_ssim', current_val_ssim, step=epoch)

        # Early Stopping Logic for Student
        if current_val_ssim > best_val_student_ssim + MIN_DELTA:
            best_val_student_ssim = current_val_ssim
            student_manager.save()
            print(f"Saved Student checkpoint at epoch {epoch+1} (best Val SSIM: {best_val_student_ssim:.4f})")
            epochs_since_student_improvement = 0
        else:
            epochs_since_student_improvement += 1
            print(f"Student Val SSIM did not improve. Epochs since last improvement: {epochs_since_student_improvement}/{EARLY_STOPPING_PATIENCE}")
            
        # Generate images every 10 epochs for visual tracking
        if (epoch + 1) % 10 == 0:
            try:
                _ = next(iter(test_dataset.take(1))) # Check if test_dataset has at least one element
                generate_and_save_images(student_model, test_dataset, epoch+1, 'Student')
            except tf.errors.OutOfRangeError:
                logging.warning(f"No test dataset available for Student visualization at epoch {epoch+1}.")

        # Check for early stopping
        if epochs_since_student_improvement >= EARLY_STOPPING_PATIENCE:
            print(f"\033[93mEarly stopping Student training: No improvement in Val SSIM for {EARLY_STOPPING_PATIENCE} epochs.\033[0m")
            break # Exit the training loop


    print("\n--- Final Evaluation on Test Set ---")
    if student_manager.latest_checkpoint:
        student_checkpoint.restore(student_manager.latest_checkpoint).expect_partial()
        print(f"Restored best Student model from {student_manager.latest_checkpoint}")
    else:
        print("No Student checkpoint to restore. Final evaluation will use the last epoch's model.")

    try:
        _ = next(iter(test_dataset.take(1))) # Check if test_dataset has at least one element
        final_psnr, final_ssim = evaluate_model(student_model, test_dataset)
        print(f"\nFinal Student Model Performance on Test Set:")
        print(f"PSNR: {final_psnr:.2f} dB")
        print(f"SSIM: {final_ssim:.4f}")
        if final_psnr >= TARGET_PSNR and final_ssim >= TARGET_SSIM:
            print(f"\033[92mFinal Student Model ACHIEVED TARGET METRICS!\033[0m")
        else:
            print(f"\033[91mFinal Student Model DID NOT meet target metrics (Target PSNR: {TARGET_PSNR} dB, SSIM: {TARGET_SSIM}).\033[0m")
    except tf.errors.OutOfRangeError:
        print("No test dataset available for final evaluation.")

if __name__ == '__main__':
    run_training()


2025-07-09 18:01:30.318824: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752084090.508136      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752084090.557822      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


1 Physical GPUs, 1 Logical GPUs


I0000 00:00:1752084104.913604      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m80134624/80134624[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
--- Preparing Datasets ---

--- Training Teacher Model ---


Teacher Epoch 1/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

I0000 00:00:1752084130.109811      57 cuda_dnn.cc:529] Loaded cuDNN version 90300


Teacher Epoch 1 - Avg Train Loss: 21.0884 | Val PSNR: 28.26 dB | Val SSIM: 0.8329
Saved Teacher checkpoint at epoch 1 (best Val SSIM: 0.8329)


Teacher Epoch 2/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 2 - Avg Train Loss: 18.6652 | Val PSNR: 28.23 dB | Val SSIM: 0.8367
Saved Teacher checkpoint at epoch 2 (best Val SSIM: 0.8367)


Teacher Epoch 3/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 3 - Avg Train Loss: 17.9065 | Val PSNR: 28.56 dB | Val SSIM: 0.8268
Teacher Val SSIM did not improve. Epochs since last improvement: 1/5


Teacher Epoch 4/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 4 - Avg Train Loss: 17.6198 | Val PSNR: 28.34 dB | Val SSIM: 0.8424
Saved Teacher checkpoint at epoch 4 (best Val SSIM: 0.8424)


Teacher Epoch 5/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 5 - Avg Train Loss: 17.2421 | Val PSNR: 28.61 dB | Val SSIM: 0.8225
Teacher Val SSIM did not improve. Epochs since last improvement: 1/5


Teacher Epoch 6/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 6 - Avg Train Loss: 16.9931 | Val PSNR: 29.16 dB | Val SSIM: 0.8413
Teacher Val SSIM did not improve. Epochs since last improvement: 2/5


Teacher Epoch 7/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 7 - Avg Train Loss: 16.7763 | Val PSNR: 28.09 dB | Val SSIM: 0.8290
Teacher Val SSIM did not improve. Epochs since last improvement: 3/5


Teacher Epoch 8/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 8 - Avg Train Loss: 16.5472 | Val PSNR: 29.06 dB | Val SSIM: 0.8292
Teacher Val SSIM did not improve. Epochs since last improvement: 4/5


Teacher Epoch 9/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 9 - Avg Train Loss: 16.2670 | Val PSNR: 29.17 dB | Val SSIM: 0.8512
Saved Teacher checkpoint at epoch 9 (best Val SSIM: 0.8512)


Teacher Epoch 10/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 10 - Avg Train Loss: 16.0558 | Val PSNR: 29.46 dB | Val SSIM: 0.8420
Teacher Val SSIM did not improve. Epochs since last improvement: 1/5


Teacher Epoch 11/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 11 - Avg Train Loss: 15.8353 | Val PSNR: 29.36 dB | Val SSIM: 0.8567
Saved Teacher checkpoint at epoch 11 (best Val SSIM: 0.8567)


Teacher Epoch 12/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 12 - Avg Train Loss: 15.5852 | Val PSNR: 29.50 dB | Val SSIM: 0.8340
Teacher Val SSIM did not improve. Epochs since last improvement: 1/5


Teacher Epoch 13/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 13 - Avg Train Loss: 15.4030 | Val PSNR: 29.50 dB | Val SSIM: 0.8325
Teacher Val SSIM did not improve. Epochs since last improvement: 2/5


Teacher Epoch 14/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 14 - Avg Train Loss: 15.1159 | Val PSNR: 29.37 dB | Val SSIM: 0.8476
Teacher Val SSIM did not improve. Epochs since last improvement: 3/5


Teacher Epoch 15/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 15 - Avg Train Loss: 14.8206 | Val PSNR: 29.42 dB | Val SSIM: 0.8575
Saved Teacher checkpoint at epoch 15 (best Val SSIM: 0.8575)


Teacher Epoch 16/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 16 - Avg Train Loss: 14.6724 | Val PSNR: 29.66 dB | Val SSIM: 0.8320
Teacher Val SSIM did not improve. Epochs since last improvement: 1/5


Teacher Epoch 17/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 17 - Avg Train Loss: 14.3804 | Val PSNR: 28.94 dB | Val SSIM: 0.8383
Teacher Val SSIM did not improve. Epochs since last improvement: 2/5


Teacher Epoch 18/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 18 - Avg Train Loss: 14.2390 | Val PSNR: 29.86 dB | Val SSIM: 0.8604
Saved Teacher checkpoint at epoch 18 (best Val SSIM: 0.8604)


Teacher Epoch 19/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 19 - Avg Train Loss: 14.0157 | Val PSNR: 30.15 dB | Val SSIM: 0.8512
Teacher Val SSIM did not improve. Epochs since last improvement: 1/5


Teacher Epoch 20/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 20 - Avg Train Loss: 13.8364 | Val PSNR: 29.95 dB | Val SSIM: 0.8591
Teacher Val SSIM did not improve. Epochs since last improvement: 2/5


Teacher Epoch 21/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 21 - Avg Train Loss: 13.7617 | Val PSNR: 30.27 dB | Val SSIM: 0.8648
Saved Teacher checkpoint at epoch 21 (best Val SSIM: 0.8648)


Teacher Epoch 22/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 22 - Avg Train Loss: 13.5709 | Val PSNR: 30.14 dB | Val SSIM: 0.8588
Teacher Val SSIM did not improve. Epochs since last improvement: 1/5


Teacher Epoch 23/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 23 - Avg Train Loss: 13.5081 | Val PSNR: 29.91 dB | Val SSIM: 0.8412
Teacher Val SSIM did not improve. Epochs since last improvement: 2/5


Teacher Epoch 24/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 24 - Avg Train Loss: 13.3304 | Val PSNR: 30.03 dB | Val SSIM: 0.8625
Teacher Val SSIM did not improve. Epochs since last improvement: 3/5


Teacher Epoch 25/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 25 - Avg Train Loss: 13.2363 | Val PSNR: 29.85 dB | Val SSIM: 0.8246
Teacher Val SSIM did not improve. Epochs since last improvement: 4/5


Teacher Epoch 26/50:   0%|          | 0/1633 [00:00<?, ?batch/s]

Teacher Epoch 26 - Avg Train Loss: 13.1242 | Val PSNR: 30.16 dB | Val SSIM: 0.8640
Teacher Val SSIM did not improve. Epochs since last improvement: 5/5
[93mEarly stopping Teacher training: No improvement in Val SSIM for 5 epochs.[0m
Restored best Teacher model from ./checkpoints/teacher_ckpt/ckpt-8

--- Training Student Model with Knowledge Distillation ---


Student Epoch 1/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 1 - Avg Train Loss: 18.6093 | Val PSNR: 28.11 dB | Val SSIM: 0.8256
Saved Student checkpoint at epoch 1 (best Val SSIM: 0.8256)


Student Epoch 2/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 2 - Avg Train Loss: 15.6110 | Val PSNR: 28.51 dB | Val SSIM: 0.8201
Student Val SSIM did not improve. Epochs since last improvement: 1/5


Student Epoch 3/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 3 - Avg Train Loss: 14.7097 | Val PSNR: 28.55 dB | Val SSIM: 0.8360
Saved Student checkpoint at epoch 3 (best Val SSIM: 0.8360)


Student Epoch 4/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 4 - Avg Train Loss: 14.2661 | Val PSNR: 28.30 dB | Val SSIM: 0.8303
Student Val SSIM did not improve. Epochs since last improvement: 1/5


Student Epoch 5/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 5 - Avg Train Loss: 13.9735 | Val PSNR: 28.40 dB | Val SSIM: 0.8411
Saved Student checkpoint at epoch 5 (best Val SSIM: 0.8411)


Student Epoch 6/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 6 - Avg Train Loss: 13.6394 | Val PSNR: 28.75 dB | Val SSIM: 0.8444
Saved Student checkpoint at epoch 6 (best Val SSIM: 0.8444)


Student Epoch 7/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 7 - Avg Train Loss: 13.3298 | Val PSNR: 28.85 dB | Val SSIM: 0.8407
Student Val SSIM did not improve. Epochs since last improvement: 1/5


Student Epoch 8/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 8 - Avg Train Loss: 13.0877 | Val PSNR: 28.31 dB | Val SSIM: 0.7995
Student Val SSIM did not improve. Epochs since last improvement: 2/5


Student Epoch 9/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 9 - Avg Train Loss: 12.8404 | Val PSNR: 29.08 dB | Val SSIM: 0.8328
Student Val SSIM did not improve. Epochs since last improvement: 3/5


Student Epoch 10/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 10 - Avg Train Loss: 12.5584 | Val PSNR: 29.33 dB | Val SSIM: 0.8424
Student Val SSIM did not improve. Epochs since last improvement: 4/5


Student Epoch 11/60:   0%|          | 0/1633 [00:00<?, ?batch/s]

Student Epoch 11 - Avg Train Loss: 12.2117 | Val PSNR: 29.24 dB | Val SSIM: 0.8383
Student Val SSIM did not improve. Epochs since last improvement: 5/5
[93mEarly stopping Student training: No improvement in Val SSIM for 5 epochs.[0m

--- Final Evaluation on Test Set ---
Restored best Student model from ./checkpoints/student_ckpt/ckpt-4

Final Student Model Performance on Test Set:
PSNR: 27.75 dB
SSIM: 0.8362
[91mFinal Student Model DID NOT meet target metrics (Target PSNR: 25.0 dB, SSIM: 0.9).[0m
