# - ROI-Straightener

## Overview
This script processes microscopy images with associated ROI annotations (LabelMe JSON format) and optional mapping images.  
It performs **ROI skeleton extraction**, determines the **longest path** with a controlled start orientation, and straightens the curved ROI region so that the start point is consistently placed on the **left side** of the output.  
The same transformation is applied to all mapping images, ensuring spatial correspondence across channels or datasets.  

The workflow includes:
1. **Validation** of target and mapping datasets.
2. **Skeletonization** and **longest path detection** within annotated ROI.
3. **Geometric straightening** of target and mapping images.
4. **Interactive navigation** to inspect results.
5. **Export** as either:
   - Individual straightened TIFF images, or
   - Multi-frame TIFF stacks.

---

## Input Requirements
- **Target folder** containing pairs of:
  - Microscopy image files (`.tif`, `.png`, `.jpg`, etc.)
  - Corresponding ROI annotation files in **LabelMe JSON format**.
- **Mapping folder(s)** (optional):
  - Each mapping folder must contain the same number of images as the target folder.
  - Images must have identical dimensions to the corresponding target images.
- **Directory structure example**:
  ```
  TARGET_DIR/
      sample1.tif
      sample1.json
      sample2.tif
      sample2.json
  MAPPING_DIRS/
      Mapping1/
          sample1.tif
          sample2.tif
      Mapping2/
          sample1.tif
          sample2.tif
  ```

---

## Outputs
- **Straightened target and mapping images** with consistent start orientation.
- **Individual TIFF files** for each processed image:
  ```
  sample1_target_straight.tif
  sample1_map_Mapping1_straight.tif
  ...
  ```
- **Multi-frame TIFF stacks**:
  ```
  Target_Stack.tif
  Mapping1_Stack.tif
  Mapping2_Stack.tif
  ```

---

## How to Run
1. **Set directory paths**:
   - Edit `TARGET_DIR`, `MAPPING_DIRS`, and `OUTPUT_DIR` in the configuration cell.
2. **Run validation**:
   - Executes `validate_and_load_files()` to ensure file count and dimensions match.
3. **Process images**:
   - Calls `process_single_image_set()` for each image pair.
4. **Inspect results interactively**:
   - Use the navigation widget to browse straightened results.
5. **Save outputs**:
   - Use the save buttons to export either individual files or TIFF stacks.

**Example run sequence** (inside Jupyter Notebook):
```python
# Step 1: Validate
validation_result = validate_and_load_files(TARGET_DIR, MAPPING_DIRS)

# Step 2: Process all images
all_results = [process_single_image_set(f) for f in file_list_map]

# Step 3: Inspect interactively
create_navigation_controls()

# Step 4: Save
create_save_controls()
```
---

In [1]:
# Import required libraries
import cv2
import numpy as np
import json
import os
import matplotlib.pyplot as plt
from PIL import Image
from skimage.morphology import skeletonize as skel_func
from skan import Skeleton as SkanSkeleton
from scipy.sparse.csgraph import dijkstra
import math
from pathlib import Path
import glob
from tqdm import tqdm
import ipywidgets as widgets
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib to display images inline with high quality
plt.rcParams['figure.figsize'] = [15, 10]
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['image.interpolation'] = 'bilinear'

## Configuration

Set your input and output directories and skeleton orientation preferences:

In [2]:
# Configuration - Update these paths according to your data
TARGET_DIR = "/Users/jyzerresico/chenlab/Synthetic Division/Fig4d/ACA-5/c1"  # Folder containing images and JSON files

MAPPING_DIRS = [
    "/Users/jyzerresico/chenlab/Synthetic Division/Fig4d/ACA-5/c2"  # Add more mapping folders as needed
]
OUTPUT_DIR = "/Users/jyzerresico/chenlab/Synthetic Division/Fig4d/ACA-5/Straightener"  # Where to save results

# Output options
SAVE_AS_STACK = False  # True for multi-frame TIFF stacks, False for individual files
PAD_TO_CANVAS = False  # True to pad/crop to fixed canvas size
OUTPUT_WIDTH = 200     # Canvas width (only used if PAD_TO_CANVAS=True)
OUTPUT_HEIGHT = 200    # Canvas height (only used if PAD_TO_CANVAS=True)

# Skeleton orientation control
# Options: 'top', 'bottom', 'left', 'right'
# This determines which end of the skeleton will be the START point
# START point will always appear on LEFT side of straightened output
SKELETON_START_PREFERENCE = 'top'  # Change this to control skeleton orientation

# Visualization options for better display quality
CONTOUR_THICKNESS = 1      # Thickness for polygon contours (1 for fine lines)
SKELETON_THICKNESS = 1     # Thickness for skeleton lines
PATH_THICKNESS = 2         # Thickness for longest path
UPSCALE_FACTOR = 2         # Factor to upscale images for better visualization

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Target directory: {TARGET_DIR}")
print(f"Mapping directories: {MAPPING_DIRS}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Save as stack: {SAVE_AS_STACK}")
print(f"Pad to canvas: {PAD_TO_CANVAS}")
print(f"Skeleton start preference: {SKELETON_START_PREFERENCE} -> START will be on LEFT side of output")
print(f"Upscale factor for display: {UPSCALE_FACTOR}x")

Target directory: /Users/jyzerresico/chenlab/Synthetic Division/Fig4d/ACA-5/c1
Mapping directories: ['/Users/jyzerresico/chenlab/Synthetic Division/Fig4d/ACA-5/c2']
Output directory: /Users/jyzerresico/chenlab/Synthetic Division/Fig4d/ACA-5/Straightener
Save as stack: False
Pad to canvas: False
Skeleton start preference: top -> START will be on LEFT side of output
Upscale factor for display: 2x


## Core Processing Functions

In [3]:
def upscale_image_for_display(image, factor=2):
    """Upscale image for better visualization quality."""
    if factor <= 1:
        return image
    height, width = image.shape[:2]
    new_height, new_width = height * factor, width * factor
    return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)


def draw_polygon_with_fine_lines(image, polygon_points, color=(0, 255, 0), thickness=1, upscale_factor=1):
    """Draw polygon with fine lines optimized for low resolution images."""
    img_copy = image.copy()
    
    if upscale_factor > 1:
        img_copy = upscale_image_for_display(img_copy, upscale_factor)
        # Scale polygon points accordingly
        scaled_points = (polygon_points * upscale_factor).astype(np.int32)
    else:
        scaled_points = polygon_points.astype(np.int32)
    
    # Draw with anti-aliasing for smoother lines
    cv2.polylines(img_copy, [scaled_points], isClosed=True, color=color, 
                 thickness=thickness, lineType=cv2.LINE_AA)
    
    return img_copy


def select_skeleton_start_point(endpoint_coordinates, preference='top'):
    """Select skeleton start point based on spatial preference.
    
    Args:
        endpoint_coordinates: Array of endpoint coordinates in (row, col) format
        preference: 'top', 'bottom', 'left', 'right'
    
    Returns:
        Index of the selected start point
    """
    if len(endpoint_coordinates) < 2:
        return 0
    
    if preference == 'top':
        # Select point with smallest row (topmost)
        return np.argmin(endpoint_coordinates[:, 0])
    elif preference == 'bottom':
        # Select point with largest row (bottommost)
        return np.argmax(endpoint_coordinates[:, 0])
    elif preference == 'left':
        # Select point with smallest column (leftmost)
        return np.argmin(endpoint_coordinates[:, 1])
    elif preference == 'right':
        # Select point with largest column (rightmost)
        return np.argmax(endpoint_coordinates[:, 1])
    else:
        # Default to first point
        return 0


def draw_skeleton_with_path_and_start(skeleton_img, longest_path_coords_rc, start_coord_rc, 
                                     upscale_factor=1, skeleton_thickness=1, path_thickness=2):
    """Draw skeleton with highlighted longest path and marked start point."""
    # Convert to color
    if len(skeleton_img.shape) == 2:
        skeleton_display = cv2.cvtColor(skeleton_img, cv2.COLOR_GRAY2BGR)
    else:
        skeleton_display = skeleton_img.copy()
    
    if upscale_factor > 1:
        skeleton_display = upscale_image_for_display(skeleton_display, upscale_factor)
        # Scale path coordinates
        scaled_path = (longest_path_coords_rc * upscale_factor).astype(np.int32)
        scaled_start = (start_coord_rc * upscale_factor).astype(np.int32)
    else:
        scaled_path = longest_path_coords_rc.astype(np.int32)
        scaled_start = start_coord_rc.astype(np.int32)
    
    # Draw the longest path in red with anti-aliasing
    path_xy = scaled_path[:, ::-1]  # Convert rc to xy coordinates
    cv2.polylines(skeleton_display, [path_xy], isClosed=False, color=(0, 0, 255), 
                 thickness=path_thickness, lineType=cv2.LINE_AA)
    
    # Mark the start point with a green circle
    start_xy = scaled_start[::-1]  # Convert rc to xy
    radius = max(3, path_thickness + 1)
    cv2.circle(skeleton_display, tuple(start_xy.astype(int)), radius, (0, 255, 0), -1, cv2.LINE_AA)
    
    # Add "START" text near the start point
    text_pos = (int(start_xy[0] + radius + 5), int(start_xy[1] - radius - 5))
    cv2.putText(skeleton_display, "START", text_pos, 
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA)
    
    return skeleton_display


def straighten_polygon_region(image, polygon_mask, skeleton_path_rc):
    """Straighten a polygonal region along a skeleton path.
    
    Modified to ensure START point appears on LEFT side of output.
    The original algorithm flips the result, so we need to account for this.
    """
    skeleton_path_rc = np.array(skeleton_path_rc, dtype=np.float32)
    if len(skeleton_path_rc) < 2:
        return np.zeros((10, 10, 3), dtype=np.uint8)
    
    # Calculate path segments and arc lengths
    path_segments = np.diff(skeleton_path_rc, axis=0)
    segment_lengths = np.linalg.norm(path_segments, axis=1)
    arc_lengths = np.concatenate(([0], np.cumsum(segment_lengths)))
    total_skeleton_length = arc_lengths[-1]
    
    if total_skeleton_length < 1e-3:
        return np.zeros((10, 10, 3), dtype=np.uint8)
    
    # Calculate normals along the path
    path_normals_rc = np.zeros_like(skeleton_path_rc)
    
    # First point normal
    tangent_first = skeleton_path_rc[1] - skeleton_path_rc[0]
    norm_tangent_first = np.linalg.norm(tangent_first)
    if norm_tangent_first < 1e-6:
        tangent_first = np.array([1.0, 0.0])
    else:
        tangent_first /= norm_tangent_first
    path_normals_rc[0] = [-tangent_first[1], tangent_first[0]]
    
    # Middle points normals
    for i in range(1, len(skeleton_path_rc) - 1):
        tangent = skeleton_path_rc[i+1] - skeleton_path_rc[i-1]
        norm_tangent = np.linalg.norm(tangent)
        if norm_tangent < 1e-6:
            path_normals_rc[i] = path_normals_rc[i-1]
        else:
            tangent /= norm_tangent
            path_normals_rc[i] = [-tangent[1], tangent[0]]
    
    # Last point normal
    tangent_last = skeleton_path_rc[-1] - skeleton_path_rc[-2]
    norm_tangent_last = np.linalg.norm(tangent_last)
    if norm_tangent_last < 1e-6:
        tangent_last = np.array([1.0, 0.0])
    else:
        tangent_last /= norm_tangent_last
    path_normals_rc[-1] = [-tangent_last[1], tangent_last[0]]
    
    # Find all points in the mask and their perpendicular distances
    all_mask_points_rc = np.argwhere(polygon_mask > 0).astype(np.float32)
    if len(all_mask_points_rc) == 0:
        return None
    
    all_v_coords = []
    for p_rc in all_mask_points_rc:
        distances_to_skel_pts_sq = np.sum((skeleton_path_rc - p_rc)**2, axis=1)
        closest_skel_pt_idx = np.argmin(distances_to_skel_pts_sq)
        vec_skel_to_p = p_rc - skeleton_path_rc[closest_skel_pt_idx]
        perp_dist = np.dot(vec_skel_to_p, path_normals_rc[closest_skel_pt_idx])
        all_v_coords.append(perp_dist)
    
    if not all_v_coords:
        d_min, d_max = -20, 20
    else:
        d_min, d_max = np.min(all_v_coords), np.max(all_v_coords)
    
    # Create output dimensions
    output_width_px = int(round(total_skeleton_length))
    output_height_px = int(round(d_max - d_min))
    if output_width_px <= 0:
        output_width_px = 1
    if output_height_px <= 0:
        output_height_px = 1
    
    v_offset_display = -d_min
    
    # Create coordinate maps
    # MODIFIED: Reverse the u_param mapping to counteract the final flip
    # This ensures START point ends up on LEFT side after flip
    map_c_coords = np.zeros((output_height_px, output_width_px), dtype=np.float32)
    map_r_coords = np.zeros((output_height_px, output_width_px), dtype=np.float32)
    
    for r_new in range(output_height_px):
        for c_new in range(output_width_px):
            # MODIFIED: Reverse the u_param so that after flip, START is on LEFT
            u_param = total_skeleton_length - float(c_new)
            v_param = float(r_new) - v_offset_display
            
            skel_pt_idx = np.searchsorted(arc_lengths, u_param, side='right') - 1
            skel_pt_idx = max(0, min(skel_pt_idx, len(arc_lengths) - 2))
            
            segment_start_arc_len = arc_lengths[skel_pt_idx]
            segment_len = arc_lengths[skel_pt_idx+1] - segment_start_arc_len
            
            if segment_len < 1e-6:
                interp_ratio = 0.0
            else:
                interp_ratio = np.clip((u_param - segment_start_arc_len) / segment_len, 0.0, 1.0)
            
            pt_on_skel_rc = (1 - interp_ratio) * skeleton_path_rc[skel_pt_idx] + interp_ratio * skeleton_path_rc[skel_pt_idx+1]
            normal_rc_interp = (1 - interp_ratio) * path_normals_rc[skel_pt_idx] + interp_ratio * path_normals_rc[skel_pt_idx+1]
            
            norm_of_normal = np.linalg.norm(normal_rc_interp)
            if norm_of_normal < 1e-6:
                normal_rc_interp = path_normals_rc[skel_pt_idx] if np.linalg.norm(path_normals_rc[skel_pt_idx]) > 1e-6 else np.array([0.0, 1.0])
            else:
                normal_rc_interp /= norm_of_normal
            
            map_r_coords[r_new, c_new] = pt_on_skel_rc[0] + v_param * normal_rc_interp[0]
            map_c_coords[r_new, c_new] = pt_on_skel_rc[1] + v_param * normal_rc_interp[1]
    
    # Apply the transformation
    masked_source_image = cv2.bitwise_and(image, image, mask=polygon_mask)
    straightened_img_content = cv2.remap(masked_source_image, map_c_coords, map_r_coords, 
                                       interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0,0,0))
    
    # Apply mask to straightened image
    mask_float = polygon_mask.astype(np.float32) / 255.0
    straightened_mask_float = cv2.remap(mask_float, map_c_coords, map_r_coords, 
                                      interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    final_straightened_mask = (straightened_mask_float > 0.5).astype(np.uint8) * 255
    final_output = cv2.bitwise_and(straightened_img_content, straightened_img_content, mask=final_straightened_mask)
    
    # Keep the original flip behavior, but with reversed u_param mapping
    # This ensures START point ends up on LEFT side
    final_output = cv2.flip(final_output, -1)
    return final_output


def pad_and_center_image(img_to_pad, out_width, out_height):
    """Pad and center an image to a specific canvas size."""
    if img_to_pad is None:
        return np.zeros((out_height, out_width, 3), dtype=np.uint8)
    
    h, w = img_to_pad.shape[:2]
    canvas = np.zeros((out_height, out_width, 3), dtype=np.uint8)
    
    x_offset = (out_width - w) // 2
    y_offset = (out_height - h) // 2
    
    paste_w, paste_h = min(w, out_width), min(h, out_height)
    
    if x_offset < 0 or y_offset < 0:
        # Image is larger than canvas, crop it
        crop_x = (w - out_width) // 2 if w > out_width else 0
        crop_y = (h - out_height) // 2 if h > out_height else 0
        canvas[:, :] = img_to_pad[crop_y:crop_y+paste_h, crop_x:crop_x+paste_w]
    else:
        # Image is smaller than canvas, center it
        canvas[y_offset:y_offset+paste_h, x_offset:x_offset+paste_w] = img_to_pad[:paste_h, :paste_w]
    
    return canvas


def extract_skeleton_and_longest_path(mask, start_preference='top'):
    """
    Extracts the skeleton and uses a "hybrid" logic to find the longest path:
    1. Uses a two-pass Dijkstra algorithm to find the physical longest path (diameter).
    2. Applies spatial preference to the two endpoints of that path to determine the final start point.
    """
    skeleton_rc = skel_func(mask > 0)
    skeleton_display_img = (skeleton_rc.astype(np.uint8) * 255)

    graph = SkanSkeleton(skeleton_rc)
    pixel_graph, coordinates_rc = graph.graph, graph.coordinates

    if pixel_graph.shape[0] == 0:
        raise ValueError("Skeleton graph is empty.")

    n_nodes = pixel_graph.shape[0]
    degrees = np.bincount(pixel_graph.indices, minlength=n_nodes)
    endpoint_nodes_indices = np.where(degrees == 1)[0]

    longest_path_coords_rc = None
    start_coord_rc = None

    if len(endpoint_nodes_indices) < 2:
        # Fallback to the original skan method if there are fewer than 2 endpoints
        if graph.n_paths > 0:
            longest_path_idx = np.argmax(graph.path_lengths())
            longest_path_coords_rc = graph.path_coordinates(longest_path_idx)
            if longest_path_coords_rc is None or len(longest_path_coords_rc) < 2:
                raise ValueError("Fallback skeleton path is invalid.")
            start_coord_rc = longest_path_coords_rc[0] # The first point becomes the start
        else:
            raise ValueError("No paths found and insufficient endpoints.")
    else:
        # --- Find the physical longest path (diameter) using the old logic ---
        
        # From an arbitrary endpoint, find the farthest endpoint A
        temp_start_node = endpoint_nodes_indices[0]
        
        # The function returns only one value, so assign it to one variable.
        distances_from_temp = dijkstra(
            csgraph=pixel_graph, directed=False, indices=temp_start_node, return_predecessors=False
        )
        
        farthest_node_A_idx = -1
        max_dist = -1
        for ep_idx in endpoint_nodes_indices:
            if distances_from_temp[ep_idx] > max_dist and np.isfinite(distances_from_temp[ep_idx]):
                max_dist = distances_from_temp[ep_idx]
                farthest_node_A_idx = ep_idx
        
        if farthest_node_A_idx == -1:
            raise ValueError("Could not find a reachable endpoint A.")

        # From endpoint A, find the farthest endpoint B and record the path
        distances_from_A, predecessors_from_A = dijkstra(
            csgraph=pixel_graph, directed=False, indices=farthest_node_A_idx, return_predecessors=True
        )

        farthest_node_B_idx = -1
        max_dist_A_to_B = -1
        for ep_idx in endpoint_nodes_indices:
            if distances_from_A[ep_idx] > max_dist_A_to_B and np.isfinite(distances_from_A[ep_idx]):
                max_dist_A_to_B = distances_from_A[ep_idx]
                farthest_node_B_idx = ep_idx
        
        if farthest_node_B_idx == -1:
            raise ValueError("Could not find a reachable endpoint B.")
            
        # Reconstruct the path from A to B
        path_indices = []
        curr = farthest_node_B_idx
        while curr != -9999:
            path_indices.append(curr)
            if curr == farthest_node_A_idx:
                break
            curr = predecessors_from_A[curr]

        if not path_indices or path_indices[-1] != farthest_node_A_idx:
            raise ValueError("Path reconstruction failed.")

        # Path is currently from B to A, so reverse it to be from A to B
        path_indices.reverse()
        
        # --- Apply spatial preference to the two endpoints A and B ---

        # Get the coordinates of endpoints A and B
        endpoint_A_coord = coordinates_rc[farthest_node_A_idx]
        endpoint_B_coord = coordinates_rc[farthest_node_B_idx]
        
        # Use the new function to decide which of A and B better fits the preference
        two_endpoints_coords = np.array([endpoint_A_coord, endpoint_B_coord])
        
        preferred_start_index = select_skeleton_start_point(two_endpoints_coords, start_preference)

        # --- Organize the path to ensure the start point is first ---
        
        # Based on the choice, set the final start point and reverse the path if needed
        if preferred_start_index == 0:
            # Endpoint A was chosen as the start. The path is already A -> B, so the order is correct.
            start_coord_rc = endpoint_A_coord
            longest_path_coords_rc = coordinates_rc[path_indices]
        else:
            # Endpoint B was chosen as the start. Reverse the A -> B path to B -> A.
            start_coord_rc = endpoint_B_coord
            path_indices.reverse() # Reverse the path
            longest_path_coords_rc = coordinates_rc[path_indices]

    if longest_path_coords_rc is None or len(longest_path_coords_rc) < 2:
        raise ValueError("Skeleton path is invalid after final check.")

    return skeleton_display_img, longest_path_coords_rc, start_coord_rc


print("Core processing functions defined successfully!")
print("Skeleton orientation control with START point on LEFT side")

Core processing functions defined successfully!
Skeleton orientation control with START point on LEFT side


## File Validation and Loading

In [4]:
def validate_and_load_files(target_dir, mapping_dirs):
    """Validate and load file mappings between target and mapping directories."""
    image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
    
    # 1. Validate Target Folder
    print("Validating target folder...")
    try:
        target_files = []
        for filename in sorted(os.listdir(target_dir)):
            name, ext = os.path.splitext(filename)
            if ext.lower() in image_extensions:
                img_path = os.path.join(target_dir, filename)
                json_path = os.path.join(target_dir, name + ".json")
                if not os.path.exists(json_path):
                    raise FileNotFoundError(f"Target image '{filename}' is missing its corresponding JSON file.")
                target_files.append({'img': img_path, 'json': json_path, 'name': filename})
        
        if not target_files:
            raise ValueError("No valid 'Image-JSON' pairs found in the target folder.")
        
        print(f"Target folder validated: {len(target_files)} image-JSON pairs found")
                
    except (FileNotFoundError, ValueError) as e:
        print(f"Target Folder Error: {e}")
        return None
    
    # 2. Validate each mapping folder
    print("\nValidating mapping folders...")
    valid_mapping_dirs = []
    file_list_map = []
    
    for i, m_dir in enumerate(mapping_dirs):
        print(f"Validating mapping folder {i+1}: {os.path.basename(m_dir)}...")
        try:
            map_files = [f for f in sorted(os.listdir(m_dir)) 
                        if os.path.splitext(f)[1].lower() in image_extensions]
            
            if len(map_files) != len(target_files):
                raise ValueError(f"File count mismatch! Target has {len(target_files)}, this folder has {len(map_files)}.")
            
            # Check dimensions match
            for j, target_file_info in enumerate(target_files):
                target_img_path = target_file_info['img']
                map_img_path = os.path.join(m_dir, map_files[j])
                
                # Check dimensions
                target_img = cv2.imread(target_img_path)
                if target_img is None:
                    raise IOError(f"Cannot read target image: {os.path.basename(target_img_path)}")
                h, w = target_img.shape[:2]
                
                map_img = cv2.imread(map_img_path)
                if map_img is None:
                    raise IOError(f"Cannot read mapping image: {os.path.basename(map_img_path)}")
                mh, mw = map_img.shape[:2]
                
                if h != mh or w != mw:
                    raise ValueError(f"Size mismatch: '{target_file_info['name']}' ({w}x{h}) vs '{map_files[j]}' ({mw}x{mh})")
            
            valid_mapping_dirs.append(m_dir)
            print(f"Mapping folder {i+1} validated successfully")
            
        except (ValueError, IOError) as e:
            print(f"Mapping folder {i+1} validation failed: {e}")
            continue
    
    if not valid_mapping_dirs:
        print("No valid mapping folders found!")
        return None
    
    # 3. Build the final file map list
    print(f"\nBuilding file mappings for {len(valid_mapping_dirs)} valid mapping folders...")
    for i in range(len(target_files)):
        map_entry = {
            'target_img_path': target_files[i]['img'], 
            'target_json_path': target_files[i]['json'], 
            'mapping_img_paths': []
        }
        
        for m_dir in valid_mapping_dirs:
            map_files = [f for f in sorted(os.listdir(m_dir)) 
                        if os.path.splitext(f)[1].lower() in image_extensions]
            map_entry['mapping_img_paths'].append(os.path.join(m_dir, map_files[i]))
        
        file_list_map.append(map_entry)
    
    print(f"File mapping complete: {len(file_list_map)} file sets with {len(valid_mapping_dirs)} mapping(s) each.")
    return file_list_map, valid_mapping_dirs


# Validate and load files
validation_result = validate_and_load_files(TARGET_DIR, MAPPING_DIRS)

if validation_result is None:
    print("Validation failed! Please check your directory paths and file structure.")
else:
    file_list_map, valid_mapping_dirs = validation_result
    print(f"\nValidation successful! Ready to process {len(file_list_map)} image sets.")

Validating target folder...
Target folder validated: 1 image-JSON pairs found

Validating mapping folders...
Validating mapping folder 1: c2...
Mapping folder 1 validated successfully

Building file mappings for 1 valid mapping folders...
File mapping complete: 1 file sets with 1 mapping(s) each.

Validation successful! Ready to process 1 image sets.


## Process Images

In [5]:
def process_single_image_set(file_set, pad_to_canvas=False, output_width=200, output_height=200, start_preference='top'):
    """Process a single image set (target + mappings) with controlled skeleton orientation."""
    # Load target image and JSON annotation
    target_img_cv = cv2.imread(file_set['target_img_path'])
    img_height, img_width = target_img_cv.shape[:2]
    
    with open(file_set['target_json_path'], 'r', encoding='utf-8') as f:
        labelme_data = json.load(f)
    
    # Extract polygon points (assuming first shape is the polygon)
    polygon_points_xy = labelme_data['shapes'][0]['points']
    poly_np = np.array(polygon_points_xy, dtype=np.int32)
    
    # Create mask from polygon
    mask = np.zeros((img_height, img_width), dtype=np.uint8)
    cv2.fillPoly(mask, [poly_np], 255)
    
    # Extract skeleton and find longest path with controlled orientation
    skeleton_img_raw, longest_path_coords_rc, start_coord_rc = extract_skeleton_and_longest_path(mask, start_preference)
    
    # Create enhanced visualization of skeleton with path and start point
    skeleton_img_display = draw_skeleton_with_path_and_start(
        skeleton_img_raw, longest_path_coords_rc, start_coord_rc,
        upscale_factor=UPSCALE_FACTOR,
        skeleton_thickness=SKELETON_THICKNESS,
        path_thickness=PATH_THICKNESS
    )
    
    # Straighten target image (START point will be on the left side of output)
    straight_target_raw = straighten_polygon_region(target_img_cv, mask, longest_path_coords_rc)
    
    # Straighten mapping images
    straight_mappings_raw = []
    for map_img_path in file_set['mapping_img_paths']:
        map_img_cv = cv2.imread(map_img_path)
        straight_map = straighten_polygon_region(map_img_cv, mask, longest_path_coords_rc)
        straight_mappings_raw.append(straight_map)
    
    # Apply padding/cropping if requested
    if pad_to_canvas:
        final_straight_target = pad_and_center_image(straight_target_raw, output_width, output_height)
        final_straight_mappings = [pad_and_center_image(img, output_width, output_height) 
                                 for img in straight_mappings_raw]
    else:
        final_straight_target = straight_target_raw
        final_straight_mappings = straight_mappings_raw
    
    return {
        "target_img_original": target_img_cv,
        "polygon": poly_np,
        "skeleton_img_display": skeleton_img_display,
        "target_straight_img": final_straight_target,
        "mapping_straight_imgs": final_straight_mappings,
        "mapping_img_paths": file_set['mapping_img_paths'],
        "start_coord_rc": start_coord_rc
    }


# Process all images
if 'file_list_map' in locals():
    print("Processing all images with controlled skeleton orientation...")
    all_results = []
    
    for index in tqdm(range(len(file_list_map)), desc="Processing images"):
        file_set = file_list_map[index]
        base_name = os.path.basename(file_set['target_img_path'])
        
        try:
            results = process_single_image_set(file_set, PAD_TO_CANVAS, OUTPUT_WIDTH, OUTPUT_HEIGHT, SKELETON_START_PREFERENCE)
            all_results.append(results)
        except Exception as e:
            print(f"\nError processing {base_name}: {e}")
            raise e
    
    print(f"\nSuccessfully processed {len(all_results)} image sets with {SKELETON_START_PREFERENCE} start preference!")
    print(f"All skeleton START points are positioned at the {SKELETON_START_PREFERENCE}, and appear on the LEFT side of straightened outputs.")
else:
    print("No valid files to process. Please run the validation cell first.")

Processing all images with controlled skeleton orientation...


Processing images: 100%|██████████| 1/1 [00:01<00:00,  1.33s/it]


Successfully processed 1 image sets with top start preference!
All skeleton START points are positioned at the top, and appear on the LEFT side of straightened outputs.





## Navigation Controls

Navigate through results with previous/next buttons:

In [6]:
def create_navigation_controls():
    """Create navigation controls with previous/next buttons."""
    if not ('all_results' in locals() or 'all_results' in globals()) or not all_results:
        print("No results to display. Please run the processing cell first.")
        return
    
    # Current index tracker
    current_index = {'value': 0}
    
    # Create buttons
    prev_button = widgets.Button(
        description='Previous',
        disabled=False,
        button_style='info',
        tooltip='Go to previous image'
    )
    
    next_button = widgets.Button(
        description='Next',
        disabled=False,
        button_style='info',
        tooltip='Go to next image'
    )
    
    info_label = widgets.HTML(value=f"<b>Image 1 of {len(all_results)}</b>")
    
    # Create output for images
    nav_output = widgets.Output()
    
    def show_current_image():
        """Display current image and update controls."""
        idx = current_index['value']
        
        # Update info label
        info_label.value = f"<b>Image {idx + 1} of {len(all_results)}</b>"
        
        # Update button states
        prev_button.disabled = (idx == 0)
        next_button.disabled = (idx == len(all_results) - 1)
        
        # Display image
        with nav_output:
            clear_output(wait=True)
            
            result = all_results[idx]
            file_info = file_list_map[idx]
            base_name = os.path.basename(file_info['target_img_path'])
            
            print(f"Displaying results for: {base_name} (Image {idx + 1}/{len(all_results)})")
            print(f"Skeleton start preference: {SKELETON_START_PREFERENCE} -> START on LEFT side of straightened output")
            
            # Calculate the number of rows needed
            num_mappings = len(result['mapping_straight_imgs'])
            num_rows = 1 + num_mappings  # 1 for target + N for mappings
            
            fig, axes = plt.subplots(num_rows, 3, figsize=(18, 6 * num_rows))
            if num_rows == 1:
                axes = axes.reshape(1, -1)
            
            # Display target row
            row = 0
            
            # Original target with enhanced annotation
            target_with_poly = draw_polygon_with_fine_lines(
                result['target_img_original'], 
                result['polygon'], 
                color=(0, 255, 0), 
                thickness=CONTOUR_THICKNESS,
                upscale_factor=UPSCALE_FACTOR
            )
            axes[row, 0].imshow(cv2.cvtColor(target_with_poly, cv2.COLOR_BGR2RGB))
            axes[row, 0].set_title(f"Target: {os.path.basename(TARGET_DIR)}\nOriginal + Annotation", fontsize=12, pad=10)
            axes[row, 0].axis('off')
            
            # Enhanced skeleton display with START marker
            axes[row, 1].imshow(cv2.cvtColor(result['skeleton_img_display'], cv2.COLOR_BGR2RGB))
            axes[row, 1].set_title(f"Skeleton + Longest Path\nSTART ({SKELETON_START_PREFERENCE}) marked in green", fontsize=12, pad=10)
            axes[row, 1].axis('off')
            
            # Straightened target
            axes[row, 2].imshow(cv2.cvtColor(result['target_straight_img'], cv2.COLOR_BGR2RGB))
            axes[row, 2].set_title(f"Straightened Target\nSTART -> LEFT side", fontsize=12, pad=10)
            axes[row, 2].axis('off')
            
            # Display mapping rows
            for i, map_img_path in enumerate(result['mapping_img_paths']):
                row = i + 1
                map_dir_name = os.path.basename(valid_mapping_dirs[i])
                
                # Original mapping with enhanced annotation
                map_img = cv2.imread(map_img_path)
                map_with_poly = draw_polygon_with_fine_lines(
                    map_img, 
                    result['polygon'], 
                    color=(0, 255, 0), 
                    thickness=CONTOUR_THICKNESS,
                    upscale_factor=UPSCALE_FACTOR
                )
                axes[row, 0].imshow(cv2.cvtColor(map_with_poly, cv2.COLOR_BGR2RGB))
                axes[row, 0].set_title(f"Map {i+1}: {map_dir_name}\nOriginal + Annotation", fontsize=12, pad=10)
                axes[row, 0].axis('off')
                
                # Same enhanced skeleton for all images
                axes[row, 1].imshow(cv2.cvtColor(result['skeleton_img_display'], cv2.COLOR_BGR2RGB))
                axes[row, 1].set_title(f"Skeleton + Longest Path\nSame START point for all", fontsize=12, pad=10)
                axes[row, 1].axis('off')
                
                # Straightened mapping
                axes[row, 2].imshow(cv2.cvtColor(result['mapping_straight_imgs'][i], cv2.COLOR_BGR2RGB))
                axes[row, 2].set_title(f"Straightened Map {i+1}\nSTART -> LEFT side", fontsize=12, pad=10)
                axes[row, 2].axis('off')
            
            plt.tight_layout(pad=2.0)
            plt.show()
    
    def on_prev_clicked(b):
        if current_index['value'] > 0:
            current_index['value'] -= 1
            show_current_image()
    
    def on_next_clicked(b):
        if current_index['value'] < len(all_results) - 1:
            current_index['value'] += 1
            show_current_image()
    
    # Connect button events
    prev_button.on_click(on_prev_clicked)
    next_button.on_click(on_next_clicked)
    
    # Layout
    controls = widgets.HBox([prev_button, info_label, next_button], 
                           layout=widgets.Layout(justify_content='center'))
    
    display(widgets.VBox([
        widgets.HTML("<h3>Quick Navigation</h3>"),
        widgets.HTML("<p>Use the buttons to navigate through images:</p>"),
        controls,
        nav_output
    ]))
    
    # Show first image
    show_current_image()


# Create navigation controls
if 'all_results' in locals() and all_results:
    create_navigation_controls()
else:
    print("No results available for navigation.")

VBox(children=(HTML(value='<h3>Quick Navigation</h3>'), HTML(value='<p>Use the buttons to navigate through ima…

## Save Results

Save the processed images to files:

In [7]:
def save_as_individual_files(all_results, file_list_map, valid_mapping_dirs, output_dir):
    """Save each straightened image as an individual TIFF file."""
    print("Saving individual TIFF files...")
    
    num_maps = len(valid_mapping_dirs)
    total_files_to_save = len(all_results) * (1 + num_maps)
    saved_count = 0
    
    try:
        for index, result_set in enumerate(tqdm(all_results, desc="Saving files")):
            # Get source filename
            source_path = file_list_map[index]['target_img_path']
            base_name = os.path.basename(source_path)
            name_only, _ = os.path.splitext(base_name)
            
            # Save straightened target image
            target_img_np = result_set['target_straight_img']
            target_pil = Image.fromarray(cv2.cvtColor(target_img_np, cv2.COLOR_BGR2RGB))
            target_out_path = os.path.join(output_dir, f"{name_only}_target_straight.tif")
            target_pil.save(target_out_path, compression='tiff_lzw')
            saved_count += 1
            
            # Save all straightened mapping images
            for i, map_img_np in enumerate(result_set['mapping_straight_imgs']):
                map_pil = Image.fromarray(cv2.cvtColor(map_img_np, cv2.COLOR_BGR2RGB))
                map_dir_name = os.path.basename(valid_mapping_dirs[i])
                map_out_path = os.path.join(output_dir, f"{name_only}_map_{map_dir_name}_straight.tif")
                map_pil.save(map_out_path, compression='tiff_lzw')
                saved_count += 1
        
        print(f"Successfully saved {saved_count} individual TIFF files to {output_dir}")
        print(f"Check your output directory: {output_dir}")
        print(f"All straightened images have START point ({SKELETON_START_PREFERENCE}) on the LEFT side")
        
    except Exception as e:
        print(f"Error during saving: {e}")
        raise e


def save_as_stacks(all_results, valid_mapping_dirs, output_dir):
    """Save results as multi-frame TIFF stacks."""
    print("Saving multi-frame TIFF stacks...")
    
    num_mapping_channels = len(valid_mapping_dirs)
    
    # Collect target images
    target_images_pil = []
    mapping_images_pil = [[] for _ in range(num_mapping_channels)]
    
    print("Collecting images for stacks...")
    for result_set in tqdm(all_results, desc="Collecting images"):
        # Add target image
        target_img_np = result_set['target_straight_img']
        target_images_pil.append(Image.fromarray(cv2.cvtColor(target_img_np, cv2.COLOR_BGR2RGB)))
        
        # Add mapping images
        for i, map_img_np in enumerate(result_set['mapping_straight_imgs']):
            mapping_images_pil[i].append(Image.fromarray(cv2.cvtColor(map_img_np, cv2.COLOR_BGR2RGB)))
    
    try:
        stacks_saved = 0
        
        # Save target stack
        if target_images_pil:
            out_path_target = os.path.join(output_dir, "Target_Stack.tif")
            target_images_pil[0].save(out_path_target, save_all=True, append_images=target_images_pil[1:], 
                                    format='TIFF', compression='tiff_lzw')
            print(f"Saved target stack: Target_Stack.tif ({len(target_images_pil)} frames)")
            stacks_saved += 1
        
        # Save mapping stacks
        for i, map_image_list in enumerate(mapping_images_pil):
            if map_image_list:
                map_dir_name = os.path.basename(valid_mapping_dirs[i])
                out_path_map = os.path.join(output_dir, f"{map_dir_name}_Stack.tif")
                map_image_list[0].save(out_path_map, save_all=True, append_images=map_image_list[1:], 
                                     format='TIFF', compression='tiff_lzw')
                print(f"Saved mapping stack {i+1}: {map_dir_name}_Stack.tif ({len(map_image_list)} frames)")
                stacks_saved += 1
        
        print(f"\nSuccessfully saved {stacks_saved} multi-frame TIFF stacks!")
        print(f"Check your output directory: {output_dir}")
        print(f"All straightened images have START point ({SKELETON_START_PREFERENCE}) on the LEFT side")
        
    except Exception as e:
        print(f"Error during saving stacks: {e}")
        raise e


# Create interactive save controls
def create_save_controls():
    """Create interactive save controls."""
    if not ('all_results' in locals() or 'all_results' in globals()) or not all_results:
        print("No results to save. Please run the processing cell first.")
        return
    
    # Create buttons
    save_individual_btn = widgets.Button(
        description='Save Individual Files',
        button_style='success',
        tooltip='Save each image as a separate TIFF file'
    )
    
    save_stacks_btn = widgets.Button(
        description='Save as Stacks',
        button_style='warning',
        tooltip='Save as multi-frame TIFF stacks'
    )
    
    # Output for save status
    save_output = widgets.Output()
    
    def on_save_individual(b):
        with save_output:
            clear_output(wait=True)
            save_as_individual_files(all_results, file_list_map, valid_mapping_dirs, OUTPUT_DIR)
    
    def on_save_stacks(b):
        with save_output:
            clear_output(wait=True)
            save_as_stacks(all_results, valid_mapping_dirs, OUTPUT_DIR)
    
    # Connect events
    save_individual_btn.on_click(on_save_individual)
    save_stacks_btn.on_click(on_save_stacks)
    
    # Layout
    save_controls = widgets.HBox([save_individual_btn, save_stacks_btn], 
                                layout=widgets.Layout(justify_content='center'))
    
    display(widgets.VBox([
        widgets.HTML("<h3>Save Results</h3>"),
        widgets.HTML(f"<p><b>Output Directory:</b> {OUTPUT_DIR}</p>"),
        widgets.HTML(f"<p><b>Ready to save:</b> {len(all_results)} processed image sets</p>"),
        widgets.HTML(f"<p><b>Skeleton orientation:</b> START point ({SKELETON_START_PREFERENCE}) -> LEFT side of outputs</p>"),
        save_controls,
        save_output
    ]))


# Create save controls
if 'all_results' in locals() and all_results:
    create_save_controls()
else:
    print("No results available for saving.")

VBox(children=(HTML(value='<h3>Save Results</h3>'), HTML(value='<p><b>Output Directory:</b> /Users/jyzerresico…