In [1]:
import warnings
warnings.filterwarnings('ignore')

import json
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from skimage import io, measure, filters
from cellpose import models

In [None]:
class NucleiSegmentationPipeline:
    """
    Automated nuclei segmentation and phenotyping pipeline (low VRAM & CPU fallback).
    Optimized for batch dataset processing.
    """

    def __init__(self, output_dir='results'):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        (self.output_dir / 'plots').mkdir(exist_ok=True)

        self.results = []
        self.all_features = []  # To accumulate feature DataFrames across images

        # Use GPU if available
        self.model = models.CellposeModel(gpu=True)
        print("✅ CellPose model loaded (GPU mode)")

    # ------------------------------------------------------------------------
    # Image preprocessing
    # ------------------------------------------------------------------------
    def preprocess_image(self, image):
        """Normalize and enhance image quality for CellPose"""
        # Handle RGBA or RGB input
        if len(image.shape) == 3:
            # Drop alpha if present
            if image.shape[2] == 4:
                image = image[:, :, :3]
            # Convert to grayscale by averaging
            image = np.mean(image, axis=2)
    
        image = image.astype(np.float32)
        image = (image - image.min()) / (image.max() - image.min() + 1e-8)
        image = filters.gaussian(image, sigma=1.0)
        return image


    # ------------------------------------------------------------------------
    # Segmentation
    # ------------------------------------------------------------------------
    def segment_cellpose(self, image):
        """Run CellPose segmentation (compatible with v3.x and v4.x)"""
        image_uint8 = (image * 255).astype(np.uint8)
        
        try:
            outputs = self.model.eval(
                image_uint8,
                diameter=30,
                channels=[0, 0]
            )
            
            # Handle both 3- and 4-return signatures
            if len(outputs) == 4:
                masks, flows, styles, diams = outputs
            elif len(outputs) == 3:
                masks, flows, styles = outputs
            else:
                raise ValueError(f"Unexpected number of return values from CellPose: {len(outputs)}")
            
            return masks
        
        except Exception as e:
            print(f"❌ Error in CellPose segmentation: {e}")
            return np.zeros_like(image_uint8)


    # ------------------------------------------------------------------------
    # Feature extraction
    # ------------------------------------------------------------------------
    def extract_features(self, image, labels, image_name):
        """Extract per-nucleus features and tag with image name"""
        props = measure.regionprops_table(
            labels,
            intensity_image=image,
            properties=[
                'label', 'area', 'perimeter', 'eccentricity',
                'solidity', 'intensity_mean', 'intensity_max',
                'centroid', 'major_axis_length', 'minor_axis_length'
            ]
        )
        df = pd.DataFrame(props)
        df['circularity'] = (4 * np.pi * df['area']) / (df['perimeter'] ** 2 + 1e-8)
        df['aspect_ratio'] = df['major_axis_length'] / (df['minor_axis_length'] + 1e-8)
        df['image_name'] = image_name
        return df

    # ------------------------------------------------------------------------
    # Population metrics
    # ------------------------------------------------------------------------
    def calculate_metrics(self, labels):
        regions = measure.regionprops(labels)
        if not regions:
            return {'num_nuclei': 0, 'mean_area': 0, 'std_area': 0, 'density': 0}

        num_nuclei = len(regions)
        areas = [r.area for r in regions]
        metrics = {
            'num_nuclei': num_nuclei,
            'mean_area': np.mean(areas),
            'std_area': np.std(areas),
            'density': num_nuclei / (labels.shape[0] * labels.shape[1]) * 1e6  # per mm²
        }
        return metrics

    # ------------------------------------------------------------------------
    # Visualization
    # ------------------------------------------------------------------------
    def visualize_results(self, image, labels, features_df, save_path):
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))

        # Original
        axes[0, 0].imshow(image, cmap='gray')
        axes[0, 0].set_title('Original')
        axes[0, 0].axis('off')

        # Segmentation overlay
        axes[0, 1].imshow(image, cmap='gray')
        axes[0, 1].imshow(labels, cmap='tab20', alpha=0.5)
        axes[0, 1].set_title(f'Segmentation ({len(features_df)} nuclei)')
        axes[0, 1].axis('off')

        # Labeled
        axes[0, 2].imshow(labels, cmap='nipy_spectral')
        axes[0, 2].set_title('Labeled Regions')
        axes[0, 2].axis('off')

        # Area distribution
        axes[1, 0].hist(features_df['area'], bins=30, edgecolor='black')
        axes[1, 0].set_title('Area Distribution')
        axes[1, 0].set_xlabel('Area (px²)')

        # Circularity distribution
        axes[1, 1].hist(features_df['circularity'], bins=30, edgecolor='black')
        axes[1, 1].set_title('Circularity Distribution')

        # Scatter
        axes[1, 2].scatter(features_df['area'], features_df['intensity_mean'], alpha=0.6, s=15)
        axes[1, 2].set_xlabel('Area')
        axes[1, 2].set_ylabel('Mean Intensity')
        axes[1, 2].set_title('Area vs Intensity')

        plt.tight_layout()
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()

    # ------------------------------------------------------------------------
    # Image processing entry point
    # ------------------------------------------------------------------------
    def process_image(self, image_path, save_results=True):
        image = io.imread(image_path)
        image = self.preprocess_image(image)
        labels = self.segment_cellpose(image)

        image_name = Path(image_path).stem
        features = self.extract_features(image, labels, image_name)
        metrics = self.calculate_metrics(labels)

        if save_results:
            plot_path = self.output_dir / 'plots' / f'{image_name}_analysis.png'
            mask_path = self.output_dir / 'plots' / f'{image_name}_mask.tif'

            self.visualize_results(image, labels, features, plot_path)
            io.imsave(mask_path, labels.astype(np.uint16))

        # Accumulate results
        self.results.append({
            'image': image_name,
            **metrics
        })
        self.all_features.append(features)

    # ------------------------------------------------------------------------
    # Dataset processing (recursive search)
    # ------------------------------------------------------------------------
    def process_dataset(self, image_dir, pattern='*.tif'):
        """Recursively scan directories for images and process all."""
        image_paths = list(Path(image_dir).rglob(pattern))
        if not image_paths:
            print(f"No images found matching {pattern} in {image_dir}")
            return

        print(f"📁 Found {len(image_paths)} images under {image_dir}")
        cnt = 0

        for img_path in image_paths:
            try:
                self.process_image(str(img_path))
                cnt += 1
                print(f"{cnt} out of {len(image_paths)} is successfully loaded.", end='\r', flush=True)
            except Exception as e:
                print(f"❌ Error processing {img_path}: {e}")

        # Save aggregated results
        print()
        self.save_combined_results()

    # ------------------------------------------------------------------------
    # Aggregated results saving
    # ------------------------------------------------------------------------
def save_combined_results(self, drop_exact_duplicates=False):
    # Columns you want to keep
    columns = ['image_name','area','perimeter','eccentricity','solidity',
               'intensity_mean','intensity_max','centroid-0','centroid-1',
               'major_axis_length','minor_axis_length','circularity','aspect_ratio']
    
    # Concatenate all DataFrames and keep only the desired columns
    df = pd.concat([df_image[columns] for df_image in pipeline.all_features], ignore_index=True)
    
    # Save to CSV
    df.to_csv('features.csv', index=False)
    
    print("CSV saved as features.csv")

    print(f"✅ Created new combined CSV at: {out_path} ({len(new_df)} rows)")

    # Save summary and dataset plots even if combined_df is empty
    summary = {
        'total_images': len(self.results),
        'total_nuclei': sum(r['num_nuclei'] for r in self.results),
        'mean_nuclei_per_image': float(np.mean([r['num_nuclei'] for r in self.results])) if self.results else 0.0,
        'mean_nucleus_area': float(np.mean([r['mean_area'] for r in self.results])) if self.results else 0.0,
        'processing_date': datetime.now().isoformat()
    }
    with open(self.output_dir / 'summary_report.json', 'w') as f:
        json.dump(summary, f, indent=2)

    self.plot_dataset_summary()
    print("🧾 Summary report saved.")

    return 

    # ------------------------------------------------------------------------
    # Dataset-level summary visualization
    # ------------------------------------------------------------------------
    def plot_dataset_summary(self):
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))

        counts = [r['num_nuclei'] for r in self.results]
        axes[0].bar(range(len(counts)), counts)
        axes[0].set_title('Nuclei per Image')
        axes[0].set_xlabel('Image Index')
        axes[0].set_ylabel('Count')

        areas = [r['mean_area'] for r in self.results]
        axes[1].hist(areas, bins=20, edgecolor='black')
        axes[1].set_title('Mean Nucleus Area Distribution')
        axes[1].set_xlabel('Area')

        densities = [r['density'] for r in self.results]
        axes[2].plot(densities, marker='o')
        axes[2].set_title('Cell Density Variation')
        axes[2].set_xlabel('Image Index')
        axes[2].set_ylabel('Density (nuclei/mm²)')

        plt.tight_layout()
        plt.savefig(self.output_dir / 'dataset_summary.png', dpi=150)
        plt.close()


In [None]:
# Example: Process single image
pipeline = NucleiSegmentationPipeline()

# For BBBC038 dataset:
pipeline.process_dataset('/home/amon/Cell-Segmentation-Morphology/dataset', pattern='*/images/*.png')

In [29]:
import pandas as pd

# Columns you want to keep
columns = ['image_name','area','perimeter','eccentricity','solidity',
           'intensity_mean','intensity_max','centroid-0','centroid-1',
           'major_axis_length','minor_axis_length','circularity','aspect_ratio']

# Concatenate all DataFrames and keep only the desired columns
df = pd.concat([df_image[columns] for df_image in pipeline.all_features], ignore_index=True)

# Save to CSV
df.to_csv('features.csv', index=False)

print("CSV saved as features.csv")


CSV saved as features.csv


In [30]:
pipeline.all_features[0].head()

Unnamed: 0,label,area,perimeter,eccentricity,solidity,intensity_mean,intensity_max,centroid-0,centroid-1,major_axis_length,minor_axis_length,circularity,aspect_ratio,image_name
0,1,110.0,41.142136,0.856954,0.916667,0.311882,0.443204,3.263636,205.181818,16.662239,8.587606,0.816637,1.940266,fc5452f612a0f972fe55cc677055ede662af6723b5c161...
1,2,187.0,48.627417,0.608791,0.973958,0.202061,0.29121,8.427807,153.438503,17.333885,13.7515,0.993777,1.260509,fc5452f612a0f972fe55cc677055ede662af6723b5c161...
2,3,268.0,59.698485,0.447635,0.950355,0.291927,0.426615,15.145522,208.884328,19.609088,17.534767,0.94497,1.118298,fc5452f612a0f972fe55cc677055ede662af6723b5c161...
3,4,183.0,48.384776,0.630578,0.973404,0.337899,0.618392,14.825137,249.491803,17.436687,13.533063,0.982298,1.288451,fc5452f612a0f972fe55cc677055ede662af6723b5c161...
4,5,303.0,63.112698,0.436021,0.961905,0.325888,0.528196,23.745875,116.792079,20.71477,18.641974,0.955914,1.11119,fc5452f612a0f972fe55cc677055ede662af6723b5c161...


## Model Evaluation

In [5]:
class SegmentationEvaluator:
    """
    Evaluate segmentation performance against ground truth masks
    """
    
    def __init__(self, output_dir='results/evaluation'):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.evaluation_results = []
    
    def calculate_iou(self, pred_mask, true_mask):
        """
        Calculate Intersection over Union (IoU) for binary masks
        
        Args:
            pred_mask: Predicted segmentation mask
            true_mask: Ground truth mask
            
        Returns:
            IoU score (0-1)
        """
        # Binarize masks
        pred_binary = (pred_mask > 0).astype(np.uint8)
        true_binary = (true_mask > 0).astype(np.uint8)
        
        # Calculate intersection and union
        intersection = np.logical_and(pred_binary, true_binary).sum()
        union = np.logical_or(pred_binary, true_binary).sum()
        
        if union == 0:
            return 0.0
        
        iou = intersection / union
        return float(iou)
    
    def calculate_dice(self, pred_mask, true_mask):
        """
        Calculate Dice coefficient (F1 score for segmentation)
        """
        pred_binary = (pred_mask > 0).astype(np.uint8)
        true_binary = (true_mask > 0).astype(np.uint8)
        
        intersection = np.logical_and(pred_binary, true_binary).sum()
        
        if pred_binary.sum() + true_binary.sum() == 0:
            return 0.0
        
        dice = 2 * intersection / (pred_binary.sum() + true_binary.sum())
        return float(dice)
    
    def calculate_pixel_accuracy(self, pred_mask, true_mask):
        """Calculate pixel-wise accuracy"""
        pred_binary = (pred_mask > 0).astype(np.uint8)
        true_binary = (true_mask > 0).astype(np.uint8)
        
        correct = (pred_binary == true_binary).sum()
        total = pred_binary.size
        
        return float(correct / total)
    
    def calculate_object_level_metrics(self, pred_mask, true_mask):
        """
        Calculate object-level detection metrics
        (True Positives, False Positives, False Negatives at nucleus level)
        """
        from skimage import measure
        
        # Get connected components
        pred_labels = measure.label(pred_mask > 0)
        true_labels = measure.label(true_mask > 0)
        
        n_pred = pred_labels.max()
        n_true = true_labels.max()
        
        # Simple matching: a predicted object is TP if it overlaps with true object
        matched_pred = set()
        matched_true = set()
        
        for i in range(1, n_pred + 1):
            pred_region = (pred_labels == i)
            overlapping_true = np.unique(true_labels[pred_region])
            overlapping_true = overlapping_true[overlapping_true > 0]
            
            if len(overlapping_true) > 0:
                matched_pred.add(i)
                matched_true.update(overlapping_true)
        
        tp = len(matched_pred)
        fp = n_pred - tp
        fn = n_true - len(matched_true)
        
        # Calculate precision, recall, F1
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        return {
            'n_predicted': n_pred,
            'n_true': n_true,
            'true_positives': tp,
            'false_positives': fp,
            'false_negatives': fn,
            'precision': precision,
            'recall': recall,
            'f1_score': f1
        }
    
    def evaluate_single_image(self, pred_mask_path, true_mask_path, image_name=None):
        """
        Evaluate a single prediction against ground truth
        """
        if image_name is None:
            image_name = Path(pred_mask_path).stem
        
        # Load masks
        pred_mask = io.imread(pred_mask_path)
        true_mask = io.imread(true_mask_path)
        
        # Handle RGB masks (take first channel)
        if len(true_mask.shape) == 3:
            true_mask = true_mask[:, :, 0]
        if len(pred_mask.shape) == 3:
            pred_mask = pred_mask[:, :, 0]
        
        # Calculate metrics
        iou = self.calculate_iou(pred_mask, true_mask)
        dice = self.calculate_dice(pred_mask, true_mask)
        pixel_acc = self.calculate_pixel_accuracy(pred_mask, true_mask)
        object_metrics = self.calculate_object_level_metrics(pred_mask, true_mask)
        
        result = {
            'image_name': image_name,
            'iou': iou,
            'dice': dice,
            'pixel_accuracy': pixel_acc,
            **object_metrics
        }
        
        self.evaluation_results.append(result)
        return result
    
    def evaluate_dataset(self, pred_mask_dir, true_mask_dir, pattern='*.png'):
        """
        Evaluate all predictions in a directory against ground truth
        
        Args:
            pred_mask_dir: Directory with predicted masks (e.g., 'results/plots')
            true_mask_dir: Root directory with ground truth (e.g., 'dataset')
        """
        print("🔍 Starting evaluation against ground truth...")
        
        # Find all predicted masks
        pred_mask_paths = list(Path(pred_mask_dir).rglob('*_mask.tif'))
        print(f"Found {len(pred_mask_paths)} predicted masks")
        
        # Match with ground truth
        evaluated = 0
        skipped = 0
        
        for pred_path in pred_mask_paths:
            # Extract image ID (remove '_mask.tif')
            image_id = pred_path.stem.replace('_mask', '')
            
            # Find corresponding ground truth masks
            # BBBC038 structure: dataset/{image_id}/masks/*.png
            true_mask_dir_path = Path(true_mask_dir)
            true_mask_files = list(true_mask_dir_path.glob(f'{image_id}/masks/*.png'))
            
            if not true_mask_files:
                # Try without subdirectory (in case structure is different)
                true_mask_files = list(true_mask_dir_path.glob(f'*/masks/{image_id}*.png'))
            
            if not true_mask_files:
                print(f"⚠️  No ground truth found for {image_id}")
                skipped += 1
                continue
            
            # If multiple masks, combine them (BBBC038 has one mask per nucleus)
            if len(true_mask_files) > 1:
                # Combine all individual masks into one
                combined_mask = None
                for mask_file in true_mask_files:
                    mask = io.imread(mask_file)
                    if len(mask.shape) == 3:
                        mask = mask[:, :, 0]
                    if combined_mask is None:
                        combined_mask = np.zeros_like(mask, dtype=np.uint16)
                    combined_mask = np.maximum(combined_mask, (mask > 0).astype(np.uint16))
                true_mask_path = combined_mask
            else:
                true_mask_path = true_mask_files[0]
            
            # Evaluate
            try:
                if isinstance(true_mask_path, np.ndarray):
                    # Handle combined mask
                    pred_mask = io.imread(pred_path)
                    if len(pred_mask.shape) == 3:
                        pred_mask = pred_mask[:, :, 0]
                    
                    iou = self.calculate_iou(pred_mask, true_mask_path)
                    dice = self.calculate_dice(pred_mask, true_mask_path)
                    pixel_acc = self.calculate_pixel_accuracy(pred_mask, true_mask_path)
                    object_metrics = self.calculate_object_level_metrics(pred_mask, true_mask_path)
                    
                    result = {
                        'image_name': image_id,
                        'iou': iou,
                        'dice': dice,
                        'pixel_accuracy': pixel_acc,
                        **object_metrics
                    }
                    self.evaluation_results.append(result)
                else:
                    self.evaluate_single_image(pred_path, true_mask_path, image_id)
                
                evaluated += 1
                print(f"✓ Evaluated {evaluated}/{len(pred_mask_paths)}", end='\r', flush=True)
            except Exception as e:
                print(f"\n❌ Error evaluating {image_id}: {e}")
                skipped += 1
        
        print(f"\n✅ Evaluation complete: {evaluated} images, {skipped} skipped")
        
        # Generate reports
        self.generate_evaluation_report()
        self.plot_evaluation_results()
        
        return self.evaluation_results
    
    def generate_evaluation_report(self):
        """Generate comprehensive evaluation report"""
        if not self.evaluation_results:
            print("No evaluation results to report")
            return
        
        df = pd.DataFrame(self.evaluation_results)
        
        # Calculate summary statistics
        summary = {
            'total_images_evaluated': len(df),
            'mean_iou': float(df['iou'].mean()),
            'std_iou': float(df['iou'].std()),
            'median_iou': float(df['iou'].median()),
            'min_iou': float(df['iou'].min()),
            'max_iou': float(df['iou'].max()),
            'mean_dice': float(df['dice'].mean()),
            'mean_pixel_accuracy': float(df['pixel_accuracy'].mean()),
            'mean_f1_score': float(df['f1_score'].mean()),
            'mean_precision': float(df['precision'].mean()),
            'mean_recall': float(df['recall'].mean()),
            'total_nuclei_predicted': int(df['n_predicted'].sum()),
            'total_nuclei_ground_truth': int(df['n_true'].sum()),
            'total_true_positives': int(df['true_positives'].sum()),
            'total_false_positives': int(df['false_positives'].sum()),
            'total_false_negatives': int(df['false_negatives'].sum()),
            'evaluation_date': datetime.now().isoformat()
        }
        
        # Save detailed results
        df.to_csv(self.output_dir / 'detailed_evaluation_results.csv', index=False)
        
        # Save summary
        with open(self.output_dir / 'evaluation_summary.json', 'w') as f:
            json.dump(summary, f, indent=2)
        
        # Print summary
        print("\n" + "="*70)
        print("SEGMENTATION EVALUATION SUMMARY")
        print("="*70)
        print(f"Images Evaluated: {summary['total_images_evaluated']}")
        print(f"\nPixel-Level Metrics:")
        print(f"  Mean IoU:            {summary['mean_iou']:.4f} (±{summary['std_iou']:.4f})")
        print(f"  Mean Dice:           {summary['mean_dice']:.4f}")
        print(f"  Mean Pixel Accuracy: {summary['mean_pixel_accuracy']:.4f}")
        print(f"\nObject-Level Metrics:")
        print(f"  Mean Precision:      {summary['mean_precision']:.4f}")
        print(f"  Mean Recall:         {summary['mean_recall']:.4f}")
        print(f"  Mean F1 Score:       {summary['mean_f1_score']:.4f}")
        print(f"\nDetection Counts:")
        print(f"  Ground Truth Nuclei: {summary['total_nuclei_ground_truth']:,}")
        print(f"  Predicted Nuclei:    {summary['total_nuclei_predicted']:,}")
        print(f"  True Positives:      {summary['total_true_positives']:,}")
        print(f"  False Positives:     {summary['total_false_positives']:,}")
        print(f"  False Negatives:     {summary['total_false_negatives']:,}")
        print("="*70)
        
        return summary
    
    def plot_evaluation_results(self):
        """Create comprehensive evaluation visualizations"""
        df = pd.DataFrame(self.evaluation_results)
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # IoU distribution
        axes[0, 0].hist(df['iou'], bins=30, edgecolor='black', alpha=0.7)
        axes[0, 0].axvline(df['iou'].mean(), color='red', linestyle='--', 
                          label=f'Mean: {df["iou"].mean():.3f}')
        axes[0, 0].set_xlabel('IoU Score')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].set_title('IoU Distribution')
        axes[0, 0].legend()
        axes[0, 0].grid(alpha=0.3)
        
        # Dice distribution
        axes[0, 1].hist(df['dice'], bins=30, edgecolor='black', alpha=0.7, color='green')
        axes[0, 1].axvline(df['dice'].mean(), color='red', linestyle='--',
                          label=f'Mean: {df["dice"].mean():.3f}')
        axes[0, 1].set_xlabel('Dice Coefficient')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].set_title('Dice Coefficient Distribution')
        axes[0, 1].legend()
        axes[0, 1].grid(alpha=0.3)
        
        # Precision vs Recall
        axes[0, 2].scatter(df['recall'], df['precision'], alpha=0.6, s=50)
        axes[0, 2].set_xlabel('Recall')
        axes[0, 2].set_ylabel('Precision')
        axes[0, 2].set_title('Precision vs Recall')
        axes[0, 2].grid(alpha=0.3)
        axes[0, 2].plot([0, 1], [0, 1], 'k--', alpha=0.3)
        
        # F1 Score distribution
        axes[1, 0].hist(df['f1_score'], bins=30, edgecolor='black', alpha=0.7, color='orange')
        axes[1, 0].axvline(df['f1_score'].mean(), color='red', linestyle='--',
                          label=f'Mean: {df["f1_score"].mean():.3f}')
        axes[1, 0].set_xlabel('F1 Score')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('F1 Score Distribution')
        axes[1, 0].legend()
        axes[1, 0].grid(alpha=0.3)
        
        # Predicted vs Ground Truth nuclei count
        axes[1, 1].scatter(df['n_true'], df['n_predicted'], alpha=0.6, s=50)
        axes[1, 1].set_xlabel('Ground Truth Count')
        axes[1, 1].set_ylabel('Predicted Count')
        axes[1, 1].set_title('Nuclei Count Comparison')
        axes[1, 1].grid(alpha=0.3)
        # Add perfect prediction line
        max_count = max(df['n_true'].max(), df['n_predicted'].max())
        axes[1, 1].plot([0, max_count], [0, max_count], 'r--', alpha=0.5, label='Perfect')
        axes[1, 1].legend()
        
        # Box plots of metrics
        metrics_data = df[['iou', 'dice', 'pixel_accuracy', 'f1_score']]
        axes[1, 2].boxplot([metrics_data[col] for col in metrics_data.columns],
                          labels=['IoU', 'Dice', 'Pixel Acc', 'F1'])
        axes[1, 2].set_ylabel('Score')
        axes[1, 2].set_title('Metrics Overview')
        axes[1, 2].grid(alpha=0.3, axis='y')
        axes[1, 2].set_ylim([0, 1])
        
        plt.tight_layout()
        plt.savefig(self.output_dir / 'evaluation_plots.png', dpi=150, bbox_inches='tight')
        plt.close()
        
        print(f"✅ Evaluation plots saved to {self.output_dir / 'evaluation_plots.png'}")


# ============================================================================
# Integration with your existing pipeline
# ============================================================================

def add_evaluation_to_pipeline(pipeline, dataset_root='dataset'):
    """
    Add evaluation capability to your existing NucleiSegmentationPipeline
    
    Usage:
        pipeline = NucleiSegmentationPipeline()
        pipeline.process_dataset('dataset', pattern='*.png')
        
        # Then evaluate
        evaluator = add_evaluation_to_pipeline(pipeline, dataset_root='dataset')
        results = evaluator.evaluate_dataset(
            pred_mask_dir='results/plots',
            true_mask_dir='dataset'
        )
    """
    evaluator = SegmentationEvaluator(output_dir=pipeline.output_dir / 'evaluation')
    return evaluator

In [6]:
evaluator = SegmentationEvaluator()

results = evaluator.evaluate_dataset(pred_mask_dir='results/plots',true_mask_dir='dataset')

🔍 Starting evaluation against ground truth...
Found 670 predicted masks
✓ Evaluated 670/670
✅ Evaluation complete: 670 images, 0 skipped

SEGMENTATION EVALUATION SUMMARY
Images Evaluated: 670

Pixel-Level Metrics:
  Mean IoU:            0.8504 (±0.1209)
  Mean Dice:           0.9127
  Mean Pixel Accuracy: 0.9795

Object-Level Metrics:
  Mean Precision:      0.9844
  Mean Recall:         0.9381
  Mean F1 Score:       0.9586

Detection Counts:
  Ground Truth Nuclei: 24,286
  Predicted Nuclei:    23,122
  True Positives:      22,857
  False Positives:     265
  False Negatives:     1,214
✅ Evaluation plots saved to results/evaluation/evaluation_plots.png
