In [None]:
import os
import logging
import numpy as np
import pandas as pd
from tifffile import imread
from skimage.measure import regionprops
from collections import defaultdict
from scipy.spatial.distance import cdist
from typing import List, Dict, Any, Tuple, Optional


In [2]:
data_directory = "/mnt/external.data/MeisterLab/mvolosko/image_project/SDC1/1268_fast_imaging_01/spots/spot_segmentation/"


In [None]:
# --- Configuration & Setup ---

# Configure logging for clear script output.
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logging.getLogger('tifffile').setLevel(logging.ERROR)


# --- Core Functions ---

def get_mask_filepaths(directory: str) -> List[str]:
    """
    Gets and sorts all TIFF file paths from a directory.

    Args:
        directory (str): Path to the directory containing TIFF files.

    Returns:
        List[str]: A sorted list of full file paths to the TIFF masks.
    """
    files = sorted([f for f in os.listdir(directory) if f.endswith(('.tif', '.tiff'))])
    if not files:
        raise ValueError(f"No TIFF files found in directory: {directory}")
    return [os.path.join(directory, f) for f in files]


def load_mask(file_path: str) -> np.ndarray:
    """
    Loads a single 3D mask from a file path, ensuring it's a 3D integer array.

    Args:
        file_path (str): The path to the TIFF file.

    Returns:
        np.ndarray: A 3D mask array (z, y, x).
    """
    try:
        mask = imread(file_path)
        mask_squeezed = np.squeeze(mask).astype(np.int32)
        if mask_squeezed.ndim != 3:
            raise ValueError(f"Mask must be 3D, but file {file_path} has shape {mask_squeezed.shape} after squeeze.")
        return mask_squeezed
    except Exception as e:
        logging.error(f"Failed to load or process mask {file_path}: {e}")
        raise


def extract_properties_from_mask(
    mask: np.ndarray, 
    time_index: int, 
    min_volume_threshold: int = 0
) -> List[Dict[str, Any]]:
    """
    Extracts region properties for all objects in a mask that meet a minimum volume.

    Args:
        mask (np.ndarray): A single 3D mask array.
        time_index (int): The timepoint (frame number) of the mask.
        min_volume_threshold (int): The minimum pixel volume for an object to be included.

    Returns:
        List[Dict[str, Any]]: A list of property dictionaries for each valid object.
    """
    regions = regionprops(mask)
    properties = []
    for r in regions:
        # --- Noise Filtering ---
        # Ignore objects smaller than the specified volume threshold.
        if r.area >= min_volume_threshold:
            properties.append({
                'time': time_index,
                'label': r.label,
                'centroid': r.centroid,
                'area': r.area,
                'bbox': r.bbox
            })
    return properties


def link_objects_between_frames(
    props1: List[Dict[str, Any]],
    props2: List[Dict[str, Any]],
    mask1: np.ndarray,
    mask2: np.ndarray,
    overlap_threshold: float = 0.25
) -> List[Tuple[Tuple[int, int], Tuple[int, int]]]:
    """
    Calculates links between objects in two consecutive frames based on overlap.

    ***ALGORITHM UPDATE***:
    This function now uses Intersection over Parent (IoP) instead of IoU.
    IoP = intersection / area_of_object_in_frame1. This is more robust for
    detecting splits (one large parent overlapping with multiple smaller children).

    Args:
        props1: Properties for objects in the first frame (t).
        props2: Properties for objects in the second frame (t+1).
        mask1: The full 3D mask for the first frame.
        mask2: The full 3D mask for the second frame.
        overlap_threshold: The minimum IoP to establish a link.

    Returns:
        A list of links, where each link is a tuple of ((t, label1), (t+1, label2)).
    """
    links = []
    if not props1 or not props2:
        return []
        
    time1 = props1[0]['time']
    time2 = props2[0]['time']

    for p1 in props1:
        p1_bbox = p1['bbox']
        for p2 in props2:
            p2_bbox = p2['bbox']
            
            # Fast bounding box intersection check
            inter_min_z = max(p1_bbox[0], p2_bbox[0])
            inter_min_y = max(p1_bbox[1], p2_bbox[1])
            inter_min_x = max(p1_bbox[2], p2_bbox[2])
            inter_max_z = min(p1_bbox[3], p2_bbox[3])
            inter_max_y = min(p1_bbox[4], p2_bbox[4])
            inter_max_x = min(p1_bbox[5], p2_bbox[5])

            if not (inter_max_z > inter_min_z and inter_max_y > inter_min_y and inter_max_x > inter_min_x):
                continue
            
            # Efficient intersection calculation on the small intersection crop
            crop1 = mask1[inter_min_z:inter_max_z, inter_min_y:inter_max_y, inter_min_x:inter_max_x]
            crop2 = mask2[inter_min_z:inter_max_z, inter_min_y:inter_max_y, inter_min_x:inter_max_x]
            intersection = np.sum((crop1 == p1['label']) & (crop2 == p2['label']))

            if intersection > 0:
                # --- Intersection over Parent (IoP) ---
                # If an object in frame t overlaps with an object in t+1 by at least
                # 25% (or other specified) of its own (the parent's) volume, we link them.
                iop = intersection / p1['area']
                if iop >= overlap_threshold:
                    key1 = (time1, p1['label'])
                    key2 = (time2, p2['label'])
                    links.append((key1, key2))
    return links


def build_link_graph(all_props: List[Dict[int, Dict[str, Any]]], mask_files: List[str], overlap_threshold: float) -> defaultdict:
    """
    Builds a directed graph of object links across all timepoints.
    This function processes masks sequentially to keep memory usage low.
    """
    link_graph = defaultdict(lambda: {'parents': [], 'children': []})
    
    # Load the first mask to initialize
    mask_t0 = load_mask(mask_files[0])

    for t in range(len(all_props) - 1):
        logging.info(f"Processing link between frame {t} and {t+1}...")
        
        # Load the next mask in the sequence
        mask_t1 = load_mask(mask_files[t+1])
        
        # Link objects between the current pair of masks
        links = link_objects_between_frames(
            list(all_props[t].values()), list(all_props[t+1].values()), mask_t0, mask_t1, overlap_threshold
            )
        
        # Populate the graph with the links found
        for key1, key2 in links:
            link_graph[key1]['children'].append(key2)
            link_graph[key2]['parents'].append(key1)
            
        # The mask for t+1 becomes the mask for t in the next iteration.
        mask_t0 = mask_t1
        
    return link_graph


def track_objects_robust(
    all_props: List[Dict[int, Dict[str, Any]]],
    link_graph: defaultdict,
    max_gap_frames: int = 2,
    max_distance: float = 50.0
) -> Tuple[Dict, defaultdict, Dict]:
    """
    Tracks objects across time, assigning unique track IDs, identifying events,
    and handling gaps in tracking.

    Args:
        all_props: A list of dictionaries, mapping labels to object properties for all timepoints.
        link_graph: The object linkage graph.
        max_gap_frames: The maximum number of frames an object can disappear for and still be re-linked.
        max_distance: The maximum centroid distance to bridge a gap.

    Returns:
        A tuple containing: track_map, event_map, and lineage.
    """
    track_counter = 1
    track_map = {}  # (time, label) -> track_id
    event_map = defaultdict(lambda: {'type': '', 'parents': []})
    lineage = {}    # track_id: [parent_track_id(s)]
    
    # --- Gap Closing Logic ---
    lost_tracks = {} # track_id -> {'key': (t, label), 'props': props, 'frame_lost': t}

    for t, timepoint_props in enumerate(all_props):
        # Link new objects to lost tracks from previous frames
        unlinked_props = []
        
        # Find objects in this frame that have no parents
        props_with_no_parents = []
        for props in timepoint_props.values():
            key = (props['time'], props['label'])
            if not link_graph[key]['parents']:
                props_with_no_parents.append(props)

        # Try to bridge gaps for these parentless objects
        matched_lost_track_ids = set()
        for new_prop in props_with_no_parents:
            new_key = (new_prop['time'], new_prop['label'])
            best_match = None
            min_dist = float('inf')

            # Find the closest lost track within the time and distance thresholds
            for track_id, lost_info in list(lost_tracks.items()):
                if track_id in matched_lost_track_ids: continue
                
                frame_gap = t - lost_info['frame_lost']
                if 1 <= frame_gap <= max_gap_frames:
                    dist = np.linalg.norm(np.array(new_prop['centroid']) - np.array(lost_info['props']['centroid']))
                    if dist < max_distance and dist < min_dist:
                        min_dist = dist
                        best_match = track_id
            
            if best_match is not None:
                # Found a match, bridge the gap!
                lost_info = lost_tracks[best_match]
                parent_key = lost_info['key']
                parent_track_id = track_map[parent_key]

                # Update maps and lineage
                track_map[new_key] = parent_track_id
                lineage.setdefault(parent_track_id, []).append(parent_track_id) # Self-parent for bridging
                event_map[new_key] = {'type': 'bridged', 'parents': [parent_track_id]}
                event_map[parent_key]['type'] = 'bridged_disappearance' # Mark the original disappearance

                # Remove from lost tracks so it can't be matched again
                del lost_tracks[best_match]
                matched_lost_track_ids.add(best_match)
            else:
                # No match found, it's a truly new object (or we can't bridge it)
                unlinked_props.append(new_prop)

        # Process all objects in the current timepoint
        for obj in timepoint_props.values():
            key = (t, obj['label'])
            if key in track_map: continue # Already handled by gap closing

            parents = link_graph[key]['parents']
            parent_track_ids = [track_map.get(p) for p in parents if p in track_map]
            
            if not parent_track_ids: # Appearance
                track_id = track_counter
                track_map[key] = track_id
                event_map[key] = {'type': 'appearance', 'parents': []}
                lineage[track_id] = []
                track_counter += 1
            elif len(parent_track_ids) == 1: # Continuation
                track_id = parent_track_ids[0]
                track_map[key] = track_id
                lineage.setdefault(track_id, []).extend(parent_track_ids)
            else: # Merge
                # For merges, we continue the track of the largest parent
                parent_areas = [all_props[p[0]][p[1]]['area'] for p in parents]
                main_parent_idx = np.argmax(parent_areas)
                main_parent_track_id = parent_track_ids[main_parent_idx]

                track_id = main_parent_track_id
                track_map[key] = track_id
                event_map[key] = {'type': 'merge', 'parents': parent_track_ids}
                lineage.setdefault(track_id, []).extend(parent_track_ids)

        # Final pass for this timepoint to identify splits and true disappearances
        # We check the *previous* frame's objects to see if they split or disappeared
        if t > 0:
            for prev_obj in all_props[t-1].values():
                prev_key = (t-1, prev_obj['label'])
                if prev_key not in track_map: continue

                children = link_graph[prev_key]['children']
                
                if not children: # Potential disappearance
                    if event_map[prev_key]['type'] == '': # Not already marked as part of a bridge
                        event_map[prev_key]['type'] = 'disappearance'
                        track_id = track_map[prev_key]
                        lost_tracks[track_id] = {'key': prev_key, 'props': prev_obj, 'frame_lost': t-1}
                elif len(children) > 1: # Split
                    event_map[prev_key]['type'] = 'split'

    # Clean up any lost tracks at the end of the sequence
    for track_id, lost_info in lost_tracks.items():
        if event_map[lost_info['key']]['type'] == 'disappearance':
             event_map[lost_info['key']]['type'] = 'terminal' # Final disappearance

    return track_map, event_map, lineage


def generate_outputs(
    all_props: List[Dict[int, Dict[str, Any]]], 
    track_map: Dict, 
    event_map: defaultdict, 
    lineage: Dict,
    output_csv: str,
    tracks_output_csv: str
) -> Dict[str, Any]:
    """Generates a detailed CSV report and a napari-compatible tracks file."""
    csv_data, tracks_data = [], []
    
    for t, timepoint_props in enumerate(all_props):
        for obj in timepoint_props.values():
            key = (t, obj['label'])
            track_id = track_map.get(key)
            if track_id is None: continue

            event_info = event_map.get(key, {'type': '', 'parents': []})
            parent_ids_str = ','.join(map(str, sorted(list(set(event_info['parents'])))))
            
            csv_data.append({
                'time': t,
                'z': obj['centroid'][0], 'y': obj['centroid'][1], 'x': obj['centroid'][2],
                'label': obj['label'], 'track_id': track_id, 'area': obj['area'],
                'event_type': event_info['type'], 'parent_track_id': parent_ids_str
            })
            
            tracks_data.append([track_id, t, obj['centroid'][0], obj['centroid'][1], obj['centroid'][2]])
    
    df = pd.DataFrame(csv_data).sort_values(by=['track_id', 'time']).reset_index(drop=True)
    df.to_csv(output_csv, index=False)
    
    tracks_array = np.array(tracks_data, dtype=float)
    napari_df = pd.DataFrame(tracks_array, columns=['track_id', 'time', 'z', 'y', 'x'])
    napari_df.to_csv(tracks_output_csv, index=False)
    
    napari_graph = {child_id: parents[0] for child_id, parents in lineage.items() if parents and len(parents) == 1}

    return {
        'tracks_array': tracks_array,
        'graph': napari_graph,
        'csv_path': output_csv,
        'tracks_csv_path': tracks_output_csv
    }


# --- Main Pipeline---

def segmentation_pipeline_robust(
    input_dir: str, 
    output_csv: str = 'analysis_results.csv',
    tracks_output_csv: str = 'napari_tracks.csv',
    overlap_threshold: float = 0.1,
    min_volume_fraction: float = 0.2,
    max_gap_frames: int = 2,
    max_distance_for_gap: float = 50.0
) -> Dict[str, Any]:
    """
    End-to-end robust segmentation analysis pipeline.

    This pipeline includes noise filtering, gap closing, and uses a more robust
    linking metric (IoP) to handle splits and merges effectively.

    Args:
        input_dir: Directory containing 3D mask TIFF files.
        output_csv: Path to save the detailed CSV report.
        tracks_output_csv: Path to save the Napari-compatible tracks CSV.
        overlap_threshold: Minimum IoP (Intersection over Parent) to link objects.
        min_volume_fraction: Fraction of the median object volume to use as a noise threshold.
        max_gap_frames: Max frames to bridge a track across a gap.
        max_distance_for_gap: Max centroid distance to bridge a track.
        
    Returns:
        A dictionary of results including the tracks array, graph, and output file paths.
    """
    try:
        mask_files = get_mask_filepaths(input_dir)
        num_timepoints = len(mask_files)
        logging.info(f"Found {num_timepoints} timepoints to process.")
        
        # --- Step 1: Pre-process to find median volume for noise filtering ---
        logging.info("Pre-processing to determine volume threshold...")
        all_volumes = []
        # Use a subset of frames for efficiency if the dataset is very large
        sample_files = mask_files[:min(20, num_timepoints)] 
        for i, f in enumerate(sample_files):
            mask = load_mask(f)
            props = extract_properties_from_mask(mask, i) # No volume filter yet
            if props:
                all_volumes.extend([p['area'] for p in props])
        
        if not all_volumes:
            logging.warning("No objects found in sample files. Disabling volume filter.")
            min_volume_threshold = 0
        else:
            median_volume = np.median(all_volumes)
            min_volume_threshold = int(median_volume * min_volume_fraction)
            logging.info(f"Median object volume is {median_volume:.2f}. Using min volume threshold: {min_volume_threshold}")

        # --- Step 2: Extract properties from all masks with noise filtering ---
        all_props = []
        for t, file_path in enumerate(mask_files):
            mask = load_mask(file_path)
            props_t = extract_properties_from_mask(mask, t, min_volume_threshold)
            all_props.append({p['label']: p for p in props_t})
        logging.info("Extracted properties for all timepoints.")

        # --- Step 3: Build the linkage graph between objects ---
        link_graph = build_link_graph(all_props, mask_files, overlap_threshold)
        logging.info("Built object linkage graph.")
        
        # --- Step 4: Perform robust tracking with gap closing ---
        track_map, event_map, lineage = track_objects_robust(
            all_props, link_graph, max_gap_frames, max_distance_for_gap
        )
        logging.info("Completed robust object tracking with gap closing.")
        
        # --- Step 5: Generate final outputs ---
        results = generate_outputs(
            all_props, track_map, event_map, lineage, output_csv, tracks_output_csv
        )
        logging.info(f"Generated detailed report: {output_csv}")
        logging.info(f"Generated Napari-compatible tracks file: {tracks_output_csv}")
        
        return results
        
    except Exception as e:
        logging.error(f"Pipeline failed: {str(e)}")
        raise


INFO: Found 18 timepoints to process.
INFO: Pre-processing to determine volume threshold...
INFO: Median object volume is 1377.00. Using min volume threshold: 413
INFO: Extracted properties for all timepoints.
INFO: Processing link between frame 0 and 1...
INFO: Processing link between frame 1 and 2...
INFO: Processing link between frame 2 and 3...
INFO: Processing link between frame 3 and 4...
INFO: Processing link between frame 4 and 5...
INFO: Processing link between frame 5 and 6...
INFO: Processing link between frame 6 and 7...
INFO: Processing link between frame 7 and 8...
INFO: Processing link between frame 8 and 9...
INFO: Processing link between frame 9 and 10...
INFO: Processing link between frame 10 and 11...
INFO: Processing link between frame 11 and 12...
INFO: Processing link between frame 12 and 13...
INFO: Processing link between frame 13 and 14...
INFO: Processing link between frame 14 and 15...
INFO: Processing link between frame 15 and 16...
INFO: Processing link bet


--- Pipeline executed successfully! ---
Detailed results saved to: analysis_results_robust.csv
Napari tracks saved to: napari_tracks_robust.csv


In [None]:
# --- Example Usage (made for scripting rather than Jyp notebook)---
if __name__ == "__main__":

    if os.path.isdir(data_directory):
        try:
            results = segmentation_pipeline_robust(
                input_dir=data_directory,
                output_csv="analysis_results.csv",
                tracks_output_csv="napari_tracks.csv",
                overlap_threshold=0.25,      # Overlap of 25% of the parent's volume
                min_volume_fraction=0.3,     # Ignore objects smaller than 30% of median volume
                max_gap_frames=2,            # Allow objects to disappear for up to 2 frames
                max_distance_for_gap=50.0    # Max distance (in pixels) to link a reappeared object
            )
            print("\n--- Pipeline executed successfully! ---")
            print(f"Detailed results saved to: {results['csv_path']}")
            print(f"Napari tracks saved to: {results['tracks_csv_path']}")

        except Exception as e:
            print(f"\nAn error occurred during the pipeline execution: {e}")
    else:
        print(f"\n--- SKIPPING EXECUTION ---")
        print(f"The example data directory does not exist: '{data_directory}'")
        print("Please update the 'data_directory' variable to point to your TIFF files.")