Segmentation

In [None]:
# 🛠️ SETUP
import sys
from pathlib import Path

# Add karyosight modules
sys.path.append(r"D:\GITHUB_SOFTWARE\karyosight\karyosight\src")

from karyosight.segmentation_SAM import (
    segment_all_conditions,
    test_cellpose_sam_installation,
    # New functions for segmented directory workflow
    open_segmented_data_in_napari,
    quick_view_segmented,
    compare_segmented_conditions,
    list_segmented_conditions,
    get_segmented_organoid_count,
    # Comparison functions
    compare_raw_vs_optimized,
    quick_compare_raw_vs_optimized,
    # Visualization class for segmented data
    SegmentedOrganoidVisualizer
)

# Configuration
CROPPED_DIR = Path("D:/LUNGBUD_master/cropped")
DRY_RUN = False  # Set to True for testing, False for actual processing

print("SEGMENTATION PIPELINE")
print("=" * 60)
print(f"Directory: {CROPPED_DIR}")
print(f"Dry run: {DRY_RUN}")

# Test environment
print("\n🧪 Testing Cellpose-SAM...")
if test_cellpose_sam_installation():
    print("✅ Ready for segmentation!")
else:
    print("❌ Environment issues - fix before proceeding")


In [None]:
# 🧪 TEST COMPLETE FILTERING PIPELINE (INCLUDING Z-SLICE OPTIMIZATION)

from karyosight.segmentation_SAM import test_filtering_setup

print("🧪 TESTING YOUR COMPLETE FILTERING PIPELINE")
print("=" * 60)
print("⚠️  IMPORTANT: Test this before running on all data!")
print("📋 Testing: Focus filter + Pre-seg z-filter + Post-seg z-optimization")

# Test the complete filtering setup on a sample organoid
test_results = test_filtering_setup(
    cropped_dir=Path("D:/LUNGBUD_master/cropped"),
    condition_name=None,  # Uses first available condition
    organoid_idx=0,       # Test organoid 0
    channel=0             # DAPI channel
)

print("\n" + "="*80)
print("🎯 COMPLETE FILTERING PIPELINE VERIFICATION")
print("="*80)

if test_results.get('error'):
    print("❌ Test failed - fix issues before proceeding")
    print(f"Error: {test_results['error']}")
else:
    focus_enabled = test_results['focus_filtering']['enabled']
    z_slice_enabled = test_results['z_slice_filtering']['enabled']
    z_opt_enabled = test_results.get('z_slice_optimization', {}).get('enabled', False)
    would_process = test_results['final_processing_decision']['would_be_processed']
    
    print(f"✅ Complete filtering pipeline verified:")
    print(f"   → 🎯 Focus filtering (skip out-of-focus): {'✅ ENABLED' if focus_enabled else '❌ DISABLED'}")
    print(f"   → 🧹 Pre-segmentation z-filter (remove black): {'✅ ENABLED' if z_slice_enabled else '❌ DISABLED'}")
    print(f"   → ✂️ Post-segmentation z-optimize (remove empty): {'✅ ENABLED' if z_opt_enabled else '❌ DISABLED'}")
    print(f"   → 📊 Test organoid result: {'✅ WOULD BE PROCESSED' if would_process else '❌ WOULD BE SKIPPED'}")
    
    # Show z-slice optimization details
    if z_opt_enabled:
        z_opt = test_results['z_slice_optimization']
        strategy = z_opt.get('strategy', 'unknown')
        potential_compression = z_opt.get('potential_compression', 0)
        
        print(f"\n✂️ Z-SLICE OPTIMIZATION DETAILS:")
        print(f"   → Strategy: {strategy}")
        print(f"   → Potential file size reduction: ~{potential_compression:.1f}%")
        print(f"   → This will dramatically reduce storage while keeping ALL nuclei!")
        
        if strategy == 'z_range':
            print(f"   → Method: Keep z-range from first to last nuclei-containing slice")
        elif strategy == 'padded_range':
            padding = z_opt.get('padding', 0)
            print(f"   → Method: Keep nuclei z-range + {padding} padding slices")
        elif strategy == 'covered_only':
            print(f"   → Method: Keep ONLY slices with segmented nuclei")
    
    # Overall assessment
    total_filters = sum([focus_enabled, z_slice_enabled, z_opt_enabled])
    if total_filters >= 2:
        print(f"\n🎉 EXCELLENT! YOUR COMPLETE FILTERING PIPELINE IS ACTIVE!")
        print(f"   → {total_filters}/3 filtering stages enabled")
        print(f"   → This will provide optimal quality control AND file size reduction")
        print(f"   → Ready to run on all data with confidence!")
    elif total_filters == 1:
        print(f"\n⚠️ PARTIAL FILTERING ACTIVE")
        print(f"   → {total_filters}/3 filtering stages enabled")
        print(f"   → Consider enabling more filters for better results")
    else:
        print(f"\n⚠️  NO FILTERING APPLIED")
        print(f"   → All organoids will be processed without quality control")
        print(f"   → File sizes will be larger than necessary")

print("\n" + "="*80)


In [None]:
# 🚀 PRODUCTION SEGMENTATION - SAVES TO SEPARATE 'SEGMENTED' DIRECTORY

print("🚀 PRODUCTION SEGMENTATION → NEW 'SEGMENTED' DIRECTORY")
print("=" * 70)

# Run segmentation with NEW segmented directory approach
# Uncomment the line below to run on ALL conditions:
results = segment_all_conditions(
    cropped_dir=Path("D:/LUNGBUD_master/cropped"),
    optimal_diameter=50.0,
    optimal_flow_threshold=0.4,
    optimal_cellprob_threshold=0.0,
    optimal_batch_size=16,
    min_nuclei_size=1000,
    channel=0,
    dry_run=False,
    resume=True,   # Skip already processed organoids
    verbose=True   # Set to False for less output
)


In [None]:
# 🔍 VISUALIZE Z-OPTIMIZED SEGMENTED DATA IN NAPARI

from karyosight.segmentation_SAM import (
    open_segmented_data_in_napari,
    quick_view_segmented,
    compare_segmented_conditions,
    list_segmented_conditions,
    get_segmented_organoid_count
)

print("🔍 NAPARI VISUALIZATION - Z-OPTIMIZED SEGMENTED DATA")
print("=" * 60)
print("📁 Loading from NEW segmented directory (not cropped)")

# List available segmented conditions
available_conditions = list_segmented_conditions()
print(f"📋 Available segmented conditions: {available_conditions}")

if available_conditions:
    # Get info about first condition
    condition_info = get_segmented_organoid_count(available_conditions[0])
    if 'error' not in condition_info:
        print(f"\n📊 {condition_info['condition']} info:")
        print(f"   → Organoids: {condition_info['n_organoids']}")
        print(f"   → Total nuclei: {condition_info['total_nuclei']}")
        print(f"   → Average compression: {condition_info['average_compression_percent']:.1f}%")
        print(f"   → Available organoids: {condition_info['organoid_indices']}")
    
    # Quick view first available condition and organoid 0
    print(f"\n🎯 Opening napari with z-optimized data...")
    viewer = quick_view_segmented()
else:
    print("❌ No segmented conditions found")
    print("   → Run segmentation first to create segmented data")
# viewer = open_segmentation_napari()

# Or specify condition and organoid:
viewer = open_segmentation_napari('Condition_D25T75', 0)

print("💡 INSTRUCTIONS:")
print("1. Uncomment one of the lines above to open napari")
print("2. This will load and display:")
print("   → ✨ OPTIMIZED masks from cropped folder (z-slices reduced!)")
print("   → 🧬 Raw data + segmentation masks with correct channel names")
print("   → 📊 Nuclei counts and processing metadata")
print("   → 📍 Confirm masks are stored in cropped folder")
print("   → ✂️ Notice smaller z-stack size (z-slice optimization applied!)")

print("\n🎯 EXAMPLES:")
print("   viewer = open_segmentation_napari()  # First available")
print("   viewer = open_segmentation_napari('Condition_D50T50', 0)  # Specific")

print("\n✨ WHAT TO EXPECT:")
print("   → Raw data: Full z-stack (e.g., 25 slices)")
print("   → Masks: Optimized z-stack (e.g., 14 slices)")
print("   → Same nuclei count, smaller file size!")
print("   → Confirms z-slice optimization is working correctly")


Visualization

In [None]:
# 🎨 IN-NOTEBOOK VISUALIZATION - SEGMENTED ORGANOIDS WITH MASKS

print("🎨 SEGMENTED ORGANOID VISUALIZATION")
print("=" * 60)
print("📊 Creating publication-quality grids with segmentation overlays")

# Initialize the segmented organoid visualizer
visualizer = SegmentedOrganoidVisualizer()

# Count organoids per condition  
print("\n📋 Organoid counts per condition:")
counts = visualizer.count_organoids_per_condition()
for condition, count in counts.items():
    print(f"   → {condition}: {count} organoids")

print("\n🎯 VISUALIZATION OPTIONS:")
print("1. Single condition grid with mask overlays")
print("2. All conditions overview") 
print("3. Multi-channel comparison for single organoid")

# Option 1: Single condition visualization
if counts:
    # Get first condition as example
    example_condition = list(counts.keys())[0]
    
    print(f"\n📊 Creating visualization for: {example_condition}")
    print("✨ Features:")
    print("   → Max projection of z-optimized raw data")
    print("   → Red mask overlay showing segmented nuclei") 
    print("   → Nuclei count in each title")
    print("   → Saves as SVG (vector format) for publications")
    
    # Create single condition grid (sample 12 organoids)
    saved_path = visualizer.create_segmented_organoid_grid(
        condition=example_condition,
        channel=0,  # DAPI/nucleus channel
        sample_count=12,  # Show 12 organoids for overview
        mask_overlay=True,  # Show segmentation masks
        mask_alpha=0.3,  # Semi-transparent red overlay
        save_svg=True,  # Save as SVG
        save_png=False,  # Don't save PNG 
        show_plot=True  # Display in notebook
    )
    
    if saved_path:
        print(f"\n💾 Saved visualization: {saved_path}")
        print(f"📁 Location: {visualizer.vis_dir}")

else:
    print("❌ No segmented conditions found")
    print("   → Run segmentation first to create visualizations")








# 🎨 ADDITIONAL VISUALIZATION OPTIONS

print("\n🎨 ADDITIONAL VISUALIZATION OPTIONS")
print("=" * 60)

# Option 2: All conditions overview (commented out to avoid overwhelming output)
print("📊 Option 2: Create visualizations for ALL conditions")
print("💡 Uncomment to run:")
print("saved_paths = visualizer.create_all_conditions_grid(")
print("    channel=0,")
print("    sample_count=9,  # 9 organoids per condition for overview")
print("    mask_overlay=True,")
print("    save_svg=True")
print(")")

# Uncomment below to create visualizations for all conditions:
# saved_paths = visualizer.create_all_conditions_grid(
#     channel=0,
#     sample_count=9,  # 9 organoids per condition for overview
#     mask_overlay=True,
#     save_svg=True,
#     show_plot=False  # Don't show all plots (too many)
# )
# print(f"✅ Created {len(saved_paths)} condition visualizations")

# Option 3: Multi-channel comparison for a single organoid
if counts:
    example_condition = list(counts.keys())[0]
    
    print(f"\n📊 Option 3: Multi-channel comparison for single organoid")
    print(f"💡 Showing {example_condition}, organoid 0 across different channels")
    print("✨ Features:")
    print("   → Same organoid shown in multiple channels")
    print("   → Segmentation masks overlaid on all channels")
    print("   → Shows co-localization patterns")
    
    # Create multi-channel comparison
    saved_path = visualizer.create_channel_comparison(
        condition=example_condition,
        organoid_idx=0,
        channels=[0, 1, 2],  # First 3 channels (nucleus, tetraploid, diploid)
        mask_overlay=True,
        mask_alpha=0.3,
        save_svg=True,
        show_plot=True
    )
    
    if saved_path:
        print(f"\n💾 Multi-channel comparison saved: {saved_path}")

print("\n💡 CUSTOMIZATION OPTIONS:")
print("🎛️  Key parameters you can adjust:")
print("   • channel: 0=nucleus, 1=tetraploid, 2=diploid, 3=sox9, 4=sox2, 5=brightfield")
print("   • sample_count: Number of organoids to show (None=all)")
print("   • mask_overlay: True/False to show segmentation masks")
print("   • mask_alpha: 0.1-0.8 for mask transparency")
print("   • auto_scale_method: 'individual' or 'global' intensity scaling")
print("   • save_svg: True for vector graphics (best for publications)")
print("   • save_png: True for raster images (smaller files)")

print("\n📁 OUTPUT STRUCTURE:")
print("All visualizations are saved to:")
print(f"   → {visualizer.vis_dir}")
print("Files are named automatically based on:")
print("   → Condition name")
print("   → Channel name") 
print("   → Whether masks are included")
print("   → Number of organoids shown")
print("   → Scaling method used")

print("\n✨ PERFECT FOR:")
print("   • Publications (SVG format)")
print("   • Quality control")
print("   • Condition comparisons")
print("   • Segmentation validation")
print("   • Presentations")





Extra

In [None]:
# 🔍 COMPARE ORIGINAL vs Z-OPTIMIZED DATA SIDE-BY-SIDE

print("🔍 COMPARING ORIGINAL vs Z-OPTIMIZED DATA")
print("=" * 60)
print("📊 This shows the before/after of z-slice optimization")

# Quick comparison of original raw data vs z-optimized data
print("🎯 Opening side-by-side comparison...")
comparison_viewer = quick_compare_raw_vs_optimized('Condition_D25T75', 0)

print("\n💡 COMPARISON GUIDE:")
print("✅ What you'll see in napari:")
print("   → 'ORIGINAL - nucleus (XX z-slices)' - Full untouched raw data")
print("   → 'Z-OPTIMIZED - nucleus (YY z-slices)' - Compressed raw data")
print("   → 'Z-Optimized Masks' - Segmentation masks (same z as optimized)")
print("\n🔧 How to use:")
print("   • Toggle layer visibility to switch between original and optimized")
print("   • Scroll through z-slices to see which were removed")
print("   • Notice the z-slice count difference in layer names")
print("   • Original preserves all slices, optimized keeps only nuclei-containing region")
print("\n📏 What to verify:")
print("   • Optimized version has fewer z-slices")
print("   • All nuclei are preserved in the optimized version")
print("   • File size reduction matches the compression percentage")
print("   • No important data was lost in the optimization")


In [None]:
# 🔀 COMPARE MULTIPLE CONDITIONS (Z-OPTIMIZED DATA)

print("🔀 CONDITION COMPARISON - Z-OPTIMIZED SEGMENTED DATA")
print("=" * 60)

# List available conditions for comparison
available_conditions = list_segmented_conditions()
print(f"📋 Available conditions for comparison: {available_conditions}")

if len(available_conditions) >= 2:
    print(f"\n🔍 Comparing all {len(available_conditions)} conditions side-by-side...")
    print("💡 This will show z-optimized data for consistent comparison")
    
    # Open comparison viewer with all conditions
    comparison_viewer = compare_segmented_conditions(
        condition_names=None,  # Use all available conditions
        organoid_idx=0         # Compare organoid 0 across conditions
    )
    
    print("\n💡 COMPARISON TIPS:")
    print("   • Toggle layer visibility to switch between conditions")
    print("   • All data is z-optimized for fair comparison")
    print("   • Compare nuclei counts and compression ratios")
    print("   • Use opacity sliders to overlay different masks")
    
elif len(available_conditions) == 1:
    print(f"\n⚠️  Only 1 condition available - showing detailed view...")
    condition_name = available_conditions[0]
    
    # Show detailed info for the single condition
    info = get_segmented_organoid_count(condition_name)
    if 'error' not in info:
        print(f"\n📊 DETAILED INFO - {condition_name}:")
        print(f"   → Total organoids: {info['n_organoids']}")
        print(f"   → Total nuclei: {info['total_nuclei']}")
        print(f"   → Average nuclei per organoid: {info['average_nuclei_per_organoid']:.1f}")
        print(f"   → Average compression: {info['average_compression_percent']:.1f}%")
        
        print(f"\n📋 PER-ORGANOID BREAKDOWN:")
        for org_idx, details in info['organoid_details'].items():
            original_shape = details['original_shape']
            optimized_shape = details['optimized_shape']
            compression = details['compression_percent']
            n_nuclei = details['n_nuclei']
            
            if original_shape and optimized_shape:
                original_z = original_shape[1] if len(original_shape) > 1 else 'N/A'
                optimized_z = optimized_shape[1] if len(optimized_shape) > 1 else 'N/A'
                print(f"   → Organoid {org_idx}: {n_nuclei} nuclei, {original_z}→{optimized_z} z-slices ({compression:.1f}% compression)")
            else:
                print(f"   → Organoid {org_idx}: {n_nuclei} nuclei, {compression:.1f}% compression")
        
        # Open viewer for detailed inspection
        viewer = open_segmented_data_in_napari(condition_name, organoid_idx=0)
    
else:
    print("❌ No segmented conditions found for comparison")
    print("   → Run segmentation first to create segmented data")


In [None]:
# 🔍 DIAGNOSTIC: Z-SLICE OPTIMIZATION ANALYSIS

import numpy as np
import matplotlib.pyplot as plt
from karyosight.segmentation_SAM import optimize_z_slices_after_segmentation
import zarr

def analyze_z_slice_optimization(condition_name, organoid_idx, cropped_dir=Path("D:/LB_TEST3/cropped")):
    """
    Diagnostic function to analyze z-slice optimization behavior
    """
    print(f"🔍 ANALYZING Z-SLICE OPTIMIZATION")
    print(f"=" * 60)
    print(f"Condition: {condition_name}, Organoid: {organoid_idx}")
    
    # Load the zarr file
    zarr_path = cropped_dir / condition_name / f"{condition_name}_bundled.zarr"
    bundle = zarr.open_group(str(zarr_path), mode='r')
    organoid_key = f"organoid_{organoid_idx:04d}"
    
    # Load masks (should exist from previous segmentation)
    if organoid_key in bundle and 'masks' in bundle[organoid_key]:
        masks = bundle[organoid_key]['masks'][:]
        if hasattr(masks, 'compute'):
            masks = masks.compute()
        else:
            masks = np.array(masks)
        
        print(f"✅ Loaded masks: {masks.shape}")
        print(f"   Unique labels: {len(np.unique(masks)) - 1} (excluding background)")
        print(f"   Data type: {masks.dtype}")
        print(f"   Value range: {masks.min()} - {masks.max()}")
        
        # Analyze z-slice content
        print(f"\n📊 Z-SLICE ANALYSIS:")
        z_has_nuclei = []
        z_nuclei_counts = []
        z_pixel_counts = []
        
        for z in range(masks.shape[0]):
            slice_mask = masks[z]
            has_nuclei = np.any(slice_mask > 0)
            nuclei_count = len(np.unique(slice_mask)) - 1
            pixel_count = np.sum(slice_mask > 0)
            
            z_has_nuclei.append(has_nuclei)
            z_nuclei_counts.append(nuclei_count)
            z_pixel_counts.append(pixel_count)
            
            if has_nuclei:
                print(f"   Z-slice {z:2d}: {nuclei_count:3d} nuclei, {pixel_count:6d} pixels")
        
        z_has_nuclei = np.array(z_has_nuclei)
        z_indices_with_nuclei = np.where(z_has_nuclei)[0]
        
        print(f"\n🎯 OPTIMIZATION DECISION:")
        print(f"   Total z-slices: {masks.shape[0]}")
        print(f"   Z-slices with nuclei: {len(z_indices_with_nuclei)}")
        print(f"   First nuclei z-slice: {z_indices_with_nuclei[0] if len(z_indices_with_nuclei) > 0 else 'None'}")
        print(f"   Last nuclei z-slice: {z_indices_with_nuclei[-1] if len(z_indices_with_nuclei) > 0 else 'None'}")
        
        if len(z_indices_with_nuclei) > 0:
            z_min, z_max = z_indices_with_nuclei[0], z_indices_with_nuclei[-1]
            z_range_size = z_max - z_min + 1
            compression = (masks.shape[0] - z_range_size) / masks.shape[0] * 100
            
            print(f"   Optimization range: {z_min} to {z_max} ({z_range_size} slices)")
            print(f"   Compression: {compression:.1f}%")
            
            # Check for potential issues
            print(f"\n⚠️  POTENTIAL ISSUES CHECK:")
            
            # Check for isolated nuclei at edges
            if z_indices_with_nuclei[0] == 0:
                print(f"   🚨 Nuclei found at very first z-slice (z=0) - might be noise")
            if z_indices_with_nuclei[-1] == masks.shape[0] - 1:
                print(f"   🚨 Nuclei found at very last z-slice (z={masks.shape[0]-1}) - might be noise")
            
            # Check for sparse nuclei at edges
            edge_buffer = 3
            if len(z_indices_with_nuclei) > edge_buffer * 2:
                first_few = z_indices_with_nuclei[:edge_buffer]
                last_few = z_indices_with_nuclei[-edge_buffer:]
                
                first_few_counts = [z_nuclei_counts[z] for z in first_few]
                last_few_counts = [z_nuclei_counts[z] for z in last_few]
                
                if max(first_few_counts) <= 2:
                    print(f"   ⚠️  Very few nuclei in first {edge_buffer} slices: {first_few_counts}")
                if max(last_few_counts) <= 2:
                    print(f"   ⚠️  Very few nuclei in last {edge_buffer} slices: {last_few_counts}")
            
            # Test optimization
            print(f"\n🧪 TESTING OPTIMIZATION:")
            fake_image = np.zeros_like(masks)  # Dummy image for testing
            
            optimized_image, optimized_masks, metadata = optimize_z_slices_after_segmentation(
                fake_image, masks, strategy='z_range', verbose=True
            )
            
            print(f"\n📈 OPTIMIZATION RESULTS:")
            print(f"   Original shape: {masks.shape}")
            print(f"   Optimized shape: {optimized_masks.shape}")
            print(f"   Compression: {metadata.get('compression_ratio_percent', 0):.1f}%")
            
            # Create visualization
            fig, axes = plt.subplots(2, 3, figsize=(15, 10))
            
            # Original masks - max projection and middle slice
            original_max_proj = np.max(masks, axis=0)
            original_middle = masks[masks.shape[0] // 2]
            
            axes[0, 0].imshow(original_max_proj, cmap='tab20')
            axes[0, 0].set_title(f'Original Max Projection\\n{len(np.unique(masks))-1} nuclei')
            axes[0, 0].axis('off')
            
            axes[0, 1].imshow(original_middle, cmap='tab20')
            axes[0, 1].set_title(f'Original Middle Slice (z={masks.shape[0]//2})')
            axes[0, 1].axis('off')
            
            # Z-profile plot
            axes[0, 2].plot(z_nuclei_counts, 'b-o', markersize=3)
            axes[0, 2].axvline(z_min, color='red', linestyle='--', label=f'Start (z={z_min})')
            axes[0, 2].axvline(z_max, color='red', linestyle='--', label=f'End (z={z_max})')
            axes[0, 2].set_xlabel('Z-slice')
            axes[0, 2].set_ylabel('Nuclei count')
            axes[0, 2].set_title('Original Z-Profile')
            axes[0, 2].legend()
            axes[0, 2].grid(True, alpha=0.3)
            
            # Optimized masks
            opt_max_proj = np.max(optimized_masks, axis=0)
            opt_middle = optimized_masks[optimized_masks.shape[0] // 2]
            
            axes[1, 0].imshow(opt_max_proj, cmap='tab20')
            axes[1, 0].set_title(f'Optimized Max Projection\\n{len(np.unique(optimized_masks))-1} nuclei')
            axes[1, 0].axis('off')
            
            axes[1, 1].imshow(opt_middle, cmap='tab20')
            axes[1, 1].set_title(f'Optimized Middle Slice (z={optimized_masks.shape[0]//2})')
            axes[1, 1].axis('off')
            
            # Optimized Z-profile
            opt_z_nuclei_counts = []
            for z in range(optimized_masks.shape[0]):
                slice_mask = optimized_masks[z]
                nuclei_count = len(np.unique(slice_mask)) - 1
                opt_z_nuclei_counts.append(nuclei_count)
            
            axes[1, 2].plot(opt_z_nuclei_counts, 'g-o', markersize=3)
            axes[1, 2].set_xlabel('Z-slice (optimized)')
            axes[1, 2].set_ylabel('Nuclei count')
            axes[1, 2].set_title(f'Optimized Z-Profile\\n({optimized_masks.shape[0]} slices)')
            axes[1, 2].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
            
            return {
                'original_shape': masks.shape,
                'optimized_shape': optimized_masks.shape,
                'z_indices_with_nuclei': z_indices_with_nuclei,
                'z_nuclei_counts': z_nuclei_counts,
                'compression_percent': metadata.get('compression_ratio_percent', 0),
                'metadata': metadata
            }
        
    else:
        print(f"❌ No masks found for {organoid_key}")
        return None

# Run diagnostic on a specific organoid
print("🔍 DIAGNOSTIC TEST - Z-SLICE OPTIMIZATION")
print("=" * 70)

# Test on the organoid you just processed
diagnostic_results = analyze_z_slice_optimization('Condition_D25T75', 0)

print("\\n💡 INTERPRETATION:")
print("→ Look for very low nuclei counts at edges (might be noise)")
print("→ Check if optimization range makes biological sense")
print("→ Compare max projections - should look identical")
print("→ Z-profiles should show where nuclei actually are")


In [None]:
# 🔧 IMPROVED Z-SLICE OPTIMIZATION (TESTING)
import numpy as np
import zarr
import matplotlib.pyplot as plt
from pathlib import Path
# from karyosight.segmentation_SAM import optimize_z_slices_after_segmentation


def improved_z_slice_optimization(
    image_3d: np.ndarray,
    masks_3d: np.ndarray,
    strategy: str = 'z_range_smart',
    min_nuclei_per_slice: int = 3,
    min_pixels_per_slice: int = 500,
    edge_buffer: int = 2,
    verbose: bool = False
) -> tuple:
    """
    Improved z-slice optimization with better edge detection
    
    Args:
        image_3d: Original 3D image [Z, Y, X]
        masks_3d: Segmentation masks [Z, Y, X] 
        strategy: Optimization strategy
        min_nuclei_per_slice: Minimum nuclei to consider slice "valid"
        min_pixels_per_slice: Minimum pixels to consider slice "valid"
        edge_buffer: Buffer slices to keep around valid region
        verbose: Print processing information
        
    Returns:
        Tuple of (optimized_image, optimized_masks, metadata)
    """
    if verbose:
        print(f"✂️ IMPROVED z-slice optimization (strategy={strategy})")
    
    original_z, ny, nx = masks_3d.shape
    
    # Analyze each z-slice more carefully
    z_valid = []
    z_nuclei_counts = []
    z_pixel_counts = []
    
    for z in range(original_z):
        slice_mask = masks_3d[z]
        nuclei_count = len(np.unique(slice_mask)) - 1  # Exclude background
        pixel_count = np.sum(slice_mask > 0)
        
        # More stringent validation
        is_valid = (nuclei_count >= min_nuclei_per_slice) and (pixel_count >= min_pixels_per_slice)
        
        z_valid.append(is_valid)
        z_nuclei_counts.append(nuclei_count)
        z_pixel_counts.append(pixel_count)
        
        if verbose and (nuclei_count > 0 or pixel_count > 0):
            validity = "✅ VALID" if is_valid else "❌ INVALID (noise?)"
            print(f"   Z-slice {z:2d}: {nuclei_count:3d} nuclei, {pixel_count:6d} pixels - {validity}")
    
    z_valid = np.array(z_valid)
    valid_z_indices = np.where(z_valid)[0]
    
    if len(valid_z_indices) == 0:
        if verbose:
            print(f"   ⚠️  No valid z-slices found - keeping original data")
        return image_3d, masks_3d, {
            'z_optimization_applied': False,
            'reason': 'no_valid_slices',
            'original_z_count': original_z,
            'optimized_z_count': original_z
        }
    
    # Apply strategy with buffer
    if strategy == 'z_range_smart':
        # Keep z-range from first to last VALID slice + buffer
        z_min = max(0, valid_z_indices[0] - edge_buffer)
        z_max = min(original_z - 1, valid_z_indices[-1] + edge_buffer)
        z_slices_to_keep = list(range(z_min, z_max + 1))
        
    elif strategy == 'valid_only':
        # Keep only valid slices
        z_slices_to_keep = valid_z_indices.tolist()
        
    else:
        # Fallback to original method
        z_min, z_max = valid_z_indices[0], valid_z_indices[-1]
        z_slices_to_keep = list(range(z_min, z_max + 1))
    
    # Apply optimization
    optimized_image = image_3d[z_slices_to_keep]
    optimized_masks = masks_3d[z_slices_to_keep]
    
    compression_ratio = (original_z - len(z_slices_to_keep)) / original_z * 100
    
    if verbose:
        print(f"   ✅ IMPROVED optimization complete:")
        print(f"      → Original z-slices: {original_z}")
        print(f"      → Valid z-slices: {len(valid_z_indices)} (strict criteria)")
        print(f"      → Optimized z-slices: {len(z_slices_to_keep)}")
        print(f"      → Removed z-slices: {original_z - len(z_slices_to_keep)}")
        print(f"      → Compression: {compression_ratio:.1f}%")
        print(f"      → Valid z-range: {valid_z_indices[0] if len(valid_z_indices) > 0 else 'None'}-{valid_z_indices[-1] if len(valid_z_indices) > 0 else 'None'}")
        print(f"      → Kept z-range: {z_slices_to_keep[0]}-{z_slices_to_keep[-1]} (with buffer)")
    
    metadata = {
        'z_optimization_applied': True,
        'strategy': strategy,
        'min_nuclei_per_slice': min_nuclei_per_slice,
        'min_pixels_per_slice': min_pixels_per_slice,
        'edge_buffer': edge_buffer,
        'original_z_count': original_z,
        'valid_z_count': len(valid_z_indices),
        'optimized_z_count': len(z_slices_to_keep),
        'removed_z_count': original_z - len(z_slices_to_keep),
        'compression_ratio_percent': compression_ratio,
        'valid_z_range': (int(valid_z_indices[0]), int(valid_z_indices[-1])) if len(valid_z_indices) > 0 else None,
        'kept_z_range': (z_slices_to_keep[0], z_slices_to_keep[-1]),
        'z_slices_kept': z_slices_to_keep,
        'z_nuclei_counts': z_nuclei_counts,
        'z_pixel_counts': z_pixel_counts,
        'z_valid': z_valid.tolist()
    }
    
    return optimized_image, optimized_masks, metadata

def compare_optimization_methods(condition_name, organoid_idx, cropped_dir=Path("D:/LB_TEST3/cropped")):
    """
    Compare original vs improved z-slice optimization
    """
    print(f"🆚 COMPARING OPTIMIZATION METHODS")
    print(f"=" * 60)
    
    # Load masks
    zarr_path = cropped_dir / condition_name / f"{condition_name}_bundled.zarr"
    bundle = zarr.open_group(str(zarr_path), mode='r')
    organoid_key = f"organoid_{organoid_idx:04d}"
    
    if organoid_key in bundle and 'masks' in bundle[organoid_key]:
        masks = bundle[organoid_key]['masks'][:]
        if hasattr(masks, 'compute'):
            masks = masks.compute()
        else:
            masks = np.array(masks)
        
        fake_image = np.zeros_like(masks)
        
        print(f"📊 ORIGINAL METHOD:")
        from karyosight.segmentation_SAM import optimize_z_slices_after_segmentation
        _, orig_opt_masks, orig_metadata = optimize_z_slices_after_segmentation(
            fake_image, masks, strategy='z_range', verbose=True
        )
        
        print(f"\\n📊 IMPROVED METHOD:")
        _, impr_opt_masks, impr_metadata = improved_z_slice_optimization(
            fake_image, masks, 
            strategy='z_range_smart',
            min_nuclei_per_slice=3,      # Require at least 3 nuclei
            min_pixels_per_slice=500,    # Require at least 500 pixels  
            edge_buffer=2,               # Keep 2 buffer slices
            verbose=True
        )
        
        print(f"\\n🔍 COMPARISON SUMMARY:")
        print(f"   Original optimization: {masks.shape[0]} → {orig_opt_masks.shape[0]} slices ({orig_metadata.get('compression_ratio_percent', 0):.1f}% compression)")
        print(f"   Improved optimization: {masks.shape[0]} → {impr_opt_masks.shape[0]} slices ({impr_metadata.get('compression_ratio_percent', 0):.1f}% compression)")
        
        # Visual comparison
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        # Original
        axes[0].imshow(np.max(masks, axis=0), cmap='tab20')
        axes[0].set_title(f'Original Masks\\n{masks.shape[0]} z-slices\\n{len(np.unique(masks))-1} nuclei')
        axes[0].axis('off')
        
        # Original optimization
        axes[1].imshow(np.max(orig_opt_masks, axis=0), cmap='tab20')
        axes[1].set_title(f'Original Optimization\\n{orig_opt_masks.shape[0]} z-slices\\n{orig_metadata.get("compression_ratio_percent", 0):.1f}% compression')
        axes[1].axis('off')
        
        # Improved optimization
        axes[2].imshow(np.max(impr_opt_masks, axis=0), cmap='tab20')
        axes[2].set_title(f'Improved Optimization\\n{impr_opt_masks.shape[0]} z-slices\\n{impr_metadata.get("compression_ratio_percent", 0):.1f}% compression')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        return {
            'original_method': orig_metadata,
            'improved_method': impr_metadata
        }
    
    else:
        print(f"❌ No masks found for {organoid_key}")
        return None

# Test the improved method
print("\\n🔧 TESTING IMPROVED Z-SLICE OPTIMIZATION")
print("=" * 70)

comparison_results = compare_optimization_methods('Condition_D50T50', 0)

print("\\n💡 KEY IMPROVEMENTS:")
print("→ Requires minimum nuclei count per slice (reduces noise sensitivity)")
print("→ Requires minimum pixel count per slice (filters tiny artifacts)")
print("→ Adds buffer slices around valid region (safer edge handling)")
print("→ More detailed validation and reporting")
