# Blind Deconvolution with FunSearch

This notebook demonstrates using FunSearch to evolve stopping criteria for the Lucy-Richardson blind deconvolution algorithm.

## Problem Description

Blind deconvolution aims to recover a sharp image from a blurred observation when the blur kernel (point spread function) is unknown or imperfectly known. The Lucy-Richardson algorithm is an iterative method that can restore images, but determining when to stop the iterations is crucial for balancing image quality with computational efficiency.

**FunSearch Goal**: Discover optimal stopping criteria that maximize image quality while minimizing unnecessary iterations.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from typing import Dict, Tuple, Callable

# If running in Colab, install required packages
try:
    import google.colab
    !pip install scipy matplotlib numpy
except ImportError:
    pass

## Core Algorithm Implementation

The following code implements the Lucy-Richardson deconvolution algorithm and supporting utilities. This is the program skeleton that FunSearch will evolve.

In [None]:
def generate_test_image(size: int = 64) -> np.ndarray:
    """Generate a synthetic test image with clear features."""
    x, y = np.meshgrid(np.linspace(-1, 1, size), np.linspace(-1, 1, size))
    
    image = np.zeros((size, size))
    image += 0.5 * np.exp(-(x**2 + y**2) / 0.2)
    image += 0.3 * np.sin(8 * np.pi * x) * np.exp(-y**2 / 0.1)
    image += 0.2 * np.sin(8 * np.pi * y) * np.exp(-x**2 / 0.1)
    
    image[size//4, size//4] += 0.8
    image[3*size//4, 3*size//4] += 0.6
    
    return np.clip(image, 0, 1)


def generate_blur_kernel(kernel_type: str, size: int = 15) -> np.ndarray:
    """Generate different types of blur kernels."""
    if kernel_type == "gaussian":
        sigma = size / 6
        kernel = np.zeros((size, size))
        center = size // 2
        for i in range(size):
            for j in range(size):
                kernel[i, j] = np.exp(-((i - center)**2 + (j - center)**2) / (2 * sigma**2))
        return kernel / np.sum(kernel)
    
    elif kernel_type == "motion":
        kernel = np.zeros((size, size))
        center = size // 2
        length = size // 2
        for i in range(length):
            kernel[center, center - length//2 + i] = 1
        return kernel / np.sum(kernel)
    
    return kernel / np.sum(kernel)


def lucy_richardson_iteration(current_estimate: np.ndarray, 
                            observed_image: np.ndarray,
                            psf: np.ndarray) -> np.ndarray:
    """Single Lucy-Richardson iteration."""
    convolved = signal.convolve2d(current_estimate, psf, mode='same', boundary='symm')
    convolved = np.maximum(convolved, 1e-10)
    ratio = observed_image / convolved
    psf_flipped = np.flipud(np.fliplr(psf))
    correction = signal.convolve2d(ratio, psf_flipped, mode='same', boundary='symm')
    return current_estimate * correction


def compute_psnr(image1: np.ndarray, image2: np.ndarray) -> float:
    """Compute Peak Signal-to-Noise Ratio between two images."""
    mse = np.mean((image1 - image2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * np.log10(1.0 / np.sqrt(mse))


def compute_image_gradient_norm(image: np.ndarray) -> float:
    """Compute the L2 norm of the image gradient."""
    grad_x = np.gradient(image, axis=1)
    grad_y = np.gradient(image, axis=0)
    return np.sqrt(np.mean(grad_x**2 + grad_y**2))


def compute_residual_norm(current: np.ndarray, previous: np.ndarray) -> float:
    """Compute normalized change between iterations."""
    diff = np.abs(current - previous)
    return np.mean(diff) / (np.mean(np.abs(current)) + 1e-10)

## FunSearch Specification

The commented-out decorators are just a way to indicate the main entry point of the program (`@funsearch.run`) and the function that *FunSearch* should evolve (`@funsearch.evolve`).

In [None]:
# @funsearch.run
def evaluate(test_cases: Dict[str, Dict]) -> float:
    """Evaluate the convergence criteria across multiple test cases.
    
    Args:
        test_cases: Dictionary of test scenarios with ground truth images,
                   blur kernels, and observed (blurred) images.
    
    Returns:
        Composite score based on quality, efficiency, and stability.
    """
    total_score = 0.0
    num_cases = 0
    
    for case_name, case_data in test_cases.items():
        ground_truth = case_data['ground_truth']
        observed = case_data['observed']
        psf = case_data['psf']
        
        # Run Lucy-Richardson with evolved stopping criteria
        deconvolved, iterations_used, _ = run_lucy_richardson_with_stopping(
            observed, psf, ground_truth, max_iterations=100
        )
        
        # Quality metric (PSNR)
        psnr = compute_psnr(ground_truth, deconvolved)
        quality_score = min(psnr / 30.0, 1.0)  # Normalize, cap at 30dB
        
        # Efficiency metric (fewer iterations is better)
        efficiency_score = max(0, 1.0 - iterations_used / 100.0)
        
        # Stability metric (penalize extreme iteration counts)
        stability_penalty = 0.1 if iterations_used < 5 or iterations_used > 80 else 0.0
        
        # Composite score for this case
        case_score = 0.6 * quality_score + 0.3 * efficiency_score - stability_penalty
        total_score += max(case_score, 0.0)
        num_cases += 1
    
    return total_score / num_cases if num_cases > 0 else 0.0


def run_lucy_richardson_with_stopping(
    observed_image: np.ndarray,
    psf: np.ndarray, 
    ground_truth: np.ndarray,
    max_iterations: int = 100
) -> Tuple[np.ndarray, int, float]:
    """Run Lucy-Richardson with evolved stopping criteria."""
    current_estimate = observed_image.copy()
    
    for iteration in range(max_iterations):
        previous_estimate = current_estimate.copy()
        current_estimate = lucy_richardson_iteration(current_estimate, observed_image, psf)
        
        # Use evolved stopping criteria
        if should_stop(current_estimate, previous_estimate, iteration, psf, observed_image):
            final_psnr = compute_psnr(ground_truth, current_estimate)
            return current_estimate, iteration + 1, final_psnr
    
    # Max iterations reached
    final_psnr = compute_psnr(ground_truth, current_estimate)
    return current_estimate, max_iterations, final_psnr


# @funsearch.evolve
def should_stop(current_image: np.ndarray, 
               previous_image: np.ndarray, 
               iteration: int,
               blur_kernel: np.ndarray, 
               observed_image: np.ndarray) -> bool:
    """Evolved stopping criteria for Lucy-Richardson deconvolution.
    
    Args:
        current_image: Current estimate of the deconvolved image
        previous_image: Previous iteration's estimate
        iteration: Current iteration number (0-indexed)
        blur_kernel: The point spread function used for deconvolution
        observed_image: The original blurred/observed image
        
    Returns:
        True if algorithm should stop, False to continue iterating.
    """
    # Trivial initial implementation - stop after fixed iterations
    return iteration >= 50

## Test and Demonstration

Let's create test cases and visualize the deconvolution process.

In [None]:
# Generate test cases
def create_test_cases() -> Dict[str, Dict]:
    """Create a set of test cases for blind deconvolution."""
    test_cases = {}
    
    # Test case 1: Gaussian blur
    ground_truth_1 = generate_test_image(64)
    gaussian_kernel = generate_blur_kernel("gaussian", 15)
    observed_1 = signal.convolve2d(ground_truth_1, gaussian_kernel, mode='same', boundary='symm')
    observed_1 = np.clip(observed_1 + np.random.normal(0, 0.01, observed_1.shape), 0, 1)
    
    test_cases['gaussian_blur'] = {
        'ground_truth': ground_truth_1,
        'observed': observed_1,
        'psf': gaussian_kernel
    }
    
    # Test case 2: Motion blur
    ground_truth_2 = generate_test_image(64)
    motion_kernel = generate_blur_kernel("motion", 15)
    observed_2 = signal.convolve2d(ground_truth_2, motion_kernel, mode='same', boundary='symm')
    observed_2 = np.clip(observed_2 + np.random.normal(0, 0.01, observed_2.shape), 0, 1)
    
    test_cases['motion_blur'] = {
        'ground_truth': ground_truth_2,
        'observed': observed_2,
        'psf': motion_kernel
    }
    
    return test_cases


# Create test cases
test_cases = create_test_cases()

# Evaluate the current (trivial) stopping criteria
score = evaluate(test_cases)
print(f"Current score with trivial stopping criteria: {score:.4f}")

## Visualization

In [None]:
# Visualize one test case
case_data = test_cases['gaussian_blur']
ground_truth = case_data['ground_truth']
observed = case_data['observed']
psf = case_data['psf']

# Run deconvolution
deconvolved, iterations_used, final_psnr = run_lucy_richardson_with_stopping(
    observed, psf, ground_truth, max_iterations=100
)

# Plot results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

axes[0, 0].imshow(ground_truth, cmap='gray')
axes[0, 0].set_title('Ground Truth')
axes[0, 0].axis('off')

axes[0, 1].imshow(observed, cmap='gray')
axes[0, 1].set_title('Blurred + Noise')
axes[0, 1].axis('off')

axes[0, 2].imshow(deconvolved, cmap='gray')
axes[0, 2].set_title(f'Deconvolved\n({iterations_used} iterations)')
axes[0, 2].axis('off')

axes[1, 0].imshow(psf, cmap='gray')
axes[1, 0].set_title('Blur Kernel')
axes[1, 0].axis('off')

# Show residual
residual = np.abs(deconvolved - ground_truth)
axes[1, 1].imshow(residual, cmap='hot')
axes[1, 1].set_title(f'Error Map\nPSNR: {final_psnr:.2f}dB')
axes[1, 1].axis('off')

# Show convergence over iterations
psnr_history = []
current_est = observed.copy()
for i in range(min(iterations_used + 10, 60)):
    current_est = lucy_richardson_iteration(current_est, observed, psf)
    psnr_history.append(compute_psnr(ground_truth, current_est))

axes[1, 2].plot(psnr_history)
axes[1, 2].axvline(x=iterations_used-1, color='red', linestyle='--', label='Stopping point')
axes[1, 2].set_xlabel('Iteration')
axes[1, 2].set_ylabel('PSNR (dB)')
axes[1, 2].set_title('Convergence')
axes[1, 2].legend()
axes[1, 2].grid(True)

plt.tight_layout()
plt.show()

print(f"Final PSNR: {final_psnr:.2f} dB")
print(f"Iterations used: {iterations_used}")
print(f"Overall evaluation score: {score:.4f}")

## How FunSearch Will Improve This

FunSearch will evolve the `should_stop` function to discover better stopping criteria. Some potential discoveries might include:

1. **Gradient-based stopping**: Stop when image gradients stabilize
2. **Residual analysis**: Stop when iteration-to-iteration changes fall below adaptive thresholds
3. **Multi-scale convergence**: Different criteria for different frequency components
4. **Blur-kernel-aware stopping**: Adapt criteria based on estimated PSF properties
5. **Quality-efficiency trade-offs**: Dynamic balancing based on iteration count and improvement rate

The evolved functions should significantly outperform the trivial fixed-iteration stopping criterion.