In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
import numpy as np
from skimage import io, filters, exposure, morphology, measure, segmentation, img_as_float, img_as_uint, transform
from scipy import ndimage
import glob
import tifffile
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from skimage.feature import peak_local_max

# ======= CONFIGURATION =======
# Set paths specifically for your Google Colab environment
INPUT_DIR = '/content/drive/MyDrive/knowledge/University/Master/Thesis/denoised_ordered/trial'
OUTPUT_DIR = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Projected/trial_tmp4'

# Default EDF parameters
DEFAULT_MAX_Z_DIFF = 1      # Maximum allowed z-difference between adjacent pixels
DEFAULT_SIGMA = 1.0         # Sigma for Laplacian operator smoothing
DEFAULT_GAUSS_DENOISE = 0.5 # Sigma for Gaussian denoising
DEFAULT_SCALES = [0.5, 1.0, 2.0]  # Scales for multi-scale focus measure
DEFAULT_MAX_REFINEMENT_DIST = 2  # Maximum distance to search during refinement

# Define channel names (customize these based on your data)
CHANNEL_NAMES = ['Cadherins', 'Nuclei', 'Golgi']

# ======= UTILITY FUNCTIONS =======

def create_dirs():
    """Create necessary directories for the new output structure"""
    os.makedirs(INPUT_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Create the new directory structure
    for channel in ['Cadherins', 'Nuclei', 'Golgi']:
        if channel == 'Nuclei':
            # Nuclei only has tophat
            os.makedirs(os.path.join(OUTPUT_DIR, channel, 'tophat'), exist_ok=True)
        else:
            # Cadherins and Golgi have both tophat and background
            os.makedirs(os.path.join(OUTPUT_DIR, channel, 'tophat'), exist_ok=True)
            os.makedirs(os.path.join(OUTPUT_DIR, channel, 'background'), exist_ok=True)

def normalize_image(img):
    """Normalize image to 0-1 range"""
    img_min = np.min(img)
    img_max = np.max(img)
    if img_max > img_min:
        return (img - img_min) / (img_max - img_min)
    return img

def get_valid_neighbors(y, x, height, width):
    """Get valid 4-connected neighbors"""
    neighbors = []
    for ny, nx in [(y-1, x), (y+1, x), (y, x-1), (y, x+1)]:
        if 0 <= ny < height and 0 <= nx < width:
            neighbors.append((ny, nx))
    return neighbors

# ======= BACKGROUND REMOVAL =======

def tophat_filter_background(image, radius=15):
    """Remove background using white top-hat filter"""
    # Normalize input to 0-1 range
    img = normalize_image(img_as_float(image))

    # Create structuring element (disk)
    selem = morphology.disk(radius)

    # Apply white top-hat filter (removes background while preserving foreground)
    tophat = morphology.white_tophat(img, selem)

    # Enhance contrast and ensure 0-1 range
    tophat = exposure.rescale_intensity(tophat)

    # Double-check normalization
    tophat = normalize_image(tophat)

    return tophat

# ======= IMPROVED FOCUS MEASURE CALCULATION =======

def enhanced_focus_measure(image_stack, gauss_denoise=DEFAULT_GAUSS_DENOISE):
    """
    Calculate enhanced focus measures using multiple focus operators
    """
    z_size, height, width = image_stack.shape
    focus_measures = np.zeros((z_size, height, width))

    for z in range(z_size):
        # Apply Gaussian filter to reduce noise
        if gauss_denoise > 0:
            blurred = filters.gaussian(image_stack[z], sigma=gauss_denoise)
        else:
            blurred = image_stack[z]

        # Calculate multiple focus measures
        # 1. Laplacian (edge detection)
        laplacian = np.abs(filters.laplace(blurred, ksize=3))

        # 2. Sobel gradient magnitude (detects intensity changes)
        sobel_h = filters.sobel_h(blurred)
        sobel_v = filters.sobel_v(blurred)
        sobel_mag = np.sqrt(sobel_h**2 + sobel_v**2)

        # 3. Local variance (texture measure)
        local_mean = filters.gaussian(blurred, sigma=1.5)
        local_var = filters.gaussian((blurred - local_mean)**2, sigma=1.5)

        # Combine measures with weights
        combined = (0.5 * filters.gaussian(laplacian, sigma=1.0) +
                    0.3 * filters.gaussian(sobel_mag, sigma=1.0) +
                    0.2 * filters.gaussian(local_var, sigma=1.0))

        focus_measures[z] = combined

    return focus_measures

def multi_scale_focus_measure(image_stack, scales=DEFAULT_SCALES):
    """
    Calculate focus measures at multiple scales to better capture different feature sizes
    """
    z_size, height, width = image_stack.shape
    combined_focus = np.zeros((z_size, height, width))

    for scale in scales:
        # Calculate focus at this scale
        if scale != 1.0:
            # Resize images for different scales
            scaled_stack = np.zeros((z_size, int(height*scale), int(width*scale)))
            for z in range(z_size):
                scaled_stack[z] = transform.resize(
                    image_stack[z],
                    (int(height*scale), int(width*scale)),
                    anti_aliasing=True
                )
            focus_at_scale = enhanced_focus_measure(scaled_stack)

            # Resize focus measures back to original size
            resized_focus = np.zeros((z_size, height, width))
            for z in range(z_size):
                resized_focus[z] = transform.resize(
                    focus_at_scale[z],
                    (height, width),
                    anti_aliasing=True
                )
            focus_at_scale = resized_focus
        else:
            focus_at_scale = enhanced_focus_measure(image_stack)

        # Add to combined focus
        combined_focus += focus_at_scale

    return combined_focus / len(scales)

# ======= IMPROVED SEED SELECTION =======

def improved_seed_selection(focus_measures, initial_best_z):
    """
    Improved seed point selection with consistency checking
    """
    z_size, height, width = focus_measures.shape

    # Calculate focus confidence
    confidence = np.zeros((height, width))
    local_consistency = np.zeros((height, width))

    for y in range(height):
        for x in range(width):
            z = initial_best_z[y, x]
            # Primary confidence: how much better is best slice compared to others
            focus_values = focus_measures[:, y, x]
            sorted_focus = np.sort(focus_values)
            if len(sorted_focus) >= 2:
                confidence[y, x] = (sorted_focus[-1] - sorted_focus[-2]) / (sorted_focus[-1] + 1e-6)

            # Local consistency: check if neighbors agree on best z
            z_counts = {}
            for ny, nx in get_valid_neighbors(y, x, height, width):
                nz = initial_best_z[ny, nx]
                z_counts[nz] = z_counts.get(nz, 0) + 1

            # Higher consistency if more neighbors agree
            max_count = max(z_counts.values()) if z_counts else 0
            local_consistency[y, x] = max_count / 4.0  # Normalize by max possible neighbors

    # Combined score
    combined_score = 0.7 * confidence + 0.3 * local_consistency

    # Select top percentile as seeds
    seed_threshold = np.percentile(combined_score, 95)
    seed_indices = np.where(combined_score > seed_threshold)
    seed_points = list(zip(seed_indices[0], seed_indices[1]))

    # Sort seed points by score (highest first)
    seed_points.sort(key=lambda idx: combined_score[idx[0], idx[1]], reverse=True)

    return seed_points, combined_score

# ======= ADAPTIVE REGION SIZE =======

def adaptive_region_size(image_stack):
    """
    Determine adaptive region sizes based on image features
    """
    # Use middle slice for feature calculation
    mid_z = image_stack.shape[0] // 2
    mid_slice = image_stack[mid_z]

    # Calculate edge density
    edges = filters.sobel(mid_slice)

    # Calculate local feature density using a sliding window
    feature_density = filters.gaussian(edges, sigma=10.0)

    # Normalize to range 5-30
    min_region = 5
    max_region = 30
    normalized = min_region + (max_region - min_region) * (1.0 - normalize_image(feature_density))

    # Return as integer
    return normalized.astype(np.int32)

# ======= IMPROVED SPATIAL CONTINUITY =======

def edf_with_spatial_continuity_region(focus_measures, max_z_diff=DEFAULT_MAX_Z_DIFF):
    """
    Enhanced spatial continuity with better region growing
    """
    z_size, height, width = focus_measures.shape

    # Find initial z-slice with maximum focus for each pixel
    initial_best_z = np.argmax(focus_measures, axis=0)

    # Get improved seed points
    seed_points, confidence_scores = improved_seed_selection(focus_measures, initial_best_z)

    # Create a processed mask to track pixels that have been assigned a final z-value
    processed = np.zeros((height, width), dtype=bool)

    # Create the output z-map that will be filled with spatially consistent z-values
    final_best_z = np.copy(initial_best_z)

    # Initialize frontier with seed points
    frontier = seed_points.copy()
    for y, x in frontier:
        processed[y, x] = True

    # Process each point and its neighbors using a breadth-first approach
    while frontier:
        y, x = frontier.pop(0)
        current_z = final_best_z[y, x]

        # Check neighbors
        for ny, nx in get_valid_neighbors(y, x, height, width):
            if processed[ny, nx]:
                continue

            # Find best z within allowed range from current pixel
            z_min = max(0, current_z - max_z_diff)
            z_max = min(z_size - 1, current_z + max_z_diff)

            # Extract the relevant slice of focus measures and find best z
            z_slice = focus_measures[z_min:z_max+1, ny, nx]
            relative_best_z = np.argmax(z_slice)
            final_best_z[ny, nx] = z_min + relative_best_z

            processed[ny, nx] = True
            frontier.append((ny, nx))

    # Check if there are any unprocessed pixels left (should be rare)
    unprocessed = ~processed
    if np.any(unprocessed):
        # Assign them the value of nearest processed neighbor
        dist, indices = ndimage.distance_transform_edt(
            unprocessed, return_indices=True)

        # Assign z-values from nearest processed pixel
        for y, x in zip(*np.where(unprocessed)):
            idx_y, idx_x = indices[0, y, x], indices[1, y, x]
            final_best_z[y, x] = final_best_z[idx_y, idx_x]

    return final_best_z

# ======= HYBRID SELECTION STRATEGY =======

def hybrid_best_z_selection(focus_measures, max_z_diff=DEFAULT_MAX_Z_DIFF):
    """
    Hybrid approach for z-selection using both confidence and spatial continuity
    """
    z_size, height, width = focus_measures.shape

    # Get initial z-indices
    initial_best_z = np.argmax(focus_measures, axis=0)

    # Get seed points and confidence map
    seed_points, confidence_scores = improved_seed_selection(focus_measures, initial_best_z)

    # Get consistency-based result
    consistency_z_map = edf_with_spatial_continuity_region(focus_measures, max_z_diff)

    # Blend results based on confidence
    final_z_map = np.zeros_like(initial_best_z)

    # Normalize confidence for blending
    blend_factor = np.clip(confidence_scores * 2, 0, 1)

    for y in range(height):
        for x in range(width):
            # Use weighted selection based on confidence
            if blend_factor[y, x] > 0.7:
                final_z_map[y, x] = initial_best_z[y, x]  # Trust high confidence pixels
            else:
                final_z_map[y, x] = consistency_z_map[y, x]  # Otherwise use consistency

    return final_z_map

# ======= POST-PROCESSING REFINEMENT =======

def refine_z_map(best_z_map, focus_measures, max_refinement_dist=DEFAULT_MAX_REFINEMENT_DIST):
    """
    Refine the z-map by checking neighborhood focus values
    """
    z_size, height, width = focus_measures.shape
    refined_map = np.copy(best_z_map)

    for y in range(height):
        for x in range(width):
            current_z = int(best_z_map[y, x])

            # Define search range
            z_min = max(0, current_z - max_refinement_dist)
            z_max = min(z_size - 1, current_z + max_refinement_dist)

            # Get focus values in local z-neighborhood
            local_z_values = focus_measures[z_min:z_max+1, y, x]

            # Find local maximum
            local_best_z = z_min + np.argmax(local_z_values)

            # Update only if focus is significantly better
            if focus_measures[local_best_z, y, x] > focus_measures[current_z, y, x] * 1.1:
                refined_map[y, x] = local_best_z

    return refined_map

# ======= IMPROVED REGIONAL GUIDED EDF METHOD =======

def improved_edf_regional_guided(image_stack, max_z_diff=DEFAULT_MAX_Z_DIFF,
                                max_refinement_dist=DEFAULT_MAX_REFINEMENT_DIST):
    """
    Improved regional guided EDF with enhanced slice selection
    """
    # Convert to float and ensure 0-1 range
    stack = img_as_float(image_stack)
    stack = normalize_image(stack)
    z_size, height, width = stack.shape

    # Calculate multi-scale focus measures
    print("Calculating multi-scale focus measures...")
    focus_measures = multi_scale_focus_measure(stack)

    # Get initial best z-indices
    initial_best_z = np.argmax(focus_measures, axis=0)

    # Calculate adaptive region sizes
    print("Determining adaptive region sizes...")
    region_sizes = adaptive_region_size(stack)

    # Create output arrays
    best_z = np.zeros((height, width), dtype=np.float64)
    weights = np.zeros((height, width), dtype=np.float32)

    # Process image in overlapping regions with adaptive sizes
    print("Processing in adaptive regions...")
    min_overlap = 3  # Minimum overlap between regions

    for y_start in range(0, height, 10):  # Use fixed step but adaptive region size
        for x_start in range(0, width, 10):
            # Get adaptive region size for this area
            y_center, x_center = y_start + 5, x_start + 5
            if 0 <= y_center < height and 0 <= x_center < width:
                region_size = int(region_sizes[y_center, x_center])
            else:
                region_size = 15  # Default

            # Ensure minimum size
            region_size = max(region_size, 5)

            # Calculate overlap
            overlap = max(min_overlap, region_size // 3)

            # Define region bounds
            y_end = min(y_start + region_size, height)
            x_end = min(x_start + region_size, width)

            # Skip too small regions
            if y_end - y_start < 3 or x_end - x_start < 3:
                continue

            # Extract region focus measures
            region_focus = focus_measures[:, y_start:y_end, x_start:x_end]

            # Apply hybrid selection for this region
            region_best_z = hybrid_best_z_selection(region_focus, max_z_diff)

            # Create weight mask (higher weights in center, lower at edges)
            y_grid, x_grid = np.mgrid[y_start:y_end, x_start:x_end]
            y_center = (y_start + y_end) / 2
            x_center = (x_start + x_end) / 2

            # Calculate distance from center (normalized to 0-1)
            y_dist = np.abs(y_grid - y_center) / (region_size / 2)
            x_dist = np.abs(x_grid - x_center) / (region_size / 2)
            dist = np.maximum(y_dist, x_dist)
            region_weights = np.clip(1.0 - dist, 0.1, 1.0)

            # Convert region_best_z to float64 before adding
            region_best_z_float = region_best_z.astype(np.float64)

            # Add weighted contribution to output
            best_z[y_start:y_end, x_start:x_end] += region_best_z_float * region_weights
            weights[y_start:y_end, x_start:x_end] += region_weights

    # Normalize by weights and round to nearest integer
    best_z = np.round(best_z / np.maximum(weights, 1e-6)).astype(np.int32)

    # Ensure best_z values are within valid range
    best_z = np.clip(best_z, 0, z_size - 1)

    # Refine the z-map
    print("Refining z-map...")
    refined_best_z = refine_z_map(best_z, focus_measures, max_refinement_dist)

    # Create output by taking pixels from best z-slices
    print("Creating final output...")
    result = np.zeros((height, width), dtype=np.float32)
    for z in range(z_size):
        mask = refined_best_z == z
        result[mask] = stack[z][mask]

    # Create a visualization of the focus map (normalized to 0-1)
    focus_map = refined_best_z / (z_size - 1)

    return result, focus_map, refined_best_z

# ======= MAIN PROCESSING FUNCTIONS =======

def process_single_stack(stack, apply_tophat=False, max_z_diff=DEFAULT_MAX_Z_DIFF,
                        max_refinement_dist=DEFAULT_MAX_REFINEMENT_DIST):
    """
    Process a single z-stack with the improved regional EDF method
    """
    print(f"Processing stack with shape {stack.shape}...")

    # Convert to float and ensure valid range (0-1)
    stack = img_as_float(stack)

    # Ensure stack values are in 0-1 range
    min_val = np.min(stack)
    max_val = np.max(stack)
    if min_val < 0 or max_val > 1:
        print(f"Normalizing stack from range [{min_val}, {max_val}] to [0, 1]")
        stack = normalize_image(stack)

    # Apply background removal if requested
    if apply_tophat:
        print("Applying tophat background removal...")
        processed_stack = np.zeros_like(stack, dtype=np.float32)
        for z in range(stack.shape[0]):
            processed_stack[z] = tophat_filter_background(stack[z])
        working_stack = processed_stack
    else:
        working_stack = stack

    # Store original slices for comparison
    mid_z_idx = stack.shape[0] // 2

    # Prepare results dictionary
    results = {
        'original_middle': stack[mid_z_idx],
        'original_max': np.max(stack, axis=0)
    }

    # Apply improved regional method
    print("Applying improved regional guided EDF...")
    proj, focus_map, _ = improved_edf_regional_guided(
        working_stack, max_z_diff=max_z_diff, max_refinement_dist=max_refinement_dist)
    results['regional_edf'] = proj
    results['focus_map_regional'] = focus_map

    return results

def save_results_to_disk(results, filename, channel_name, apply_tophat):
    """Save results to disk with updated directory structure and naming"""
    base_name = os.path.splitext(filename)[0]

    # Get only the projection result (not focus map)
    key = "regional_edf"
    if key in results:
        # Determine processing type folder
        process_type = "tophat" if apply_tophat else "background"

        # Skip if this is not one of the required outputs
        if channel_name == "Nuclei" and not apply_tophat:
            print(f"Skipping {channel_name} without tophat (not required)")
            return

        # Create appropriate directory path
        output_dir = os.path.join(OUTPUT_DIR, channel_name, process_type)
        os.makedirs(output_dir, exist_ok=True)

        # Create output filename - using just the base name as requested
        output_filename = f"{base_name}_{channel_name}_regional"
        if apply_tophat:
            output_filename += "_tophat"
        output_filename += ".tif"

        output_path = os.path.join(output_dir, output_filename)

        # Normalize to 0-1 range and convert to uint16
        projection = results[key]
        projection_normalized = normalize_image(projection)
        projection_uint = img_as_uint(projection_normalized)

        # Save
        tifffile.imwrite(output_path, projection_uint)
        print(f"Saved {output_path}")

        # Optionally save focus map for analysis
        if 'focus_map_regional' in results:
            focus_map = results['focus_map_regional']
            focus_map_normalized = normalize_image(focus_map)
            focus_map_uint = img_as_uint(focus_map_normalized)

            focus_map_filename = f"{base_name}_{channel_name}_focus_map"
            if apply_tophat:
                focus_map_filename += "_tophat"
            focus_map_filename += ".tif"

            focus_map_path = os.path.join(output_dir, focus_map_filename)
            tifffile.imwrite(focus_map_path, focus_map_uint)
            print(f"Saved focus map: {focus_map_path}")

def process_directory(input_dir=INPUT_DIR, max_z_diff=DEFAULT_MAX_Z_DIFF,
                     max_refinement_dist=DEFAULT_MAX_REFINEMENT_DIST,
                     file_pattern='*.tif*', save_results=True):
    """
    Process all files in a directory with channel-specific settings
    """
    # Get all matching files
    all_files = glob.glob(os.path.join(input_dir, file_pattern), recursive=True)
    all_files.sort()

    if not all_files:
        print(f"No files matching '{file_pattern}' found in {input_dir}")
        return []

    print(f"Found {len(all_files)} files to process")

    # Create output directories
    create_dirs()

    # Process each file
    processed_files = []

    for file_idx, file_path in enumerate(all_files):
        print(f"\nProcessing file {file_idx+1}/{len(all_files)}: {os.path.basename(file_path)}")

        try:
            results = process_file(
                file_path,
                max_z_diff=max_z_diff,
                max_refinement_dist=max_refinement_dist,
                save_results=save_results
            )

            if results:
                processed_files.append(file_path)
                print(f"Successfully processed {os.path.basename(file_path)}")

        except Exception as e:
            print(f"Error processing file {os.path.basename(file_path)}: {e}")
            import traceback
            traceback.print_exc()

    print(f"\nSuccessfully processed {len(processed_files)}/{len(all_files)} files")
    return processed_files

def process_file(file_path, max_z_diff=DEFAULT_MAX_Z_DIFF,
                max_refinement_dist=DEFAULT_MAX_REFINEMENT_DIST,
                save_results=True):
    """
    Process a single file with channel-specific processing:
    - Nuclei (channel 1): Apply tophat only
    - Cadherins and Golgi (channels 0 and 2): Create with and without tophat
    """
    filename = os.path.basename(file_path)
    print(f"Processing file: {filename}")

    # Load image
    try:
        image = tifffile.imread(file_path)
        print(f"Image shape: {image.shape}, dtype: {image.dtype}")
    except Exception as e:
        print(f"Error loading image: {e}")
        return None

    # Handle different dimensionalities
    if len(image.shape) == 3:
        # Single channel z-stack - process with and without tophat
        results_no_tophat = process_single_stack(
            image, apply_tophat=False,
            max_z_diff=max_z_diff,
            max_refinement_dist=max_refinement_dist)

        results_tophat = process_single_stack(
            image, apply_tophat=True,
            max_z_diff=max_z_diff,
            max_refinement_dist=max_refinement_dist)

        if save_results:
            save_results_to_disk(results_no_tophat, filename, "default", False)
            save_results_to_disk(results_tophat, filename, "default", True)

        return {"default_no_tophat": results_no_tophat, "default_tophat": results_tophat}

    elif len(image.shape) == 4:
        # Multi-channel z-stack with channel-specific processing
        all_results = {}

        # Process each channel with appropriate settings
        for ch_idx in range(image.shape[0]):
            # Extract channel
            channel_data = image[ch_idx]
            ch_name = CHANNEL_NAMES[ch_idx] if ch_idx < len(CHANNEL_NAMES) else f"channel_{ch_idx}"

            # Apply channel-specific processing
            if ch_idx == 1:  # Nuclei (second channel) - tophat only
                print(f"Processing {ch_name} with tophat...")
                results = process_single_stack(
                    channel_data, apply_tophat=True,
                    max_z_diff=max_z_diff,
                    max_refinement_dist=max_refinement_dist)

                all_results[f"{ch_name}_tophat"] = results

                if save_results:
                    save_results_to_disk(results, filename, ch_name, True)

            else:  # Cadherins and Golgi (channels 0 and 2) - with and without tophat
                print(f"Processing {ch_name} without tophat...")
                results_no_tophat = process_single_stack(
                    channel_data, apply_tophat=False,
                    max_z_diff=max_z_diff,
                    max_refinement_dist=max_refinement_dist)

                print(f"Processing {ch_name} with tophat...")
                results_tophat = process_single_stack(
                    channel_data, apply_tophat=True,
                    max_z_diff=max_z_diff,
                    max_refinement_dist=max_refinement_dist)

                all_results[f"{ch_name}_no_tophat"] = results_no_tophat
                all_results[f"{ch_name}_tophat"] = results_tophat

                if save_results:
                    save_results_to_disk(results_no_tophat, filename, ch_name, False)
                    save_results_to_disk(results_tophat, filename, ch_name, True)

        return all_results

    else:
        print(f"Unsupported image dimensions: {len(image.shape)}")
        return None

# ======= VISUALIZATION FUNCTIONS =======

def visualize_focus_comparison(original_stack, improved_result, original_best_z=None, improved_best_z=None):
    """
    Visualize the comparison between original and improved methods
    """
    z_size = original_stack.shape[0]

    # Create figure
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Original middle slice
    mid_z = z_size // 2
    axes[0, 0].imshow(original_stack[mid_z], cmap='gray')
    axes[0, 0].set_title(f'Original Middle Slice (z={mid_z})')

    # Original max projection
    axes[0, 1].imshow(np.max(original_stack, axis=0), cmap='gray')
    axes[0, 1].set_title('Original Max Projection')

    # Original best z map if available
    if original_best_z is not None:
        original_focus_map = original_best_z / (z_size - 1)
        axes[0, 2].imshow(original_focus_map, cmap='viridis')
        axes[0, 2].set_title('Original Focus Map')
    else:
        axes[0, 2].set_visible(False)

    # Improved result
    axes[1, 0].imshow(improved_result, cmap='gray')
    axes[1, 0].set_title('Improved EDF Result')

    # Difference
    diff = normalize_image(np.abs(improved_result - np.max(original_stack, axis=0)))
    axes[1, 1].imshow(diff, cmap='hot')
    axes[1, 1].set_title('Difference (Improved vs Max Proj)')

    # Improved best z map if available
    if improved_best_z is not None:
        improved_focus_map = improved_best_z / (z_size - 1)
        axes[1, 2].imshow(improved_focus_map, cmap='viridis')
        axes[1, 2].set_title('Improved Focus Map')
    else:
        axes[1, 2].set_visible(False)

    # Remove ticks
    for ax in axes.ravel():
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    return fig

# Main execution - just call process_directory with your parameters
def main():
    create_dirs()
    processed_files = process_directory(
        input_dir=INPUT_DIR,
        max_z_diff=DEFAULT_MAX_Z_DIFF,
        max_refinement_dist=DEFAULT_MAX_REFINEMENT_DIST,
        file_pattern='*.tif*',
        save_results=True
    )
    print(f"Processed {len(processed_files)} files")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Found 1 files to process

Processing file 1/1: denoised_1.4Pa_A1_20dec21_20xA_L2RA_FlatA_seq004.tif
Processing file: denoised_1.4Pa_A1_20dec21_20xA_L2RA_FlatA_seq004.tif
Image shape: (3, 13, 1024, 1024), dtype: float32
Processing Cadherins without tophat...
Processing stack with shape (13, 1024, 1024)...
Normalizing stack from range [-0.012688316404819489, 1.9504793882369995] to [0, 1]
Applying improved regional guided EDF...
Calculating multi-scale focus measures...
Determining adaptive region sizes...
Processing in adaptive regions...
Refining z-map...
Creating final output...
Processing Cadherins with tophat...
Processing stack with shape (13, 1024, 1024)...
Normalizing stack from range [-0.012688316404819489, 1.9504793882369995] to [0, 1]
Applying tophat background removal...
Applying improved regional guided EDF...
Calculating multi-scale focus measures.