# Nuclear Segmentation using HoverNet

This notebook implements nuclear segmentation for cervical cancer cell classification using the HoverNet model.

**Dataset Structure:**
- Base Directory: `Augmented Dataset - Limited Enhancement`
- Classes: `im_Dyskeratotic`, `im_Koilocytotic`, `im_Metaplastic`, `im_Parabasal`, `im_Superficial-Intermediate`
- Images Location: `<class_folder>/NLM_CLAHE/*.bmp`

**Reference:** [HoverNet GitHub](https://github.com/vqdang/hover_net)

## 1. Environment Setup

In [None]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
    print("✓ Running on Google Colab")
except:
    IN_COLAB = False
    print("ℹ Not running on Google Colab")

In [None]:
# Install required dependencies
if IN_COLAB:
    print("Installing dependencies...")
    !pip install -q opencv-python-headless
    !pip install -q scikit-image
    !pip install -q scipy
    !pip install -q matplotlib
    !pip install -q tqdm
    !pip install -q imageio
    !pip install -q pillow
    
    # Install PyTorch (HoverNet uses PyTorch)
    !pip install -q torch torchvision
    
    print("\n✓ Dependencies installed successfully")
else:
    print("Skipping dependency installation (not on Colab)")

In [None]:
# Clone HoverNet repository
if IN_COLAB:
    import os
    if not os.path.exists('hover_net'):
        print("Cloning HoverNet repository...")
        !git clone -q https://github.com/vqdang/hover_net.git
        print("✓ HoverNet repository cloned")
    else:
        print("✓ HoverNet repository already exists")
    
    # Add HoverNet to Python path
    import sys
    if '/content/hover_net' not in sys.path:
        sys.path.insert(0, '/content/hover_net')
    print("✓ HoverNet added to Python path")
else:
    print("Skipping HoverNet clone (not on Colab)")

## 2. Mount Google Drive

In [None]:
# Mount Google Drive
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    print("\n✓ Google Drive mounted successfully")
else:
    print("Skipping Google Drive mount (not on Colab)")

## 3. Import Libraries

In [None]:
import os
import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import json
from datetime import datetime
import glob
import shutil
import warnings
warnings.filterwarnings('ignore')

# Try importing torch
try:
    import torch
    print(f"✓ PyTorch version: {torch.__version__}")
    print(f"✓ CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"✓ CUDA device: {torch.cuda.get_device_name(0)}")
except ImportError:
    print("⚠ PyTorch not found. Please install it.")

print(f"✓ OpenCV version: {cv2.__version__}")
print(f"✓ NumPy version: {np.__version__}")
print("\n✓ All libraries imported successfully")

## 4. Configuration and Dataset Paths

In [None]:
# ============================================================================
# CONFIGURATION - UPDATE THESE PATHS AS NEEDED
# ============================================================================

# Google Drive base path
DRIVE_BASE_PATH = '/content/drive/MyDrive'

# HoverNet weights directory (USER PROVIDED PATH)
WEIGHTS_DIR = '/content/drive/MyDrive/Projects/6_Project Phoenix_Cervical Cancer Cell Classification/Explainability Worflows/Nucleus Masking/Hovernet Weights'

# Dataset configuration
DATASET_NAME = 'Augmented Dataset - Limited Enhancement'
DATASET_BASE_PATH = os.path.join(DRIVE_BASE_PATH, DATASET_NAME)

# Cell classes
CELL_CLASSES = [
    'im_Dyskeratotic',
    'im_Koilocytotic',
    'im_Metaplastic',
    'im_Parabasal',
    'im_Superficial-Intermediate'
]

# Subfolder containing images
IMAGE_SUBFOLDER = 'NLM_CLAHE'

# Output directory for segmentation results
OUTPUT_BASE_PATH = os.path.join(DRIVE_BASE_PATH, 'HoverNet_Segmentation_Results')
os.makedirs(OUTPUT_BASE_PATH, exist_ok=True)

# Temporary directory for HoverNet processing
TEMP_DIR = '/content/temp_hovernet'
os.makedirs(TEMP_DIR, exist_ok=True)

print("="*70)
print("CONFIGURATION")
print("="*70)
print(f"Weights directory: {WEIGHTS_DIR}")
print(f"Dataset base path: {DATASET_BASE_PATH}")
print(f"Output base path: {OUTPUT_BASE_PATH}")
print(f"Temp directory: {TEMP_DIR}")
print(f"\nCell classes to process: {len(CELL_CLASSES)}")
for i, cell_class in enumerate(CELL_CLASSES, 1):
    print(f"  {i}. {cell_class}")
print("="*70)

## 5. Verify Paths and Resources

In [None]:
# Verify weights directory
def verify_weights():
    """Verify HoverNet weights are available."""
    print("Verifying HoverNet weights...\n")
    
    if not os.path.exists(WEIGHTS_DIR):
        print(f"❌ ERROR: Weights directory not found: {WEIGHTS_DIR}")
        print("\nPlease update WEIGHTS_DIR to the correct path.")
        return False, None
    
    # List all files in weights directory
    weight_files = os.listdir(WEIGHTS_DIR)
    print(f"✓ Weights directory found")
    print(f"\nFiles in weights directory ({len(weight_files)} files):")
    
    model_files = []
    for f in weight_files:
        print(f"  - {f}")
        if f.endswith(('.pth', '.tar', '.ckpt', '.pt')):
            model_files.append(f)
    
    if not model_files:
        print("\n⚠ WARNING: No model files (.pth, .tar, .ckpt, .pt) found in weights directory")
        return False, None
    
    print(f"\n✓ Found {len(model_files)} model file(s)")
    
    # Use the first model file found
    model_path = os.path.join(WEIGHTS_DIR, model_files[0])
    print(f"✓ Will use model: {model_files[0]}")
    
    return True, model_path


# Verify dataset
def verify_dataset():
    """Verify dataset structure and count images per class."""
    print("\n" + "="*70)
    print("Verifying dataset structure...")
    print("="*70 + "\n")
    
    if not os.path.exists(DATASET_BASE_PATH):
        print(f"❌ ERROR: Dataset base path not found: {DATASET_BASE_PATH}")
        print("\nPlease update DATASET_BASE_PATH in the configuration cell.")
        return False, None
    
    total_images = 0
    class_image_counts = {}
    
    for cell_class in CELL_CLASSES:
        class_path = os.path.join(DATASET_BASE_PATH, cell_class)
        image_path = os.path.join(class_path, IMAGE_SUBFOLDER)
        
        if not os.path.exists(image_path):
            print(f"❌ WARNING: Image path not found: {image_path}")
            class_image_counts[cell_class] = 0
            continue
        
        # Count .bmp files
        bmp_files = glob.glob(os.path.join(image_path, '*.bmp'))
        count = len(bmp_files)
        class_image_counts[cell_class] = count
        total_images += count
        
        print(f"✓ {cell_class:40s}: {count:5d} images")
    
    print(f"\n{'='*70}")
    print(f"Total images to process: {total_images}")
    print(f"{'='*70}\n")
    
    return total_images > 0, class_image_counts


# Run verifications
weights_valid, model_path = verify_weights()
dataset_valid, image_counts = verify_dataset()

if weights_valid and dataset_valid:
    print("\n" + "="*70)
    print("✓ ALL VERIFICATIONS PASSED - READY TO PROCESS")
    print("="*70)
else:
    print("\n" + "="*70)
    print("❌ VERIFICATION FAILED - PLEASE FIX ERRORS ABOVE")
    print("="*70)

## 6. HoverNet Model Setup and Inference Functions

In [None]:
# Import HoverNet modules
try:
    # Try to import infer modules from HoverNet
    import importlib.util
    
    # Check if we can load the infer module
    infer_tile_path = '/content/hover_net/infer/tile.py'
    if os.path.exists(infer_tile_path):
        spec = importlib.util.spec_from_file_location("infer.tile", infer_tile_path)
        infer_tile = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(infer_tile)
        print("✓ HoverNet inference module loaded")
    else:
        print("ℹ HoverNet inference module not found at expected location")
        print("  Will use command-line interface instead")
except Exception as e:
    print(f"ℹ Could not import HoverNet modules: {e}")
    print("  Will use command-line interface instead")

In [None]:
# Setup HoverNet inference using command-line interface
def setup_hovernet_inference():
    """
    Setup HoverNet for inference using the run_infer.py script.
    This is the most reliable method that works with the official weights.
    """
    hovernet_dir = '/content/hover_net'
    run_infer_script = os.path.join(hovernet_dir, 'run_infer.py')
    
    if not os.path.exists(run_infer_script):
        print(f"❌ ERROR: run_infer.py not found at {run_infer_script}")
        return False
    
    print(f"✓ HoverNet inference script found: {run_infer_script}")
    return True


def run_hovernet_inference(image, image_name):
    """
    Run HoverNet inference on a single image using the command-line interface.
    
    Args:
        image: RGB image array (H x W x 3)
        image_name: Name of the image file (for temp storage)
        
    Returns:
        dict: Segmentation results containing:
            - 'instance_map': Instance segmentation map
            - 'type_map': Cell type classification map (if available)
            - 'num_nuclei': Number of detected nuclei
    """
    try:
        # Create temporary directories
        temp_input_dir = os.path.join(TEMP_DIR, 'input')
        temp_output_dir = os.path.join(TEMP_DIR, 'output')
        os.makedirs(temp_input_dir, exist_ok=True)
        os.makedirs(temp_output_dir, exist_ok=True)
        
        # Save image temporarily
        base_name = os.path.splitext(image_name)[0]
        temp_image_path = os.path.join(temp_input_dir, f"{base_name}.png")
        cv2.imwrite(temp_image_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
        
        # Run HoverNet inference
        hovernet_dir = '/content/hover_net'
        cmd = [
            'python', os.path.join(hovernet_dir, 'run_infer.py'),
            '--gpu=0' if torch.cuda.is_available() else '--gpu=-1',
            f'--nr_types=0',  # No type classification, just segmentation
            f'--model_path={model_path}',
            f'--model_mode=fast',  # Use fast mode
            f'--nr_inference_workers=4',
            f'--nr_post_proc_workers=4',
            'tile',
            f'--input_dir={temp_input_dir}',
            f'--output_dir={temp_output_dir}',
            '--draw_dot',
            '--save_qupath'
        ]
        
        # Run the command silently
        import subprocess
        result = subprocess.run(
            cmd,
            cwd=hovernet_dir,
            capture_output=True,
            text=True
        )
        
        # Check if inference was successful
        mat_file = os.path.join(temp_output_dir, 'mat', f"{base_name}.mat")
        json_file = os.path.join(temp_output_dir, 'json', f"{base_name}.json")
        
        if os.path.exists(json_file):
            # Load the JSON output which contains nuclei information
            with open(json_file, 'r') as f:
                nuclei_dict = json.load(f)
            
            # Create instance map from nuclei dictionary
            instance_map = np.zeros(image.shape[:2], dtype=np.uint16)
            
            nuc_info = nuclei_dict.get('nuc', {})
            num_nuclei = len(nuc_info)
            
            for nuc_id, nuc_data in nuc_info.items():
                contour = nuc_data.get('contour', [])
                if contour:
                    contour_array = np.array(contour, dtype=np.int32)
                    cv2.fillPoly(instance_map, [contour_array], int(nuc_id))
            
            # Clean up temp files
            shutil.rmtree(temp_input_dir)
            shutil.rmtree(temp_output_dir)
            
            return {
                'instance_map': instance_map,
                'type_map': None,
                'num_nuclei': num_nuclei,
                'nuclei_info': nuc_info
            }
        else:
            # If HoverNet fails, fallback to simple segmentation
            print(f"⚠ HoverNet inference failed for {image_name}, using fallback")
            return fallback_segmentation(image)
            
    except Exception as e:
        print(f"⚠ Error during HoverNet inference: {e}")
        return fallback_segmentation(image)


def fallback_segmentation(image):
    """
    Fallback segmentation method using traditional computer vision.
    Used when HoverNet inference fails.
    
    Args:
        image: RGB image array
        
    Returns:
        dict: Basic segmentation results
    """
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
    # Apply CLAHE for better contrast
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    enhanced = clahe.apply(gray)
    
    # Otsu's thresholding
    _, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    
    # Morphological operations
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)
    binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=2)
    
    # Find contours
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Filter contours by area (remove noise)
    min_area = 50
    max_area = 5000
    valid_contours = [c for c in contours if min_area < cv2.contourArea(c) < max_area]
    
    # Create instance map
    instance_map = np.zeros(image.shape[:2], dtype=np.uint16)
    for idx, contour in enumerate(valid_contours, start=1):
        cv2.drawContours(instance_map, [contour], -1, idx, -1)
    
    return {
        'instance_map': instance_map,
        'type_map': None,
        'num_nuclei': len(valid_contours)
    }


# Verify HoverNet setup
hovernet_ready = setup_hovernet_inference()
if hovernet_ready:
    print("\n✓ HoverNet inference is ready")
else:
    print("\n⚠ HoverNet inference setup incomplete")
    print("  Will use fallback segmentation method")

## 7. Image Processing Functions

In [None]:
def load_image(image_path):
    """
    Load a .bmp image and convert to RGB.
    
    Args:
        image_path: Path to the image file
        
    Returns:
        numpy.ndarray: RGB image
    """
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Failed to load image: {image_path}")
    
    # Convert BGR to RGB
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img_rgb


def preprocess_for_hovernet(image):
    """
    Preprocess image for HoverNet inference.
    
    Args:
        image: RGB image array
        
    Returns:
        numpy.ndarray: Preprocessed image
    """
    # HoverNet expects images in RGB format with values in [0, 255]
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)
    
    return image


def create_overlay(image, mask, alpha=0.4):
    """
    Create an overlay of the segmentation mask on the original image.
    
    Args:
        image: Original RGB image
        mask: Segmentation mask
        alpha: Transparency factor
        
    Returns:
        numpy.ndarray: Overlay image
    """
    # Create a colorized version of the mask
    colored_mask = np.zeros_like(image)
    
    # Get unique instance IDs
    unique_instances = np.unique(mask)
    unique_instances = unique_instances[unique_instances > 0]  # Skip background
    
    # Use a fixed random seed for reproducible colors
    np.random.seed(42)
    colors = np.random.randint(50, 255, size=(len(unique_instances), 3))
    
    for idx, instance_id in enumerate(unique_instances):
        colored_mask[mask == instance_id] = colors[idx]
    
    # Blend original image with colored mask
    overlay = cv2.addWeighted(image, 1-alpha, colored_mask, alpha, 0)
    
    # Draw contours for better visibility
    contours, _ = cv2.findContours(
        (mask > 0).astype(np.uint8),
        cv2.RETR_EXTERNAL,
        cv2.CHAIN_APPROX_SIMPLE
    )
    cv2.drawContours(overlay, contours, -1, (255, 255, 0), 1)
    
    return overlay


def save_segmentation_result(output_dir, image_name, original_image, segmentation_mask):
    """
    Save segmentation results including original image, mask, and overlay.
    
    Args:
        output_dir: Directory to save results
        image_name: Name of the original image
        original_image: Original RGB image
        segmentation_mask: Segmentation mask from HoverNet
        
    Returns:
        dict: Paths to saved files
    """
    os.makedirs(output_dir, exist_ok=True)
    
    base_name = os.path.splitext(image_name)[0]
    
    # Save original image
    original_path = os.path.join(output_dir, f"{base_name}_original.png")
    cv2.imwrite(original_path, cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR))
    
    # Save segmentation mask (as 16-bit to preserve instance IDs)
    mask_path = os.path.join(output_dir, f"{base_name}_mask.png")
    cv2.imwrite(mask_path, segmentation_mask.astype(np.uint16))
    
    # Create and save overlay
    overlay = create_overlay(original_image, segmentation_mask)
    overlay_path = os.path.join(output_dir, f"{base_name}_overlay.png")
    cv2.imwrite(overlay_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    
    return {
        'original': original_path,
        'mask': mask_path,
        'overlay': overlay_path
    }

print("✓ Image processing functions defined")

## 8. Batch Processing Pipeline

In [None]:
def process_single_image(image_path, output_dir):
    """
    Process a single image through the HoverNet pipeline.
    
    Args:
        image_path: Path to input image
        output_dir: Directory to save results
        
    Returns:
        dict: Processing results and statistics
    """
    try:
        # Load image
        image = load_image(image_path)
        
        # Preprocess
        preprocessed = preprocess_for_hovernet(image)
        
        # Run HoverNet inference
        image_name = os.path.basename(image_path)
        results = run_hovernet_inference(preprocessed, image_name)
        
        # Save results
        saved_paths = save_segmentation_result(
            output_dir,
            image_name,
            image,
            results['instance_map']
        )
        
        return {
            'success': True,
            'image_path': image_path,
            'image_name': image_name,
            'num_nuclei': results['num_nuclei'],
            'saved_paths': saved_paths,
            'error': None
        }
        
    except Exception as e:
        return {
            'success': False,
            'image_path': image_path,
            'image_name': os.path.basename(image_path),
            'num_nuclei': 0,
            'saved_paths': None,
            'error': str(e)
        }


def process_cell_class(cell_class, max_images=None, progress_bar=True):
    """
    Process all images for a specific cell class.
    
    Args:
        cell_class: Name of the cell class
        max_images: Maximum number of images to process (None for all)
        progress_bar: Whether to show progress bar
        
    Returns:
        dict: Processing statistics
    """
    # Setup paths
    input_dir = os.path.join(DATASET_BASE_PATH, cell_class, IMAGE_SUBFOLDER)
    output_dir = os.path.join(OUTPUT_BASE_PATH, cell_class)
    
    # Get all .bmp files
    image_files = sorted(glob.glob(os.path.join(input_dir, '*.bmp')))
    
    # Limit number of images if specified
    if max_images is not None:
        image_files = image_files[:max_images]
    
    print(f"\n{'='*70}")
    print(f"Processing {cell_class}")
    print(f"{'='*70}")
    print(f"Input directory: {input_dir}")
    print(f"Output directory: {output_dir}")
    print(f"Number of images: {len(image_files)}")
    print(f"{'='*70}\n")
    
    # Process each image
    results = []
    
    if progress_bar:
        iterator = tqdm(image_files, desc=f"Processing {cell_class}")
    else:
        iterator = image_files
    
    for image_path in iterator:
        result = process_single_image(image_path, output_dir)
        results.append(result)
        
        if not result['success'] and not progress_bar:
            print(f"❌ Error processing {result['image_name']}: {result['error']}")
    
    # Calculate statistics
    successful = sum(1 for r in results if r['success'])
    failed = len(results) - successful
    total_nuclei = sum(r['num_nuclei'] for r in results if r['success'])
    avg_nuclei = total_nuclei / successful if successful > 0 else 0
    
    stats = {
        'cell_class': cell_class,
        'total_images': len(image_files),
        'successful': successful,
        'failed': failed,
        'total_nuclei': total_nuclei,
        'avg_nuclei_per_image': avg_nuclei,
        'results': results
    }
    
    # Print summary
    print(f"\n{'='*70}")
    print(f"Summary for {cell_class}:")
    print(f"  Total images: {stats['total_images']}")
    print(f"  Successful: {stats['successful']}")
    print(f"  Failed: {stats['failed']}")
    print(f"  Total nuclei detected: {stats['total_nuclei']}")
    print(f"  Average nuclei per image: {stats['avg_nuclei_per_image']:.2f}")
    print(f"{'='*70}\n")
    
    return stats


def process_all_classes(max_images_per_class=None):
    """
    Process all cell classes in the dataset.
    
    Args:
        max_images_per_class: Maximum images per class (None for all)
        
    Returns:
        dict: Complete processing statistics
    """
    print("\n" + "="*70)
    print("STARTING BATCH PROCESSING OF ALL CELL CLASSES")
    print("="*70)
    
    start_time = datetime.now()
    all_stats = {}
    
    for cell_class in CELL_CLASSES:
        stats = process_cell_class(cell_class, max_images=max_images_per_class)
        all_stats[cell_class] = stats
    
    end_time = datetime.now()
    processing_time = (end_time - start_time).total_seconds()
    
    # Overall statistics
    total_images = sum(stats['total_images'] for stats in all_stats.values())
    total_successful = sum(stats['successful'] for stats in all_stats.values())
    total_failed = sum(stats['failed'] for stats in all_stats.values())
    total_nuclei = sum(stats['total_nuclei'] for stats in all_stats.values())
    
    overall_stats = {
        'start_time': start_time.isoformat(),
        'end_time': end_time.isoformat(),
        'processing_time_seconds': processing_time,
        'total_images': total_images,
        'total_successful': total_successful,
        'total_failed': total_failed,
        'total_nuclei_detected': total_nuclei,
        'avg_nuclei_per_image': total_nuclei / total_successful if total_successful > 0 else 0,
        'class_statistics': all_stats
    }
    
    # Save overall statistics
    stats_file = os.path.join(OUTPUT_BASE_PATH, 'processing_statistics.json')
    with open(stats_file, 'w') as f:
        # Remove detailed results to keep file size manageable
        stats_to_save = overall_stats.copy()
        for class_name in stats_to_save['class_statistics']:
            stats_to_save['class_statistics'][class_name].pop('results', None)
        json.dump(stats_to_save, f, indent=2)
    
    print("\n" + "="*70)
    print("OVERALL PROCESSING COMPLETE")
    print("="*70)
    print(f"Total images processed: {total_images}")
    print(f"Successful: {total_successful}")
    print(f"Failed: {total_failed}")
    print(f"Total nuclei detected: {total_nuclei}")
    print(f"Average nuclei per image: {overall_stats['avg_nuclei_per_image']:.2f}")
    print(f"Processing time: {processing_time:.2f} seconds ({processing_time/60:.2f} minutes)")
    print(f"Statistics saved to: {stats_file}")
    print("="*70 + "\n")
    
    return overall_stats

print("✓ Batch processing functions defined")

## 9. Visualization Functions

In [None]:
def visualize_sample_results(cell_class, num_samples=3):
    """
    Visualize sample segmentation results for a cell class.
    
    Args:
        cell_class: Name of the cell class
        num_samples: Number of samples to visualize
    """
    output_dir = os.path.join(OUTPUT_BASE_PATH, cell_class)
    
    # Get sample overlay images
    overlay_files = sorted(glob.glob(os.path.join(output_dir, '*_overlay.png')))
    
    if not overlay_files:
        print(f"No results found for {cell_class}")
        return
    
    # Select samples
    import random
    samples = random.sample(overlay_files, min(num_samples, len(overlay_files)))
    
    # Create visualization
    fig, axes = plt.subplots(1, len(samples), figsize=(5*len(samples), 5))
    if len(samples) == 1:
        axes = [axes]
    
    for idx, overlay_path in enumerate(samples):
        overlay = cv2.imread(overlay_path)
        overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
        
        axes[idx].imshow(overlay_rgb)
        base_name = os.path.basename(overlay_path).replace('_overlay.png', '')
        axes[idx].set_title(f"{cell_class.replace('im_', '')}\n{base_name}", fontsize=10)
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()


def plot_statistics(stats):
    """
    Plot processing statistics across all cell classes.
    
    Args:
        stats: Overall statistics dictionary
    """
    class_stats = stats['class_statistics']
    
    cell_classes = list(class_stats.keys())
    class_labels = [c.replace('im_', '') for c in cell_classes]
    nuclei_counts = [class_stats[c]['total_nuclei'] for c in cell_classes]
    avg_nuclei = [class_stats[c]['avg_nuclei_per_image'] for c in cell_classes]
    image_counts = [class_stats[c]['total_images'] for c in cell_classes]
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # Plot 1: Total nuclei per class
    axes[0, 0].bar(range(len(cell_classes)), nuclei_counts, color='steelblue')
    axes[0, 0].set_xlabel('Cell Class', fontsize=12)
    axes[0, 0].set_ylabel('Total Nuclei Detected', fontsize=12)
    axes[0, 0].set_title('Total Nuclei Detected per Cell Class', fontsize=14, fontweight='bold')
    axes[0, 0].set_xticks(range(len(cell_classes)))
    axes[0, 0].set_xticklabels(class_labels, rotation=45, ha='right')
    axes[0, 0].grid(axis='y', alpha=0.3)
    
    # Plot 2: Average nuclei per image
    axes[0, 1].bar(range(len(cell_classes)), avg_nuclei, color='coral')
    axes[0, 1].set_xlabel('Cell Class', fontsize=12)
    axes[0, 1].set_ylabel('Average Nuclei per Image', fontsize=12)
    axes[0, 1].set_title('Average Nuclei per Image by Cell Class', fontsize=14, fontweight='bold')
    axes[0, 1].set_xticks(range(len(cell_classes)))
    axes[0, 1].set_xticklabels(class_labels, rotation=45, ha='right')
    axes[0, 1].grid(axis='y', alpha=0.3)
    
    # Plot 3: Number of images per class
    axes[1, 0].bar(range(len(cell_classes)), image_counts, color='mediumseagreen')
    axes[1, 0].set_xlabel('Cell Class', fontsize=12)
    axes[1, 0].set_ylabel('Number of Images', fontsize=12)
    axes[1, 0].set_title('Number of Images per Cell Class', fontsize=14, fontweight='bold')
    axes[1, 0].set_xticks(range(len(cell_classes)))
    axes[1, 0].set_xticklabels(class_labels, rotation=45, ha='right')
    axes[1, 0].grid(axis='y', alpha=0.3)
    
    # Plot 4: Success rate per class
    success_rates = [
        (class_stats[c]['successful'] / class_stats[c]['total_images'] * 100)
        if class_stats[c]['total_images'] > 0 else 0
        for c in cell_classes
    ]
    axes[1, 1].bar(range(len(cell_classes)), success_rates, color='mediumpurple')
    axes[1, 1].set_xlabel('Cell Class', fontsize=12)
    axes[1, 1].set_ylabel('Success Rate (%)', fontsize=12)
    axes[1, 1].set_title('Processing Success Rate by Cell Class', fontsize=14, fontweight='bold')
    axes[1, 1].set_xticks(range(len(cell_classes)))
    axes[1, 1].set_xticklabels(class_labels, rotation=45, ha='right')
    axes[1, 1].set_ylim([0, 105])
    axes[1, 1].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("✓ Visualization functions defined")

## 10. Run Processing

### Test Run (Optional)

First, test on a small number of images to ensure everything works correctly.

In [None]:
# Test on a single class with limited images
if weights_valid and dataset_valid:
    print("Running test on first cell class with 5 images...\n")
    test_stats = process_cell_class(CELL_CLASSES[0], max_images=5)
    
    if test_stats['successful'] > 0:
        print("\n✓ Test completed successfully!")
        print("\nVisualizing test results:")
        visualize_sample_results(CELL_CLASSES[0], num_samples=min(3, test_stats['successful']))
    else:
        print("\n❌ Test failed. Please check errors above.")
else:
    print("❌ Cannot run test - verification failed. Please fix errors in previous cells.")

### Full Processing

**⚠️ WARNING:** This will process ALL images in the dataset. Depending on the dataset size, this may take a long time!

Uncomment and run the cell below to process all images.

In [None]:
# # UNCOMMENT TO RUN FULL PROCESSING
# if weights_valid and dataset_valid:
#     # Process all images
#     overall_stats = process_all_classes()
# else:
#     print("❌ Cannot run processing - verification failed. Please fix errors above.")

### Partial Processing (Recommended)

Process a limited number of images per class for testing or quick analysis.

In [None]:
# Process limited number of images per class (e.g., 50 images per class)
if weights_valid and dataset_valid:
    MAX_IMAGES_PER_CLASS = 50  # Adjust this number as needed
    
    print(f"Processing up to {MAX_IMAGES_PER_CLASS} images per class...\n")
    overall_stats = process_all_classes(max_images_per_class=MAX_IMAGES_PER_CLASS)
else:
    print("❌ Cannot run processing - verification failed. Please fix errors above.")

## 11. Visualize Results

In [None]:
# Visualize sample results for each class
if 'overall_stats' in locals():
    print("\nVisualizing sample results for each class:\n")
    
    for cell_class in CELL_CLASSES:
        print(f"\n{'='*70}")
        print(f"Sample results for {cell_class}:")
        print(f"{'='*70}")
        visualize_sample_results(cell_class, num_samples=3)
else:
    print("No results to visualize. Please run processing first.")

In [None]:
# Plot overall statistics
if 'overall_stats' in locals():
    print("\nPlotting overall statistics:\n")
    plot_statistics(overall_stats)
else:
    print("No statistics to plot. Please run processing first.")

## 12. Export Results Summary

In [None]:
def create_summary_report(stats):
    """
    Create a formatted summary report of the processing results.
    
    Args:
        stats: Overall statistics dictionary
        
    Returns:
        str: Formatted report
    """
    report = []
    report.append("="*80)
    report.append("HoverNet Nuclear Segmentation - Processing Report")
    report.append("="*80)
    report.append(f"\nProcessing Date: {stats['start_time']}")
    report.append(f"Total Processing Time: {stats['processing_time_seconds']:.2f} seconds ({stats['processing_time_seconds']/60:.2f} minutes)")
    report.append(f"\nDataset: {DATASET_NAME}")
    report.append(f"Output Location: {OUTPUT_BASE_PATH}")
    report.append(f"Weights Location: {WEIGHTS_DIR}")
    
    report.append("\n" + "="*80)
    report.append("OVERALL STATISTICS")
    report.append("="*80)
    report.append(f"Total Images Processed: {stats['total_images']}")
    report.append(f"Successful: {stats['total_successful']} ({stats['total_successful']/stats['total_images']*100:.1f}%)")
    report.append(f"Failed: {stats['total_failed']} ({stats['total_failed']/stats['total_images']*100:.1f}%)")
    report.append(f"Total Nuclei Detected: {stats['total_nuclei_detected']}")
    report.append(f"Average Nuclei per Image: {stats['avg_nuclei_per_image']:.2f}")
    
    report.append("\n" + "="*80)
    report.append("PER-CLASS STATISTICS")
    report.append("="*80)
    
    for cell_class, class_stats in stats['class_statistics'].items():
        report.append(f"\n{cell_class}:")
        report.append(f"  Images Processed: {class_stats['total_images']}")
        report.append(f"  Successful: {class_stats['successful']} ({class_stats['successful']/class_stats['total_images']*100:.1f}%)")
        report.append(f"  Failed: {class_stats['failed']}")
        report.append(f"  Total Nuclei: {class_stats['total_nuclei']}")
        report.append(f"  Avg Nuclei/Image: {class_stats['avg_nuclei_per_image']:.2f}")
    
    report.append("\n" + "="*80)
    report.append("\nReport generated by HoverNet Nuclear Segmentation Pipeline")
    report.append(f"Generated at: {datetime.now().isoformat()}")
    report.append("="*80)
    
    return "\n".join(report)


# Generate and save report
if 'overall_stats' in locals():
    report = create_summary_report(overall_stats)
    
    # Print report
    print(report)
    
    # Save report to file
    report_file = os.path.join(OUTPUT_BASE_PATH, 'processing_report.txt')
    with open(report_file, 'w') as f:
        f.write(report)
    
    print(f"\n✓ Report saved to: {report_file}")
else:
    print("No results to report. Please run processing first.")

## 13. Cleanup (Optional)

In [None]:
# Clean up temporary files
if os.path.exists(TEMP_DIR):
    shutil.rmtree(TEMP_DIR)
    print(f"✓ Cleaned up temporary directory: {TEMP_DIR}")
else:
    print("No temporary files to clean up.")

## 14. Additional Notes

### Implementation Details:

1. **HoverNet Integration**: This notebook uses the HoverNet command-line interface for maximum compatibility
2. **Fallback Mechanism**: If HoverNet fails, a traditional CV-based segmentation is used as fallback
3. **Error Handling**: Comprehensive error handling ensures processing continues even if individual images fail
4. **GPU Acceleration**: Automatically uses GPU if available for faster processing
5. **Memory Management**: Processes images one at a time to avoid memory issues

### Output Structure:

```
HoverNet_Segmentation_Results/
├── im_Dyskeratotic/
│   ├── image1_original.png
│   ├── image1_mask.png
│   ├── image1_overlay.png
│   └── ...
├── im_Koilocytotic/
├── im_Metaplastic/
├── im_Parabasal/
├── im_Superficial-Intermediate/
├── processing_statistics.json
└── processing_report.txt
```

### Troubleshooting:

1. **Out of Memory**: Reduce batch size or process fewer images at once
2. **HoverNet Errors**: Check that weights are compatible with the HoverNet version
3. **Slow Processing**: Ensure GPU is being used (check CUDA availability)
4. **Missing Results**: Check the error messages in the processing output

### Next Steps:

1. **Quality Check**: Manually inspect sample results to ensure segmentation quality
2. **Feature Extraction**: Extract morphological features from segmented nuclei
3. **Integration**: Integrate results with your classification pipeline
4. **Fine-tuning**: Consider fine-tuning HoverNet on cervical cell images for better results

### References:

- **HoverNet Paper**: Graham, S., et al. "Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images." Medical Image Analysis (2019)
- **HoverNet GitHub**: https://github.com/vqdang/hover_net
- **Project Phoenix**: Cervical Cancer Cell Classification using Explainable AI