In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
import numpy as np
from skimage import io, filters, exposure, morphology, measure, segmentation, img_as_float, img_as_uint
from skimage.feature import peak_local_max
from scipy import ndimage
import glob
import tifffile
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# ======= CONFIGURATION =======
# Set paths specifically for your Google Colab environment
INPUT_DIR = '/content/drive/MyDrive/knowledge/University/Master/Thesis/denoised_ordered/trial'
OUTPUT_DIR = '/content/drive/MyDrive/knowledge/University/Master/Thesis/Projected/trial2'

# Default EDF parameters
DEFAULT_MAX_Z_DIFF = 1      # Maximum allowed z-difference between adjacent pixels
DEFAULT_SIGMA = 1.0         # Sigma for Laplacian operator smoothing
DEFAULT_GAUSS_DENOISE = 0.5 # Sigma for Gaussian denoising

# Define channel names (customize these based on your data)
CHANNEL_NAMES = ['Cadherins', 'Nuclei', 'Golgi']

# ======= UTILITY FUNCTIONS =======

def create_dirs():
    """Create necessary directories for the new output structure"""
    os.makedirs(INPUT_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Create the new directory structure
    for channel in ['Cadherins', 'Nuclei', 'Golgi']:
        if channel == 'Nuclei':
            # Nuclei only has tophat
            os.makedirs(os.path.join(OUTPUT_DIR, channel, 'tophat'), exist_ok=True)
        else:
            # Cadherins and Golgi have both tophat and background
            os.makedirs(os.path.join(OUTPUT_DIR, channel, 'tophat'), exist_ok=True)
            os.makedirs(os.path.join(OUTPUT_DIR, channel, 'background'), exist_ok=True)

def normalize_image(img):
    """Normalize image to 0-1 range"""
    img_min = np.min(img)
    img_max = np.max(img)
    if img_max > img_min:
        return (img - img_min) / (img_max - img_min)
    return img

# ======= BACKGROUND REMOVAL =======

def tophat_filter_background(image, radius=15):
    """Remove background using white top-hat filter"""
    # Normalize input to 0-1 range
    img = normalize_image(img_as_float(image))

    # Create structuring element (disk)
    selem = morphology.disk(radius)

    # Apply white top-hat filter (removes background while preserving foreground)
    tophat = morphology.white_tophat(img, selem)

    # Enhance contrast and ensure 0-1 range
    tophat = exposure.rescale_intensity(tophat)

    # Double-check normalization
    tophat = normalize_image(tophat)

    return tophat

# ======= ENHANCED FOCUS MEASURE CALCULATION =======

def calculate_focus_measures(stack, gauss_denoise=0.5, sigma=1.0):
    """Enhanced focus measure calculation using multiple methods"""
    z_size, height, width = stack.shape
    focus_measures = np.zeros((z_size, height, width))

    for z in range(z_size):
        # Apply Gaussian filter to reduce noise
        if gauss_denoise > 0:
            blurred = filters.gaussian(stack[z], sigma=gauss_denoise)
        else:
            blurred = stack[z]

        # 1. Laplacian (edge detection)
        laplacian = np.abs(filters.laplace(blurred, ksize=3))

        # 2. Sobel gradient magnitude (another edge detector)
        sobel_h = filters.sobel_h(blurred)
        sobel_v = filters.sobel_v(blurred)
        sobel = np.sqrt(sobel_h**2 + sobel_v**2)

        # 3. Variance (local contrast)
        variance = ndimage.generic_filter(blurred, np.var, size=3)
        variance = normalize_image(variance)

        # Combine measures (weighted sum)
        combined = (0.5 * laplacian +
                   0.3 * normalize_image(sobel) +
                   0.2 * variance)

        # Apply small Gaussian to make the decision more robust
        if sigma > 0:
            focus_measures[z] = filters.gaussian(combined, sigma=sigma)
        else:
            focus_measures[z] = combined

    return focus_measures

# ======= ADAPTIVE REGION SIZE =======

def determine_adaptive_region_size(image, min_size=10, max_size=25):
    """Determine appropriate region size based on image complexity"""
    # Calculate edge density
    edges = filters.sobel(image)
    edge_density = np.mean(edges > 0.05)  # Threshold for edge detection

    # Scale region size inversely with edge density
    # More edges -> smaller regions for better detail
    if edge_density > 0.2:
        return min_size
    elif edge_density < 0.05:
        return max_size
    else:
        # Linear interpolation between min and max size
        return int(max_size - (edge_density - 0.05) * (max_size - min_size) / 0.15)

# ======= BETTER SEED POINT SELECTION =======

def select_seed_points(focus_measures, initial_best_z):
    """Select better seed points based on focus confidence and local maxima"""
    z_size, height, width = focus_measures.shape

    # Calculate confidence (focus measure at the selected z-slice)
    confidence = np.zeros((height, width))
    for y in range(height):
        for x in range(width):
            z = initial_best_z[y, x]
            confidence[y, x] = focus_measures[z, y, x]

    # Find local maxima in the confidence map
    local_max_coords = peak_local_max(
        confidence,
        min_distance=5,  # Minimum distance between peaks
        threshold_abs=0.1,  # Absolute threshold for peak detection
        exclude_border=2   # Exclude border pixels
    )

    # Sort by confidence value (highest first)
    local_max_values = [confidence[y, x] for y, x in local_max_coords]
    sorted_indices = np.argsort(local_max_values)[::-1]
    sorted_seeds = [local_max_coords[i] for i in sorted_indices]

    # Limit number of seeds to prevent over-segmentation
    max_seeds = max(10, int(0.01 * height * width))  # Max 1% of pixels as seeds, but at least 10
    return sorted_seeds[:max_seeds]

# ======= ADAPTIVE Z-DIFFERENCE PARAMETER =======

def calculate_adaptive_z_diff(focus_measures, initial_best_z, base_max_z_diff=1):
    """Calculate adaptive max_z_diff based on local depth variations"""
    z_size, height, width = focus_measures.shape

    # Calculate local z-variance in small neighborhoods
    z_var = ndimage.generic_filter(
        initial_best_z.astype(float),
        np.var,
        size=5,  # 5x5 neighborhood
        mode='nearest'
    )

    # Scale max_z_diff based on local variance
    # Higher variance -> larger z_diff allowed
    adaptive_z_diff = np.clip(
        base_max_z_diff + np.sqrt(z_var),
        base_max_z_diff,
        base_max_z_diff * 3  # Cap at 3x the base value
    ).astype(np.int32)

    return adaptive_z_diff

# ======= ENHANCED REGIONAL GUIDED EDF =======

def edf_with_spatial_continuity_region_enhanced(focus_measures, max_z_diff=1):
    """
    Enhanced version of spatial continuity with better seed selection
    """
    z_size, height, width = focus_measures.shape

    # Find initial z-slice with maximum focus for each pixel
    initial_best_z = np.argmax(focus_measures, axis=0)

    # Get better seed points
    seed_points = select_seed_points(focus_measures, initial_best_z)

    # Create a processed mask to track pixels that have been assigned a final z-value
    processed = np.zeros((height, width), dtype=bool)

    # Create the output z-map that will be filled with spatially consistent z-values
    final_best_z = np.copy(initial_best_z)

    # Function to get valid 4-connected neighbors
    def get_neighbors(y, x):
        neighbors = []
        for ny, nx in [(y-1, x), (y+1, x), (y, x-1), (y, x+1)]:
            if 0 <= ny < height and 0 <= nx < width:
                neighbors.append((ny, nx))
        return neighbors

    # Initialize frontier with seed points
    frontier = seed_points.copy()
    for y, x in frontier:
        processed[y, x] = True

    # Process each point and its neighbors using a breadth-first approach
    while frontier:
        y, x = frontier.pop(0)
        current_z = final_best_z[y, x]

        # Check neighbors
        for ny, nx in get_neighbors(y, x):
            if processed[ny, nx]:
                continue

            # Find best z within allowed range from current pixel
            z_min = max(0, current_z - max_z_diff)
            z_max = min(z_size - 1, current_z + max_z_diff)

            # Extract the relevant slice of focus measures and find best z
            z_slice = focus_measures[z_min:z_max+1, ny, nx]
            relative_best_z = np.argmax(z_slice)
            final_best_z[ny, nx] = z_min + relative_best_z

            processed[ny, nx] = True
            frontier.append((ny, nx))

    # Check if there are any unprocessed pixels left (should be rare)
    unprocessed = ~processed
    if np.any(unprocessed):
        # Assign them the value of nearest processed neighbor
        dist, indices = ndimage.distance_transform_edt(
            unprocessed, return_indices=True)

        # Assign z-values from nearest processed pixel
        for y, x in zip(*np.where(unprocessed)):
            idx_y, idx_x = indices[0, y, x], indices[1, y, x]
            final_best_z[y, x] = final_best_z[idx_y, idx_x]

    return final_best_z

def edf_regional_guided_enhanced(image_stack, base_max_z_diff=DEFAULT_MAX_Z_DIFF,
                                sigma=DEFAULT_SIGMA, gauss_denoise=DEFAULT_GAUSS_DENOISE):
    """
    Enhanced regional guided EDF with adaptive parameters
    """
    # Convert to float and ensure 0-1 range
    stack = img_as_float(image_stack)
    stack = normalize_image(stack)
    z_size, height, width = stack.shape

    # Determine adaptive region size based on middle slice
    mid_z = z_size // 2
    adaptive_region_size = determine_adaptive_region_size(stack[mid_z])
    overlap = max(3, adaptive_region_size // 3)  # Adaptive overlap

    print(f"Using adaptive region size: {adaptive_region_size}, overlap: {overlap}")

    # Calculate enhanced focus measures
    focus_measures = calculate_focus_measures(stack, gauss_denoise, sigma)

    # Find initial best z-slices
    initial_best_z = np.argmax(focus_measures, axis=0)

    # Calculate adaptive max_z_diff
    adaptive_z_diff = calculate_adaptive_z_diff(focus_measures, initial_best_z, base_max_z_diff)

    # Create output arrays
    best_z = np.zeros((height, width), dtype=np.float64)
    weights = np.zeros((height, width), dtype=np.float32)

    # Process image in overlapping regions
    for y_start in range(0, height, adaptive_region_size - overlap):
        for x_start in range(0, width, adaptive_region_size - overlap):
            # Define region bounds
            y_end = min(y_start + adaptive_region_size, height)
            x_end = min(x_start + adaptive_region_size, width)

            # Extract region
            region_focus = focus_measures[:, y_start:y_end, x_start:x_end]
            region_z_diff = adaptive_z_diff[y_start:y_end, x_start:x_end]

            # Get maximum z_diff for this region
            max_z_diff_region = int(np.ceil(np.mean(region_z_diff)))

            # Process region with spatial continuity
            region_best_z = edf_with_spatial_continuity_region_enhanced(
                region_focus,
                max_z_diff=max_z_diff_region
            )

            # Create weight mask (higher weights in center, lower at edges)
            y_grid, x_grid = np.mgrid[y_start:y_end, x_start:x_end]
            y_center = (y_start + y_end) / 2
            x_center = (x_start + x_end) / 2

            # Calculate distance from center (normalized to 0-1)
            y_dist = np.abs(y_grid - y_center) / (adaptive_region_size / 2)
            x_dist = np.abs(x_grid - x_center) / (adaptive_region_size / 2)
            dist = np.maximum(y_dist, x_dist)
            region_weights = np.clip(1.0 - dist, 0.1, 1.0)

            # Convert region_best_z to float64 before multiplying
            region_best_z_float = region_best_z.astype(np.float64)

            # Add weighted contribution to output
            best_z[y_start:y_end, x_start:x_end] += region_best_z_float * region_weights
            weights[y_start:y_end, x_start:x_end] += region_weights

    # Normalize by weights and round to nearest integer
    best_z = np.round(best_z / np.maximum(weights, 1e-6)).astype(np.int32)

    # Ensure best_z values are within valid range
    best_z = np.clip(best_z, 0, z_size - 1)

    # Create output by taking pixels from best z-slices
    result = np.zeros((height, width), dtype=np.float32)
    for z in range(z_size):
        mask = best_z == z
        result[mask] = stack[z][mask]

    # Create a visualization of the focus map (normalized to 0-1)
    focus_map = best_z / (z_size - 1)

    return result, focus_map, best_z

# ======= FOCUS MAP REFINEMENT =======

def refine_focus_map(best_z, confidence, smoothness=0.5):
    """
    Refine the focus map to reduce noise while preserving edges
    """
    # Create edge-preserving filter
    edges = filters.sobel(confidence)
    edge_mask = edges > np.percentile(edges, 80)  # Top 20% of edge responses

    # Apply bilateral filter to smooth similar regions while preserving edges
    refined = filters.gaussian(best_z, sigma=smoothness)

    # Keep original values at edge locations
    refined[edge_mask] = best_z[edge_mask]

    return refined

# ======= MAIN PROCESSING FUNCTIONS =======

def process_single_stack(stack, apply_tophat=False, max_z_diff=DEFAULT_MAX_Z_DIFF):
    """
    Process a single z-stack with the enhanced regional EDF method
    """
    print(f"Processing stack with shape {stack.shape}...")

    # Convert to float and ensure valid range (0-1)
    stack = img_as_float(stack)

    # Ensure stack values are in 0-1 range
    min_val = np.min(stack)
    max_val = np.max(stack)
    if min_val < 0 or max_val > 1:
        print(f"Normalizing stack from range [{min_val}, {max_val}] to [0, 1]")
        stack = normalize_image(stack)

    # Apply background removal if requested
    if apply_tophat:
        print("Applying tophat background removal...")
        processed_stack = np.zeros_like(stack, dtype=np.float32)
        for z in range(stack.shape[0]):
            processed_stack[z] = tophat_filter_background(stack[z])
        working_stack = processed_stack
    else:
        working_stack = stack

    # Store original slices for comparison
    mid_z_idx = stack.shape[0] // 2

    # Prepare results dictionary
    results = {
        'original_middle': stack[mid_z_idx],
        'original_max': np.max(stack, axis=0)
    }

    # Apply enhanced regional method
    print("Applying enhanced regional guided EDF...")
    proj, focus_map, best_z = edf_regional_guided_enhanced(
        working_stack, base_max_z_diff=max_z_diff)

    # Calculate focus confidence for refinement
    focus_measures = calculate_focus_measures(working_stack)
    confidence = np.zeros_like(best_z, dtype=np.float32)
    for y in range(best_z.shape[0]):
        for x in range(best_z.shape[1]):
            z = best_z[y, x]
            confidence[y, x] = focus_measures[z, y, x]

    # Apply focus map refinement for smoother results
    refined_best_z = refine_focus_map(best_z, confidence, smoothness=0.75)

    # Create refined projection using the refined focus map
    refined_proj = np.zeros_like(proj)
    for z in range(working_stack.shape[0]):
        mask = np.round(refined_best_z).astype(np.int32) == z
        refined_proj[mask] = working_stack[z][mask]

    results['regional_edf'] = refined_proj
    results['focus_map_regional'] = focus_map

    # Add focus confidence map for debugging/visualization
    results['focus_confidence'] = confidence

    return results

def save_results_to_disk(results, filename, channel_name, apply_tophat):
    """Save results to disk with updated directory structure and naming"""
    base_name = os.path.splitext(filename)[0]

    # Get only the projection result (not focus map)
    key = "regional_edf"
    if key in results:
        # Determine processing type folder
        process_type = "tophat" if apply_tophat else "background"

        # Skip if this is not one of the required outputs
        if channel_name == "Nuclei" and not apply_tophat:
            print(f"Skipping {channel_name} without tophat (not required)")
            return

        # Create appropriate directory path
        output_dir = os.path.join(OUTPUT_DIR, channel_name, process_type)
        os.makedirs(output_dir, exist_ok=True)

        # Create output filename - using just the base name as requested
        output_filename = f"{base_name}_{channel_name}_regional"
        if apply_tophat:
            output_filename += "_tophat"
        output_filename += ".tif"

        output_path = os.path.join(output_dir, output_filename)

        # Normalize to 0-1 range and convert to uint16
        projection = results[key]
        projection_normalized = normalize_image(projection)
        projection_uint = img_as_uint(projection_normalized)

        # Save
        tifffile.imwrite(output_path, projection_uint)
        print(f"Saved {output_path}")

        # Additionally save focus confidence map for debugging (optional)
        if 'focus_confidence' in results:
            conf_path = output_path.replace('.tif', '_confidence.tif')
            conf_normalized = normalize_image(results['focus_confidence'])
            conf_uint = img_as_uint(conf_normalized)
            tifffile.imwrite(conf_path, conf_uint)
            print(f"Saved confidence map: {conf_path}")

def process_directory(input_dir=INPUT_DIR, max_z_diff=DEFAULT_MAX_Z_DIFF,
                     file_pattern='*.tif*', save_results=True):
    """
    Process all files in a directory with channel-specific settings
    """
    # Get all matching files
    all_files = glob.glob(os.path.join(input_dir, file_pattern), recursive=True)
    all_files.sort()

    if not all_files:
        print(f"No files matching '{file_pattern}' found in {input_dir}")
        return []

    print(f"Found {len(all_files)} files to process")

    # Create output directories
    create_dirs()

    # Process each file
    processed_files = []

    for file_idx, file_path in enumerate(all_files):
        print(f"\nProcessing file {file_idx+1}/{len(all_files)}: {os.path.basename(file_path)}")

        try:
            results = process_file(
                file_path, max_z_diff=max_z_diff,
                save_results=save_results
            )

            if results:
                processed_files.append(file_path)
                print(f"Successfully processed {os.path.basename(file_path)}")

        except Exception as e:
            print(f"Error processing file {os.path.basename(file_path)}: {e}")
            import traceback
            traceback.print_exc()

    print(f"\nSuccessfully processed {len(processed_files)}/{len(all_files)} files")
    return processed_files

def process_file(file_path, max_z_diff=DEFAULT_MAX_Z_DIFF, save_results=True):
    """
    Process a single file with channel-specific processing:
    - Nuclei (channel 1): Apply tophat only
    - Cadherins and Golgi (channels 0 and 2): Create with and without tophat
    """
    filename = os.path.basename(file_path)
    print(f"Processing file: {filename}")

    # Load image
    try:
        image = tifffile.imread(file_path)
        print(f"Image shape: {image.shape}, dtype: {image.dtype}")
    except Exception as e:
        print(f"Error loading image: {e}")
        return None

    # Handle different dimensionalities
    if len(image.shape) == 3:
        # Single channel z-stack - process with and without tophat
        results_no_tophat = process_single_stack(
            image, apply_tophat=False, max_z_diff=max_z_diff)

        results_tophat = process_single_stack(
            image, apply_tophat=True, max_z_diff=max_z_diff)

        if save_results:
            save_results_to_disk(results_no_tophat, filename, "default", False)
            save_results_to_disk(results_tophat, filename, "default", True)

        return {"default_no_tophat": results_no_tophat, "default_tophat": results_tophat}

    elif len(image.shape) == 4:
        # Multi-channel z-stack with channel-specific processing
        all_results = {}

        # Process each channel with appropriate settings
        for ch_idx in range(image.shape[0]):
            # Extract channel
            channel_data = image[ch_idx]
            ch_name = CHANNEL_NAMES[ch_idx] if ch_idx < len(CHANNEL_NAMES) else f"channel_{ch_idx}"

            # Apply channel-specific processing
            if ch_idx == 1:  # Nuclei (second channel) - tophat only
                print(f"Processing {ch_name} with tophat...")
                results = process_single_stack(
                    channel_data, apply_tophat=True, max_z_diff=max_z_diff)

                all_results[f"{ch_name}_tophat"] = results

                if save_results:
                    save_results_to_disk(results, filename, ch_name, True)

            else:  # Cadherins and Golgi (channels 0 and 2) - with and without tophat
                print(f"Processing {ch_name} without tophat...")
                results_no_tophat = process_single_stack(
                    channel_data, apply_tophat=False, max_z_diff=max_z_diff)

                print(f"Processing {ch_name} with tophat...")
                results_tophat = process_single_stack(
                    channel_data, apply_tophat=True, max_z_diff=max_z_diff)

                all_results[f"{ch_name}_no_tophat"] = results_no_tophat
                all_results[f"{ch_name}_tophat"] = results_tophat

                if save_results:
                    save_results_to_disk(results_no_tophat, filename, ch_name, False)
                    save_results_to_disk(results_tophat, filename, ch_name, True)

        return all_results

    else:
        print(f"Unsupported image dimensions: {len(image.shape)}")
        return None

# ======= VISUALIZATION FUNCTIONS =======

def visualize_focus_map(focus_map, original_image, projected_image, title="Focus Map Visualization"):
    """
    Create a visualization of the focus map as a heatmap overlay
    """
    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(original_image, cmap='gray')
    plt.title("Original Image (Middle Slice)")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(projected_image, cmap='gray')
    plt.title("Enhanced EDF Projection")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(projected_image, cmap='gray', alpha=0.7)
    plt.imshow(focus_map, cmap='inferno', alpha=0.5)
    plt.colorbar(label='Z-Stack Position (Normalized)')
    plt.title("Focus Map Overlay")
    plt.axis('off')

    plt.suptitle(title)
    plt.tight_layout()
    return plt.gcf()

# Main execution - just call process_directory with your parameters
def main():
    create_dirs()
    processed_files = process_directory(
        input_dir=INPUT_DIR,
        max_z_diff=DEFAULT_MAX_Z_DIFF,
        file_pattern='*.tif*',
        save_results=True
    )
    print(f"Processed {len(processed_files)} files")

if __name__ == "__main__":
    main()

Mounted at /content/drive
Found 2 files to process

Processing file 1/2: denoised_1.4Pa_A1_19dec21_20xA_L2RA_FlatA_seq001.tif
Processing file: denoised_1.4Pa_A1_19dec21_20xA_L2RA_FlatA_seq001.tif
Image shape: (3, 13, 1024, 1024), dtype: float32
Processing Cadherins without tophat...
Processing stack with shape (13, 1024, 1024)...
Normalizing stack from range [-0.0228258203715086, 7.299477577209473] to [0, 1]
Applying enhanced regional guided EDF...
Using adaptive region size: 25, overlap: 8
Processing Cadherins with tophat...
Processing stack with shape (13, 1024, 1024)...
Normalizing stack from range [-0.0228258203715086, 7.299477577209473] to [0, 1]
Applying tophat background removal...
Applying enhanced regional guided EDF...
Using adaptive region size: 25, overlap: 8
Saved /content/drive/MyDrive/knowledge/University/Master/Thesis/Projected/trial2/Cadherins/background/denoised_1.4Pa_A1_19dec21_20xA_L2RA_FlatA_seq001_Cadherins_regional.tif
Saved confidence map: /content/drive/MyDrive