In [None]:
import warnings
warnings.filterwarnings('ignore')

import json
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from skimage import io, measure, filters


class NucleiSegmentationPipeline:
    """
    Automated nuclei segmentation and phenotyping pipeline (low VRAM & CPU fallback).
    Optimized for batch dataset processing.
    """

    def __init__(self, output_dir='results'):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        (self.output_dir / 'plots').mkdir(exist_ok=True)

        self.results = []
        self.all_features = []  # To accumulate feature DataFrames across images

        from cellpose import models
        # Use GPU if available
        self.model = models.CellposeModel(gpu=True)
        print("✅ CellPose model loaded (GPU mode)")

    # ------------------------------------------------------------------------
    # Image preprocessing
    # ------------------------------------------------------------------------
    def preprocess_image(self, image):
        """Normalize and enhance image quality for CellPose"""
        # Handle RGBA or RGB input
        if len(image.shape) == 3:
            # Drop alpha if present
            if image.shape[2] == 4:
                image = image[:, :, :3]
            # Convert to grayscale by averaging
            image = np.mean(image, axis=2)
    
        image = image.astype(np.float32)
        image = (image - image.min()) / (image.max() - image.min() + 1e-8)
        image = filters.gaussian(image, sigma=1.0)
        return image


    # ------------------------------------------------------------------------
    # Segmentation
    # ------------------------------------------------------------------------
    def segment_cellpose(self, image):
        """Run CellPose segmentation (compatible with v3.x and v4.x)"""
        image_uint8 = (image * 255).astype(np.uint8)
        
        try:
            outputs = self.model.eval(
                image_uint8,
                diameter=30,
                channels=[0, 0]
            )
            
            # Handle both 3- and 4-return signatures
            if len(outputs) == 4:
                masks, flows, styles, diams = outputs
            elif len(outputs) == 3:
                masks, flows, styles = outputs
            else:
                raise ValueError(f"Unexpected number of return values from CellPose: {len(outputs)}")
            
            return masks
        
        except Exception as e:
            print(f"❌ Error in CellPose segmentation: {e}")
            return np.zeros_like(image_uint8)


    # ------------------------------------------------------------------------
    # Feature extraction
    # ------------------------------------------------------------------------
    def extract_features(self, image, labels, image_name):
        """Extract per-nucleus features and tag with image name"""
        props = measure.regionprops_table(
            labels,
            intensity_image=image,
            properties=[
                'label', 'area', 'perimeter', 'eccentricity',
                'solidity', 'intensity_mean', 'intensity_max',
                'centroid', 'major_axis_length', 'minor_axis_length'
            ]
        )
        df = pd.DataFrame(props)
        df['circularity'] = (4 * np.pi * df['area']) / (df['perimeter'] ** 2 + 1e-8)
        df['aspect_ratio'] = df['major_axis_length'] / (df['minor_axis_length'] + 1e-8)
        df['image_name'] = image_name
        return df

    # ------------------------------------------------------------------------
    # Population metrics
    # ------------------------------------------------------------------------
    def calculate_metrics(self, labels):
        regions = measure.regionprops(labels)
        if not regions:
            return {'num_nuclei': 0, 'mean_area': 0, 'std_area': 0, 'density': 0}

        num_nuclei = len(regions)
        areas = [r.area for r in regions]
        metrics = {
            'num_nuclei': num_nuclei,
            'mean_area': np.mean(areas),
            'std_area': np.std(areas),
            'density': num_nuclei / (labels.shape[0] * labels.shape[1]) * 1e6  # per mm²
        }
        return metrics

    # ------------------------------------------------------------------------
    # Visualization
    # ------------------------------------------------------------------------
    def visualize_results(self, image, labels, features_df, save_path):
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))

        # Original
        axes[0, 0].imshow(image, cmap='gray')
        axes[0, 0].set_title('Original')
        axes[0, 0].axis('off')

        # Segmentation overlay
        axes[0, 1].imshow(image, cmap='gray')
        axes[0, 1].imshow(labels, cmap='tab20', alpha=0.5)
        axes[0, 1].set_title(f'Segmentation ({len(features_df)} nuclei)')
        axes[0, 1].axis('off')

        # Labeled
        axes[0, 2].imshow(labels, cmap='nipy_spectral')
        axes[0, 2].set_title('Labeled Regions')
        axes[0, 2].axis('off')

        # Area distribution
        axes[1, 0].hist(features_df['area'], bins=30, edgecolor='black')
        axes[1, 0].set_title('Area Distribution')
        axes[1, 0].set_xlabel('Area (px²)')

        # Circularity distribution
        axes[1, 1].hist(features_df['circularity'], bins=30, edgecolor='black')
        axes[1, 1].set_title('Circularity Distribution')

        # Scatter
        axes[1, 2].scatter(features_df['area'], features_df['intensity_mean'], alpha=0.6, s=15)
        axes[1, 2].set_xlabel('Area')
        axes[1, 2].set_ylabel('Mean Intensity')
        axes[1, 2].set_title('Area vs Intensity')

        plt.tight_layout()
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()

    # ------------------------------------------------------------------------
    # Image processing entry point
    # ------------------------------------------------------------------------
    def process_image(self, image_path, save_results=True):
        image = io.imread(image_path)
        image = self.preprocess_image(image)
        labels = self.segment_cellpose(image)

        image_name = Path(image_path).stem
        features = self.extract_features(image, labels, image_name)
        metrics = self.calculate_metrics(labels)

        if save_results:
            plot_path = self.output_dir / 'plots' / f'{image_name}_analysis.png'
            mask_path = self.output_dir / 'plots' / f'{image_name}_mask.tif'

            self.visualize_results(image, labels, features, plot_path)
            io.imsave(mask_path, labels.astype(np.uint16))

        # Accumulate results
        self.results.append({
            'image': image_name,
            **metrics
        })
        self.all_features.append(features)

    # ------------------------------------------------------------------------
    # Dataset processing (recursive search)
    # ------------------------------------------------------------------------
    def process_dataset(self, image_dir, pattern='*.tif'):
        """Recursively scan directories for images and process all."""
        image_paths = list(Path(image_dir).rglob(pattern))
        if not image_paths:
            print(f"No images found matching {pattern} in {image_dir}")
            return

        print(f"📁 Found {len(image_paths)} images under {image_dir}")
        cnt = 0

        for img_path in image_paths:
            try:
                self.process_image(str(img_path))
                cnt += 1
                print(f"{cnt} out of {len(image_paths)} is successfully loaded.", end='\r', flush=True)
            except Exception as e:
                print(f"❌ Error processing {img_path}: {e}")

        # Save aggregated results
        print()
        self.save_combined_results()

    # ------------------------------------------------------------------------
    # Aggregated results saving
    # ------------------------------------------------------------------------
def save_combined_results(self, drop_exact_duplicates=False):
    # Columns you want to keep
    columns = ['image_name','area','perimeter','eccentricity','solidity',
               'intensity_mean','intensity_max','centroid-0','centroid-1',
               'major_axis_length','minor_axis_length','circularity','aspect_ratio']
    
    # Concatenate all DataFrames and keep only the desired columns
    df = pd.concat([df_image[columns] for df_image in pipeline.all_features], ignore_index=True)
    
    # Save to CSV
    df.to_csv('features.csv', index=False)
    
    print("CSV saved as features.csv")

    print(f"✅ Created new combined CSV at: {out_path} ({len(new_df)} rows)")

    # Save summary and dataset plots even if combined_df is empty
    summary = {
        'total_images': len(self.results),
        'total_nuclei': sum(r['num_nuclei'] for r in self.results),
        'mean_nuclei_per_image': float(np.mean([r['num_nuclei'] for r in self.results])) if self.results else 0.0,
        'mean_nucleus_area': float(np.mean([r['mean_area'] for r in self.results])) if self.results else 0.0,
        'processing_date': datetime.now().isoformat()
    }
    with open(self.output_dir / 'summary_report.json', 'w') as f:
        json.dump(summary, f, indent=2)

    self.plot_dataset_summary()
    print("🧾 Summary report saved.")

    return 

    # ------------------------------------------------------------------------
    # Dataset-level summary visualization
    # ------------------------------------------------------------------------
    def plot_dataset_summary(self):
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))

        counts = [r['num_nuclei'] for r in self.results]
        axes[0].bar(range(len(counts)), counts)
        axes[0].set_title('Nuclei per Image')
        axes[0].set_xlabel('Image Index')
        axes[0].set_ylabel('Count')

        areas = [r['mean_area'] for r in self.results]
        axes[1].hist(areas, bins=20, edgecolor='black')
        axes[1].set_title('Mean Nucleus Area Distribution')
        axes[1].set_xlabel('Area')

        densities = [r['density'] for r in self.results]
        axes[2].plot(densities, marker='o')
        axes[2].set_title('Cell Density Variation')
        axes[2].set_xlabel('Image Index')
        axes[2].set_ylabel('Density (nuclei/mm²)')

        plt.tight_layout()
        plt.savefig(self.output_dir / 'dataset_summary.png', dpi=150)
        plt.close()


In [None]:
if __name__ == "__main__":
    # Example: Process single image
    pipeline = NucleiSegmentationPipeline()
    
    # For BBBC038 dataset:
    pipeline.process_dataset('/home/amon/Cell-Segmentation-Morphology/dataset', pattern='*/images/*.png')
    


In [29]:
import pandas as pd

# Columns you want to keep
columns = ['image_name','area','perimeter','eccentricity','solidity',
           'intensity_mean','intensity_max','centroid-0','centroid-1',
           'major_axis_length','minor_axis_length','circularity','aspect_ratio']

# Concatenate all DataFrames and keep only the desired columns
df = pd.concat([df_image[columns] for df_image in pipeline.all_features], ignore_index=True)

# Save to CSV
df.to_csv('features.csv', index=False)

print("CSV saved as features.csv")


CSV saved as features.csv


In [30]:
pipeline.all_features[0].head()

Unnamed: 0,label,area,perimeter,eccentricity,solidity,intensity_mean,intensity_max,centroid-0,centroid-1,major_axis_length,minor_axis_length,circularity,aspect_ratio,image_name
0,1,110.0,41.142136,0.856954,0.916667,0.311882,0.443204,3.263636,205.181818,16.662239,8.587606,0.816637,1.940266,fc5452f612a0f972fe55cc677055ede662af6723b5c161...
1,2,187.0,48.627417,0.608791,0.973958,0.202061,0.29121,8.427807,153.438503,17.333885,13.7515,0.993777,1.260509,fc5452f612a0f972fe55cc677055ede662af6723b5c161...
2,3,268.0,59.698485,0.447635,0.950355,0.291927,0.426615,15.145522,208.884328,19.609088,17.534767,0.94497,1.118298,fc5452f612a0f972fe55cc677055ede662af6723b5c161...
3,4,183.0,48.384776,0.630578,0.973404,0.337899,0.618392,14.825137,249.491803,17.436687,13.533063,0.982298,1.288451,fc5452f612a0f972fe55cc677055ede662af6723b5c161...
4,5,303.0,63.112698,0.436021,0.961905,0.325888,0.528196,23.745875,116.792079,20.71477,18.641974,0.955914,1.11119,fc5452f612a0f972fe55cc677055ede662af6723b5c161...
