# ROI Segmentation Tutorial

This notebook demonstrates how to use the `onem_segment` module for automatic Region of Interest (ROI) segmentation with intelligent 2D/3D model selection.

## üìã Table of Contents
1. [Setup and Imports](#setup)
2. [Image Dimension Analysis](#analysis)
3. [Single Image Segmentation](#single)
4. [Batch Segmentation](#batch)
5. [Model Comparison](#comparison)
6. [Post-processing and Refinement](#postprocessing)
7. [Results Visualization](#visualization)

## üîß Setup and Imports {#setup}

In [None]:
# Core imports
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Medical imaging imports
import nibabel as nib
import SimpleITK as sitk

# Add project root to path
project_root = Path().absolute().parent
sys.path.append(str(project_root))

# Import onem_segment modules
from onem_segment import ROISegmenter
from onem_segment.utils.image_analyzer import ImageDimensionAnalyzer
from onem_segment.config.settings import get_preset_config

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("‚úÖ All modules imported successfully!")
print(f"Project root: {project_root}")

## üîç Image Dimension Analysis {#analysis}

In [None]:
# Initialize image analyzer
analyzer = ImageDimensionAnalyzer()
print("üîç Image dimension analyzer initialized")

# Example image path (replace with your actual file path)
image_path = "sample_data/patient001_ct.nii.gz"

if os.path.exists(image_path):
    print(f"üìÅ Analyzing image: {image_path}")
    
    # Analyze image dimensions and characteristics
    analysis = analyzer.analyze_image(image_path)
    
    print("\nüìä Image Analysis Results:")
    print(f"  Shape: {analysis['shape']}")
    print(f"  Voxel spacing: {analysis['voxel_spacing']}")
    print(f"  Data type: {analysis['data_type']}")
    print(f"  Intensity range: [{analysis['min_intensity']:.2f}, {analysis['max_intensity']:.2f}]")
    
    # 2D/3D recommendation
    print(f"\nüéØ Model Recommendation:")
    print(f"  Recommended mode: {analysis['recommended_mode']}")
    print(f"  Confidence: {analysis['confidence']:.2f}")
    
    # Analysis details
    print(f"\nüî¨ Analysis Details:")
    print(f"  Slice count: {analysis['slice_count']} (threshold: 30)")
    print(f"  Slice thickness: {analysis['slice_thickness']:.2f}mm (threshold: 5mm)")
    print(f"  Content variation: {analysis['content_variation']:.2f} (threshold: 0.10)")
    print(f"  Criteria met: {analysis['criteria_met']}/3")
else:
    print(f"‚ö†Ô∏è  Sample image not found: {image_path}")
    print("Please replace with your actual NIfTI file path")
    
# Create dummy analysis for demonstration
analysis = {
    'shape': (512, 512, 120),
    'voxel_spacing': [0.7, 0.7, 3.0],
    'slice_count': 120,
    'slice_thickness': 3.0,
    'content_variation': 0.15,
    'recommended_mode': '3D',
    'confidence': 0.85,
    'criteria_met': 3
}
print("\nüé≠ Using dummy analysis for demonstration:")
for key, value in analysis.items():
    print(f"  {key}: {value}")

## üéØ Single Image Segmentation {#single}

In [None]:
# Initialize the ROI segmenter
segmenter = ROISegmenter()
print("üéØ ROI segmenter initialized")

# Example image and output paths
image_path = "sample_data/patient001_ct.nii.gz"
output_path = "output/segmentations/patient001_roi.nii.gz"

# Create output directory
os.makedirs(os.path.dirname(output_path), exist_ok=True)

# Check if image exists
if os.path.exists(image_path):
    print(f"üöÄ Starting segmentation...")
    
    # Perform segmentation with automatic model selection
    result = segmenter.segment_image(
        image_path=image_path,
        model_type='auto',  # '2d', '3d', or 'auto'
        output_path=output_path,
        config_name='ct_organ',  # Use CT organ segmentation preset
        confidence_threshold=0.5
    )
    
    print(f"‚úÖ Segmentation completed!")
    print(f"üìä Model used: {result['model_used']}")
    print(f"‚è±Ô∏è  Processing time: {result['processing_time']:.2f}s")
    print(f"üìÅ Output saved to: {result['output_path']}")
    
    # Display segmentation statistics
    if 'statistics' in result:
        stats = result['statistics']
        print(f"\nüìà Segmentation Statistics:")
        print(f"  ROI volume: {stats.get('roi_volume', 'N/A')} voxels")
        print(f"  ROI percentage: {stats.get('roi_percentage', 'N/A')}%")
        print(f"  Connected components: {stats.get('connected_components', 'N/A')}")
        print(f"  Largest component: {stats.get('largest_component_size', 'N/A')} voxels")
else:
    print(f"‚ö†Ô∏è  Sample image not found: {image_path}")
    print("Please replace with your actual NIfTI file path")
    
# Create dummy result for demonstration
result = {
    'model_used': '3D',
    'processing_time': 45.6,
    'output_path': 'output/segmentations/patient001_roi.nii.gz',
    'statistics': {
        'roi_volume': 15420,
        'roi_percentage': 2.8,
        'connected_components': 3,
        'largest_component_size': 12350
    }
}
print("\nüé≠ Using dummy result for demonstration:")
for key, value in result.items():
    if isinstance(value, dict):
        print(f"  {key}:")
        for k, v in value.items():
            print(f"    {k}: {v}")
    else:
        print(f"  {key}: {value}")

## üîÑ Batch Segmentation {#batch}

In [None]:
# Batch processing setup
image_dir = "sample_data/images/"
output_dir = "output/batch_segmentations/"

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# Check if directory exists
if os.path.exists(image_dir):
    print(f"üìÅ Processing images from: {image_dir}")
    print(f"üìÅ Output directory: {output_dir}")
    
    # Perform batch segmentation
    print("üöÄ Starting batch segmentation...")
    results = segmenter.segment_batch(
        image_dir=image_dir,
        output_dir=output_dir,
        model_type='auto',
        config_name='ct_organ',
        parallel=True,  # Enable parallel processing
        n_workers=4
    )
    
    print(f"‚úÖ Batch segmentation completed!")
    print(f"üìä Processed {len(results)} images")
    
    # Summary statistics
    processing_times = [r['processing_time'] for r in results if 'processing_time' in r]
    models_used = [r['model_used'] for r in results if 'model_used' in r]
    
    print(f"\nüìà Batch Processing Summary:")
    print(f"  Total processing time: {sum(processing_times):.2f}s")
    print(f"  Average time per image: {np.mean(processing_times):.2f}s")
    print(f"  2D models used: {models_used.count('2D')}")
    print(f"  3D models used: {models_used.count('3D')}")
    
    # Display individual results
    print(f"\nüëÄ Individual Results:")
    for i, result in enumerate(results[:5]):  # Show first 5
        print(f"  {i+1}. {os.path.basename(result.get('image_path', 'unknown'))}: ")
        print(f"     Model: {result.get('model_used', 'unknown')}, ")
        print(f"     Time: {result.get('processing_time', 'unknown')}s")
    
    if len(results) > 5:
        print(f"  ... and {len(results) - 5} more results")
else:
    print(f"‚ö†Ô∏è  Sample directory not found: {image_dir}")
    print("Please replace with your actual image directory path")
    
# Create dummy batch results for demonstration
dummy_results = [
    {'image_path': 'patient001.nii.gz', 'model_used': '3D', 'processing_time': 45.6},
    {'image_path': 'patient002.nii.gz', 'model_used': '2D', 'processing_time': 23.4},
    {'image_path': 'patient003.nii.gz', 'model_used': '3D', 'processing_time': 67.8},
    {'image_path': 'patient004.nii.gz', 'model_used': '2D', 'processing_time': 18.9},
    {'image_path': 'patient005.nii.gz', 'model_used': '3D', 'processing_time': 52.1}
]

print("\nüé≠ Using dummy batch results for demonstration:")
processing_times = [r['processing_time'] for r in dummy_results]
models_used = [r['model_used'] for r in dummy_results]

print(f"  Total processed: {len(dummy_results)} images")
print(f"  Total time: {sum(processing_times):.2f}s")
print(f"  Average time: {np.mean(processing_times):.2f}s")
print(f"  2D models: {models_used.count('2D')}, 3D models: {models_used.count('3D')}")

## ‚öñÔ∏è Model Comparison {#comparison}

In [None]:
# Compare different models on the same image
image_path = "sample_data/patient001_ct.nii.gz"
output_base = "output/model_comparison/"

# Create output directory
os.makedirs(output_base, exist_ok=True)

if os.path.exists(image_path):
    print(f"üîç Comparing models on: {image_path}")
    
    # Test different model types
    model_types = ['2d', '3d', 'auto']
    comparison_results = {}
    
    for model_type in model_types:
        print(f"\nüöÄ Testing {model_type.upper()} model...")
        
        output_path = os.path.join(output_base, f"patient001_{model_type}_roi.nii.gz")
        
        try:
            result = segmenter.segment_image(
                image_path=image_path,
                model_type=model_type,
                output_path=output_path,
                config_name='ct_organ'
            )
            
            comparison_results[model_type] = result
            print(f"  ‚úÖ Completed in {result['processing_time']:.2f}s")
            print(f"  üìä Model used: {result['model_used']}")
            
            if 'statistics' in result:
                stats = result['statistics']
                print(f"  üìà ROI volume: {stats.get('roi_volume', 'N/A')} voxels")
                
        except Exception as e:
            print(f"  ‚ùå Error: {e}")
            comparison_results[model_type] = {'error': str(e)}
    
    # Create comparison table
    print("\nüìä Model Comparison Summary:")
    comparison_data = []
    
    for model_type, result in comparison_results.items():
        if 'error' not in result:
            comparison_data.append({
                'Model Type': model_type.upper(),
                'Actual Model': result.get('model_used', 'N/A'),
                'Processing Time (s)': result.get('processing_time', 0),
                'ROI Volume': result.get('statistics', {}).get('roi_volume', 0)
            })
        else:
            comparison_data.append({
                'Model Type': model_type.upper(),
                'Actual Model': 'Error',
                'Processing Time (s)': 0,
                'ROI Volume': 0
            })
    
    comparison_df = pd.DataFrame(comparison_data)
    display(comparison_df)
else:
    print(f"‚ö†Ô∏è  Sample image not found: {image_path}")
    
# Create dummy comparison for demonstration
dummy_comparison = pd.DataFrame({
    'Model Type': ['2D', '3D', 'AUTO'],
    'Actual Model': ['2D', '3D', '3D'],
    'Processing Time (s)': [23.4, 67.8, 45.6],
    'ROI Volume': [15420, 16234, 15987]
})

print("\nüé≠ Using dummy comparison for demonstration:")
display(dummy_comparison)

## üîß Post-processing and Refinement {#postprocessing}

In [None]:
# Demonstrate post-processing options
if os.path.exists(image_path) and 'result' in locals():
    print("üîß Demonstrating post-processing options...")
    
    # 1. Connected component analysis
    print("\nüîç Connected Component Analysis:")
    cca_result = segmenter.apply_connected_component_analysis(
        result['output_path'],
        min_size=100,  # Minimum component size
        keep_largest=True
    )
    print(f"  Original components: {cca_result['original_components']}")
    print(f"  Filtered components: {cca_result['filtered_components']}")
    print(f"  Largest component size: {cca_result['largest_component_size']}")
    
    # 2. Morphological operations
    print("\nüî® Morphological Operations:")
    morph_result = segmenter.apply_morphological_operations(
        result['output_path'],
        operation='closing',  # 'opening', 'closing', 'erosion', 'dilation'
        kernel_size=3,
        iterations=2
    )
    print(f"  Operation: {morph_result['operation']}")
    print(f"  Kernel size: {morph_result['kernel_size']}")
    print(f"  Volume change: {morph_result['volume_change']:.2f}%")
    
    # 3. Boundary refinement
    print("\nüéØ Boundary Refinement:")
    boundary_result = segmenter.refine_boundaries(
        result['output_path'],
        image_path=image_path,
        method='active_contour',  # 'active_contour', 'graph_cut'
        iterations=50
    )
    print(f"  Method: {boundary_result['method']}")
    print(f"  Iterations: {boundary_result['iterations']}")
    print(f"  Convergence: {boundary_result['convergence']:.4f}")
    print(f"  Boundary change: {boundary_result['boundary_change']:.2f}%")
else:
    print("‚ö†Ô∏è  No segmentation result available for post-processing")
    
# Create dummy post-processing results for demonstration
dummy_cca = {
    'original_components': 5,
    'filtered_components': 2,
    'largest_component_size': 12350
}

dummy_morph = {
    'operation': 'closing',
    'kernel_size': 3,
    'volume_change': -2.3
}

dummy_boundary = {
    'method': 'active_contour',
    'iterations': 50,
    'convergence': 0.0001,
    'boundary_change': 5.7
}

print("\nüé≠ Using dummy post-processing results for demonstration:")
print("\nüîç Connected Component Analysis:")
for key, value in dummy_cca.items():
    print(f"  {key}: {value}")

print("\nüî® Morphological Operations:")
for key, value in dummy_morph.items():
    print(f"  {key}: {value}")

print("\nüéØ Boundary Refinement:")
for key, value in dummy_boundary.items():
    print(f"  {key}: {value}")

## üìä Results Visualization {#visualization}

In [None]:
# Create visualizations of segmentation results
if 'results' in locals() or 'dummy_results' in locals():
    # Use actual or dummy results
    batch_results = results if 'results' in locals() else dummy_results
    
    # 1. Processing time comparison
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('ROI Segmentation Analysis', fontsize=16, fontweight='bold')
    
    # Processing times
    processing_times = [r.get('processing_time', 0) for r in batch_results]
    image_names = [os.path.basename(r.get('image_path', f'Image {i+1}')) 
                   for i, r in enumerate(batch_results)]
    
    axes[0, 0].bar(range(len(image_names)), processing_times, color='skyblue', alpha=0.7)
    axes[0, 0].set_title('Processing Time per Image')
    axes[0, 0].set_xlabel('Image')
    axes[0, 0].set_ylabel('Time (seconds)')
    axes[0, 0].set_xticks(range(len(image_names)))
    axes[0, 0].set_xticklabels(image_names, rotation=45, ha='right')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Model type distribution
    models_used = [r.get('model_used', 'Unknown') for r in batch_results]
    model_counts = pd.Series(models_used).value_counts()
    
    axes[0, 1].pie(model_counts.values, labels=model_counts.index, autopct='%1.1f%%', 
                   colors=['lightcoral', 'lightblue'])
    axes[0, 1].set_title('Model Type Distribution')
    
    # Processing time statistics
    if processing_times:
        time_stats = {
            'Min': min(processing_times),
            'Max': max(processing_times),
            'Mean': np.mean(processing_times),
            'Median': np.median(processing_times)
        }
        
        axes[1, 0].bar(time_stats.keys(), time_stats.values(), 
                       color=['green', 'red', 'blue', 'orange'], alpha=0.7)
        axes[1, 0].set_title('Processing Time Statistics')
        axes[1, 0].set_ylabel('Time (seconds)')
        axes[1, 0].grid(True, alpha=0.3)
    
    # ROI volumes (if available)
    roi_volumes = []
    for r in batch_results:
        if 'statistics' in r and 'roi_volume' in r['statistics']:
            roi_volumes.append(r['statistics']['roi_volume'])
    
    if roi_volumes:
        axes[1, 1].hist(roi_volumes, bins=10, alpha=0.7, color='gold', edgecolor='black')
        axes[1, 1].set_title('ROI Volume Distribution')
        axes[1, 1].set_xlabel('ROI Volume (voxels)')
        axes[1, 1].set_ylabel('Frequency')
        axes[1, 1].grid(True, alpha=0.3)
    else:
        axes[1, 1].text(0.5, 0.5, 'ROI Volume Data\nNot Available', 
                       ha='center', va='center', transform=axes[1, 1].transAxes)
        axes[1, 1].set_title('ROI Volume Distribution')
    
    plt.tight_layout()
    plt.show()
    
    # 2. Model comparison heatmap
    if 'comparison_df' in locals() or 'dummy_comparison' in locals():
        comp_data = comparison_df if 'comparison_df' in locals() else dummy_comparison
        
        plt.figure(figsize=(10, 6))
        
        # Create comparison metrics
        metrics = ['Processing Time (s)', 'ROI Volume']
        comparison_matrix = comp_data[metrics].values
        
        # Normalize for better visualization
        normalized_matrix = comparison_matrix / comparison_matrix.max(axis=0)
        
        sns.heatmap(normalized_matrix.T, 
                   annot=comp_data[metrics].T, 
                   xticklabels=comp_data['Model Type'],
                   yticklabels=metrics,
                   cmap='YlOrRd', 
                   fmt='.1f',
                   cbar_kws={'label': 'Normalized Value'})
        
        plt.title('Model Performance Comparison', fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()
else:
    print("‚ö†Ô∏è  No results available for visualization")
    
# Create sample visualizations with dummy data
print("üìä Creating sample visualizations...")

# Sample data
sample_times = [23.4, 67.8, 45.6, 18.9, 52.1]
sample_images = ['Patient 001', 'Patient 002', 'Patient 003', 'Patient 004', 'Patient 005']

fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('ROI Segmentation Sample Analysis', fontsize=16, fontweight='bold')

# Processing time bar chart
axes[0, 0].bar(sample_images, sample_times, color='skyblue', alpha=0.7)
axes[0, 0].set_title('Processing Time per Patient')
axes[0, 0].set_ylabel('Time (seconds)')
axes[0, 0].tick_params(axis='x', rotation=45)
axes[0, 0].grid(True, alpha=0.3)

# Time distribution histogram
axes[0, 1].hist(sample_times, bins=5, alpha=0.7, color='lightgreen', edgecolor='black')
axes[0, 1].set_title('Processing Time Distribution')
axes[0, 1].set_xlabel('Time (seconds)')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].grid(True, alpha=0.3)

# Time statistics
time_stats = [min(sample_times), max(sample_times), np.mean(sample_times), np.median(sample_times)]
stat_labels = ['Min', 'Max', 'Mean', 'Median']
stat_colors = ['green', 'red', 'blue', 'orange']

axes[1, 0].bar(stat_labels, time_stats, color=stat_colors, alpha=0.7)
axes[1, 0].set_title('Processing Time Statistics')
axes[1, 0].set_ylabel('Time (seconds)')
axes[1, 0].grid(True, alpha=0.3)

# Model performance comparison
models = ['2D', '3D', 'Auto']
avg_times = [21.15, 59.8, 40.35]  # Average times

axes[1, 1].bar(models, avg_times, color=['lightcoral', 'lightblue', 'gold'], alpha=0.7)
axes[1, 1].set_title('Average Time by Model Type')
axes[1, 1].set_ylabel('Average Time (seconds)')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## üéØ Summary and Best Practices

### Key Takeaways:
1. **Auto Model Selection**: The 'auto' mode intelligently chooses between 2D and 3D models
2. **Processing Time**: 2D models are faster (~3x) but 3D models may be more accurate for thick volumes
3. **Post-processing**: Essential for cleaning up segmentation results
4. **Batch Processing**: Use parallel processing for multiple images
5. **Quality Control**: Always visualize and validate segmentation results

### Decision Criteria for Model Selection:
- **Slice Count**: ‚â•30 slices favors 3D
- **Slice Thickness**: ‚â§5mm favors 3D  
- **Content Variation**: ‚â•10% variation favors 3D
- **Confidence**: Higher confidence means more reliable auto-selection

### Common Issues and Solutions:
- ‚ö†Ô∏è **Memory errors** with large 3D images ‚Üí Use 2D mode or increase memory
- ‚ö†Ô∏è **Over-segmentation** ‚Üí Apply connected component analysis
- ‚ö†Ô∏è **Under-segmentation** ‚Üí Try different confidence thresholds
- ‚ö†Ô∏è **Noisy boundaries** ‚Üí Apply morphological operations

### Next Steps:
- üîó Combine segmentation with radiomics extraction
- üìä Use segmentation masks for feature extraction
- üß™ Validate segmentation with expert annotations
- üéØ Fine-tune models for specific anatomical structures