In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import random
import glob
import logging

# Suppress TensorFlow logging except for errors
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel('ERROR')

# Set GPU memory growth if available
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(f"{len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs")
    except RuntimeError as e:
        print(e)

# --- Configuration (Must match training script) ---
IMG_HEIGHT = 256
IMG_WIDTH = 256
CHANNELS = 3
BATCH_SIZE = 1 # Use batch size 1 for visualization for simplicity
BUFFER_SIZE = tf.data.AUTOTUNE

# UPDATED: Path to your saved student model checkpoint
# Based on your image, the 'student' output directory from Kaggle contains 'checkpoints'.
# Assuming your script is run from the root of the Kaggle output.
# If you are running this from /kaggle/working/, and the output is /kaggle/output/,
# you might need to adjust BASE_OUTPUT_PATH.
BASE_OUTPUT_PATH = '/kaggle/input/' # Adjust if your output is in a different location relative to where you run this script

# The CHECKPOINT_DIR should point to the directory *containing* 'student_ckpt' and 'teacher_ckpt' folders.
# In your case, this is '/kaggle/output/student/checkpoints'.
CHECKPOINT_ROOT_DIR = os.path.join(BASE_OUTPUT_PATH, 'student', 'checkpoints')
MODEL_TO_LOAD = 'student' # 'teacher' or 'student'

# Dataset path for visualization
# This should still point to your input dataset
VISUALIZATION_DATASET_PATH = '/kaggle/input/a-curated-list-of-image-deblurring-datasets/DBlur/'
VISUALIZATION_DATASET_NAME = 'TextOCR' # Or any other dataset you want to visualize from
VISUALIZATION_DATASET_TYPE = 'test' # 'test' or 'validation'

# Output directory for saved visualization images
VISUALIZATION_OUTPUT_DIR = './visualizations' # This will create a 'visualizations' folder in your current working directory (e.g., /kaggle/working/visualizations)
os.makedirs(VISUALIZATION_OUTPUT_DIR, exist_ok=True)

# --- Data Loading and Preprocessing (Copied from training script) ---

def _load_single_image_py(image_path):
    """
    Pure Python function to load and decode a single image.
    Handles corrupt/malformed images by returning None.
    """
    try:
        img_bytes = tf.io.read_file(image_path).numpy()
        img = tf.image.decode_image(img_bytes, channels=CHANNELS, expand_animations=False)
        if img is None or img.shape == (0, 0, 0): # Check for empty/malformed tensors
            logging.warning(f"Skipping malformed or empty image: {image_path.decode()}")
            return None
        # Ensure image has 3 channels even if decoded with fewer (e.g., grayscale)
        if img.shape[-1] != CHANNELS:
            logging.warning(f"Image has {img.shape[-1]} channels, expected {CHANNELS}: {image_path.decode()}. Converting to RGB.")
            if img.shape[-1] == 1:
                img = tf.image.grayscale_to_rgb(img)
            else:
                return None # Too complex to auto-handle all cases, better to skip.
        img = tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH], method=tf.image.ResizeMethod.BICUBIC)
        img = tf.cast(img, tf.float32) / 255.0 # Normalize to [0, 1]
        return img
    except Exception as e:
        logging.warning(f"Error loading image {image_path.decode()}: {e}. Skipping.")
        return None

def _tf_py_function_wrapper(blur_path, sharp_path):
    """
    Wrapper for tf.py_function to load image pairs and mark validity.
    Returns (blurred_image, sharp_image, is_valid_pair).
    """
    def load_and_validate(b_path, s_path):
        blur_img = _load_single_image_py(b_path)
        sharp_img = _load_single_image_py(s_path)
        is_valid = blur_img is not None and sharp_img is not None
        # Return dummy tensors if not valid to maintain shape, will be filtered later
        return blur_img if is_valid else tf.zeros((IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=tf.float32), \
               sharp_img if is_valid else tf.zeros((IMG_HEIGHT, IMG_WIDTH, CHANNELS), dtype=tf.float32), \
               tf.constant(is_valid, dtype=tf.bool)

    blur_img, sharp_img, is_valid = tf.py_function(
        load_and_validate,
        [blur_path, sharp_path],
        [tf.float32, tf.float32, tf.bool]
    )

    # Crucially, set the shape after tf.py_function
    blur_img.set_shape([IMG_HEIGHT, IMG_WIDTH, CHANNELS])
    sharp_img.set_shape([IMG_HEIGHT, IMG_WIDTH, CHANNELS])
    is_valid.set_shape([]) # Scalar boolean

    return blur_img, sharp_img, is_valid

def create_image_dataset_for_viz(dataset_type, batch_size, dataset_name, base_path=VISUALIZATION_DATASET_PATH):
    """
    Creates a tf.data.Dataset pipeline for visualization.
    Only takes blur and sharp paths and doesn't apply augmentation.
    """
    blur_dir = os.path.join(base_path, dataset_name, dataset_type, 'blur')
    sharp_dir = os.path.join(base_path, dataset_name, dataset_type, 'sharp')

    if not os.path.exists(blur_dir) or not os.path.exists(sharp_dir):
        logging.error(f"Visualization directories not found: {blur_dir} or {sharp_dir}")
        return None, 0

    blur_files = sorted(glob.glob(os.path.join(blur_dir, '*.*')))
    sharp_files = sorted(glob.glob(os.path.join(sharp_dir, '*.*')))

    sharp_map = {os.path.basename(f): f for f in sharp_files}
    matched_blur_paths = []
    matched_sharp_paths = []

    for blur_path in blur_files:
        filename = os.path.basename(blur_path)
        if filename in sharp_map:
            matched_blur_paths.append(blur_path)
            matched_sharp_paths.append(sharp_map[filename])
        else:
            logging.warning(f"No matching sharp image found for {blur_path}. Skipping for visualization.")

    if not matched_blur_paths:
        logging.error(f"No valid image pairs found for visualization in {dataset_name}/{dataset_type}.")
        return None, 0

    dataset = tf.data.Dataset.from_tensor_slices((matched_blur_paths, matched_sharp_paths))
    dataset = dataset.map(_tf_py_function_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.filter(lambda blur_img, sharp_img, is_valid: is_valid) # Filter out invalid pairs
    dataset = dataset.map(lambda blur_img, sharp_img, is_valid: (blur_img, sharp_img), num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=BUFFER_SIZE)

    return dataset, len(matched_blur_paths)

# --- Model Architecture (Copied from training script) ---
# It's crucial that the model architecture is identical to how it was saved.
def conv_block(inputs, filters, kernel_size=(3, 3), strides=(1, 1), padding='same', activation='relu', use_bn=True):
    x = tf.keras.layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, use_bias=not use_bn)(inputs)
    if use_bn:
        x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(activation)(x)
    return x

def residual_block(inputs, filters, activation='relu'):
    x = conv_block(inputs, filters, activation=activation)
    x = conv_block(x, filters, activation=None) # No activation on last conv of residual block
    x = tf.keras.layers.Add()([inputs, x])
    x = tf.keras.layers.Activation(activation)(x)
    return x

def build_enhanced_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, CHANNELS), base_filters=64):
    inputs = tf.keras.layers.Input(shape=input_shape)

    # Encoder
    conv1 = conv_block(inputs, base_filters)
    conv1 = residual_block(conv1, base_filters)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1) # 128x128

    conv2 = conv_block(pool1, base_filters * 2)
    conv2 = residual_block(conv2, base_filters * 2)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2) # 64x64

    conv3 = conv_block(pool2, base_filters * 4)
    conv3 = residual_block(conv3, base_filters * 4)
    pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3) # 32x32

    conv4 = conv_block(pool3, base_filters * 8)
    conv4 = residual_block(conv4, base_filters * 8)
    pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv4) # 16x16

    # Bottleneck
    conv_bridge = conv_block(pool4, base_filters * 16)
    conv_bridge = residual_block(conv_bridge, base_filters * 16)

    # Decoder
    up1 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv_bridge) # 32x32
    concat1 = tf.keras.layers.Concatenate()([up1, conv4])
    conv5 = conv_block(concat1, base_filters * 8)
    conv5 = residual_block(conv5, base_filters * 8)

    up2 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv5) # 64x64
    concat2 = tf.keras.layers.Concatenate()([up2, conv3])
    conv6 = conv_block(concat2, base_filters * 4)
    conv6 = residual_block(conv6, base_filters * 4)

    up3 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv6) # 128x128
    concat3 = tf.keras.layers.Concatenate()([up3, conv2])
    conv7 = conv_block(concat3, base_filters * 2)
    conv7 = residual_block(conv7, base_filters * 2)

    up4 = tf.keras.layers.UpSampling2D(size=(2, 2))(conv7) # 256x256
    concat4 = tf.keras.layers.Concatenate()([up4, conv1])
    conv8 = conv_block(concat4, base_filters)
    conv8 = residual_block(conv8, base_filters)

    # Final output layer uses sigmoid for [0, 1] range
    output = tf.keras.layers.Conv2D(CHANNELS, (1, 1), activation='sigmoid', padding='same')(conv8)

    model = tf.keras.Model(inputs=inputs, outputs=output)
    return model

# --- Visualization Function ---
def visualize_deblurring(model, dataset, output_dir, num_examples=5, filename_prefix="deblur_viz"):
    """
    Generates and saves visualizations of deblurring results.
    """
    example_count = 0
    
    # Check if dataset is truly empty
    try:
        _ = next(iter(dataset.take(1)))
    except tf.errors.OutOfRangeError:
        print("Error: Visualization dataset is empty. Cannot generate images.")
        return

    for blur_img_batch, sharp_img_batch in dataset:
        if example_count >= num_examples:
            break

        deblurred_img_batch = model(blur_img_batch, training=False)

        for i in range(blur_img_batch.shape[0]):
            if example_count >= num_examples:
                break

            fig = plt.figure(figsize=(15, 5))

            plt.subplot(1, 3, 1)
            plt.imshow(blur_img_batch[i].numpy())
            plt.title('Blurred Input')
            plt.axis('off')

            plt.subplot(1, 3, 2)
            plt.imshow(deblurred_img_batch[i].numpy())
            plt.title(f'{MODEL_TO_LOAD.capitalize()} Deblurred')
            plt.axis('off')

            plt.subplot(1, 3, 3)
            plt.imshow(sharp_img_batch[i].numpy())
            plt.title('Ground Truth')
            plt.axis('off')

            plt.suptitle(f"Deblurring Visualization - Example {example_count + 1}", fontsize=16)
            plt.savefig(os.path.join(output_dir, f'{filename_prefix}_example_{example_count:03d}.png'))
            plt.close(fig)
            example_count += 1
            print(f"Generated visualization {example_count}/{num_examples}")

    if example_count == 0:
        print("No images were generated. Check dataset or model output.")
    else:
        print(f"\nSuccessfully generated {example_count} visualizations in '{output_dir}'")

# --- Main Visualization Script ---
def run_visualization():
    print(f"--- Preparing Visualization Dataset from {VISUALIZATION_DATASET_NAME}/{VISUALIZATION_DATASET_TYPE} ---")
    viz_dataset, num_viz_elements = create_image_dataset_for_viz(
        VISUALIZATION_DATASET_TYPE, BATCH_SIZE, VISUALIZATION_DATASET_NAME
    )

    if viz_dataset is None or num_viz_elements == 0:
        print("Exiting: No valid images found for visualization.")
        return

    print(f"Found {num_viz_elements} potential images for visualization.")

    print(f"\n--- Loading {MODEL_TO_LOAD.capitalize()} Model ---")
    if MODEL_TO_LOAD == 'student':
        model = build_enhanced_unet(base_filters=STUDENT_FILTERS)
        # The checkpoint_prefix should point to the directory that *contains* the actual
        # checkpoint files (e.g., ckpt-1.data-...). This is the 'student_ckpt' folder itself.
        checkpoint_prefix = os.path.join(CHECKPOINT_ROOT_DIR, "student_ckpt")
    elif MODEL_TO_LOAD == 'teacher':
        model = build_enhanced_unet(base_filters=TEACHER_FILTERS)
        checkpoint_prefix = os.path.join(CHECKPOINT_ROOT_DIR, "teacher_ckpt")
    else:
        print("Invalid model type specified. Choose 'teacher' or 'student'.")
        return

    checkpoint = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(checkpoint, checkpoint_prefix, max_to_keep=5) # Corrected usage

    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint).expect_partial()
        print(f"Successfully loaded {MODEL_TO_LOAD.capitalize()} model from {manager.latest_checkpoint}")
    else:
        print(f"Error: No checkpoint found for {MODEL_TO_LOAD.capitalize()} at {checkpoint_prefix}. "
              "Please ensure your training script saved checkpoints correctly and this path is accurate.")
        return

    print(f"\n--- Generating Visualizations using the {MODEL_TO_LOAD.capitalize()} Model ---")
    visualize_deblurring(model, viz_dataset, VISUALIZATION_OUTPUT_DIR, num_examples=10) # Generate 10 examples

if __name__ == '__main__':
    # Define STUDENT_FILTERS and TEACHER_FILTERS here,
    # as they are needed by build_enhanced_unet and might not be imported otherwise.
    # These should match the values used during training.
    TEACHER_FILTERS = 64
    STUDENT_FILTERS = 48
    run_visualization()

2025-07-11 15:09:30.528705: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752246570.740565      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752246570.801022      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


1 Physical GPUs, 1 Logical GPUs
--- Preparing Visualization Dataset from TextOCR/test ---


I0000 00:00:1752246585.846347      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Found 500 potential images for visualization.

--- Loading Student Model ---
Successfully loaded Student model from /kaggle/input/student/checkpoints/student_ckpt/ckpt-4

--- Generating Visualizations using the Student Model ---


I0000 00:00:1752246590.104338      19 cuda_dnn.cc:529] Loaded cuDNN version 90300


Generated visualization 1/10
Generated visualization 2/10
Generated visualization 3/10
Generated visualization 4/10
Generated visualization 5/10
Generated visualization 6/10
Generated visualization 7/10
Generated visualization 8/10
Generated visualization 9/10
Generated visualization 10/10

Successfully generated 10 visualizations in './visualizations'
