# Nuclear Segmentation using HoverNet

This notebook implements nuclear segmentation for cervical cancer cell classification using the HoverNet model.

**Dataset Structure:**
- Base Directory: `Augmented Dataset - Limited Enhancement`
- Classes: `im_Dyskeratotic`, `im_Koilocytotic`, `im_Metaplastic`, `im_Parabasal`, `im_Superficial-Intermediate`
- Images Location: `<class_folder>/NLM_CLAHE/*.bmp`

**Reference:** [HoverNet GitHub](https://github.com/vqdang/hover_net)

## 1. Environment Setup

In [None]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
    print("✓ Running on Google Colab")
except:
    IN_COLAB = False
    print("ℹ Not running on Google Colab")

In [None]:
# Install required dependencies
if IN_COLAB:
    print("Installing dependencies...\n")
    
    # Core dependencies
    !pip install -q opencv-python-headless
    !pip install -q scikit-image
    !pip install -q scipy
    !pip install -q matplotlib
    !pip install -q tqdm
    !pip install -q pillow
    
    # PyTorch
    !pip install -q torch torchvision
    
    # HoverNet dependencies
    !pip install -q imgaug
    !pip install -q termcolor
    
    print("\n✓ Dependencies installed successfully")
else:
    print("Skipping dependency installation (not on Colab)")

In [None]:
# Clone HoverNet repository
import os
import sys

if IN_COLAB:
    HOVERNET_DIR = '/content/hover_net'
    
    if not os.path.exists(HOVERNET_DIR):
        print("Cloning HoverNet repository...")
        !git clone https://github.com/vqdang/hover_net.git {HOVERNET_DIR}
        print("✓ HoverNet repository cloned")
    else:
        print("✓ HoverNet repository already exists")
    
    # Add HoverNet to Python path
    if HOVERNET_DIR not in sys.path:
        sys.path.insert(0, HOVERNET_DIR)
    print("✓ HoverNet added to Python path")
else:
    HOVERNET_DIR = './hover_net'
    print("Skipping HoverNet clone (not on Colab)")

## 2. Mount Google Drive

In [None]:
# Mount Google Drive
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    print("\n✓ Google Drive mounted successfully")
else:
    print("Skipping Google Drive mount (not on Colab)")

## 3. Import Libraries

In [None]:
import os
import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import json
from datetime import datetime
import glob
import shutil
import warnings
import traceback
warnings.filterwarnings('ignore')

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F

# SciPy/Scikit-image imports
from scipy.ndimage import label as scipy_label
from scipy.ndimage import binary_fill_holes, binary_dilation, binary_erosion
from scipy import ndimage

print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ CUDA device: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print(f"✓ Using device: {DEVICE}")

print(f"✓ OpenCV version: {cv2.__version__}")
print(f"✓ NumPy version: {np.__version__}")
print("\n✓ All libraries imported successfully")

## 4. Configuration

In [None]:
# ============================================================================
# CONFIGURATION - UPDATE THESE PATHS AS NEEDED
# ============================================================================

# Google Drive base path
DRIVE_BASE_PATH = '/content/drive/MyDrive'

# HoverNet weights directory (contains the .tar file)
WEIGHTS_DIR = '/content/drive/MyDrive/Projects/6_Project Phoenix_Cervical Cancer Cell Classification/Explainability Worflows/Nucleus Masking/Hovernet Weights'

# Dataset configuration
DATASET_NAME = 'Augmented Dataset - Limited Enhancement'
DATASET_BASE_PATH = os.path.join(DRIVE_BASE_PATH, DATASET_NAME)

# Cell classes
CELL_CLASSES = [
    'im_Dyskeratotic',
    'im_Koilocytotic',
    'im_Metaplastic',
    'im_Parabasal',
    'im_Superficial-Intermediate'
]

# Subfolder containing images
IMAGE_SUBFOLDER = 'NLM_CLAHE'

# Output directory for segmentation results
OUTPUT_BASE_PATH = os.path.join(DRIVE_BASE_PATH, 'HoverNet_Segmentation_Results')
os.makedirs(OUTPUT_BASE_PATH, exist_ok=True)

# Temporary directory for processing
TEMP_DIR = '/content/temp_hovernet'
os.makedirs(TEMP_DIR, exist_ok=True)

# Debug mode - set to True to see detailed error messages
DEBUG_MODE = True

print("="*70)
print("CONFIGURATION")
print("="*70)
print(f"Weights directory: {WEIGHTS_DIR}")
print(f"Dataset base path: {DATASET_BASE_PATH}")
print(f"Output base path: {OUTPUT_BASE_PATH}")
print(f"Debug mode: {DEBUG_MODE}")
print(f"\nCell classes to process: {len(CELL_CLASSES)}")
for i, cell_class in enumerate(CELL_CLASSES, 1):
    print(f"  {i}. {cell_class}")
print("="*70)

## 5. Verify Weights and Dataset

In [None]:
def find_weight_file():
    """Find the HoverNet weight file (.tar) in the weights directory."""
    print("Searching for HoverNet weights...\n")
    
    if not os.path.exists(WEIGHTS_DIR):
        print(f"❌ ERROR: Weights directory not found: {WEIGHTS_DIR}")
        return None
    
    # List all files
    all_files = os.listdir(WEIGHTS_DIR)
    print(f"Files in weights directory:")
    for f in all_files:
        file_path = os.path.join(WEIGHTS_DIR, f)
        if os.path.isfile(file_path):
            size_mb = os.path.getsize(file_path) / (1024 * 1024)
            print(f"  - {f} ({size_mb:.1f} MB)")
    
    # Find .tar files
    tar_files = [f for f in all_files if f.endswith('.tar')]
    
    if not tar_files:
        print("\n❌ ERROR: No .tar weight files found!")
        return None
    
    weight_file = os.path.join(WEIGHTS_DIR, tar_files[0])
    print(f"\n✓ Found weight file: {tar_files[0]}")
    
    return weight_file


def verify_dataset():
    """Verify dataset structure and count images."""
    print("\n" + "="*70)
    print("Verifying dataset...")
    print("="*70 + "\n")
    
    if not os.path.exists(DATASET_BASE_PATH):
        print(f"❌ ERROR: Dataset not found: {DATASET_BASE_PATH}")
        return False, {}
    
    total_images = 0
    class_counts = {}
    
    for cell_class in CELL_CLASSES:
        image_dir = os.path.join(DATASET_BASE_PATH, cell_class, IMAGE_SUBFOLDER)
        
        if not os.path.exists(image_dir):
            print(f"❌ {cell_class}: Directory not found")
            class_counts[cell_class] = 0
            continue
        
        images = glob.glob(os.path.join(image_dir, '*.bmp'))
        count = len(images)
        class_counts[cell_class] = count
        total_images += count
        print(f"✓ {cell_class:40s}: {count:5d} images")
    
    print(f"\n{'='*70}")
    print(f"Total images: {total_images}")
    print(f"{'='*70}")
    
    return total_images > 0, class_counts


# Run verifications
WEIGHT_FILE = find_weight_file()
dataset_valid, image_counts = verify_dataset()

if WEIGHT_FILE and dataset_valid:
    print("\n" + "="*70)
    print("✓ ALL VERIFICATIONS PASSED")
    print("="*70)
else:
    print("\n" + "="*70)
    print("❌ VERIFICATION FAILED - Please fix errors above")
    print("="*70)

## 6. Load HoverNet Model

In [None]:
# Try to import HoverNet model components
HOVERNET_MODULES_AVAILABLE = False
model = None

try:
    from hover_net.models.hovernet.net_desc import create_model
    print("✓ HoverNet model module imported")
    HOVERNET_MODULES_AVAILABLE = True
except ImportError as e:
    print(f"⚠ Could not import HoverNet model: {e}")
    print("Will use CV-based segmentation instead")

# Try loading the model if modules are available
if HOVERNET_MODULES_AVAILABLE and WEIGHT_FILE:
    try:
        print(f"\nLoading model from: {WEIGHT_FILE}")
        
        # Load checkpoint
        checkpoint = torch.load(WEIGHT_FILE, map_location=DEVICE)
        print(f"Checkpoint keys: {list(checkpoint.keys())}")
        
        # Get state dict
        state_dict = checkpoint.get('state_dict', checkpoint)
        
        # Check if this is a notype model
        has_type_head = any('tp.' in k for k in state_dict.keys())
        nr_types = None
        if has_type_head:
            # Find nr_types from tp output layer
            for key in state_dict.keys():
                if 'tp.u0.conv.weight' in key:
                    nr_types = state_dict[key].shape[0]
                    break
        
        print(f"Has type head: {has_type_head}")
        print(f"Number of types: {nr_types}")
        
        # Determine model mode from state dict keys
        mode = 'original'
        for key in state_dict.keys():
            if 'fast' in key.lower():
                mode = 'fast'
                break
        print(f"Model mode: {mode}")
        
        # Create model
        model = create_model(mode=mode, nr_types=nr_types)
        model.load_state_dict(state_dict, strict=False)
        model = model.to(DEVICE)
        model.eval()
        
        print("\n✓ HoverNet model loaded successfully!")
        
    except Exception as e:
        print(f"\n⚠ Failed to load HoverNet model: {e}")
        if DEBUG_MODE:
            traceback.print_exc()
        print("Will use CV-based segmentation instead")
        model = None
else:
    print("\nUsing CV-based segmentation (HoverNet model not available)")

## 7. Segmentation Functions

In [None]:
def segment_nuclei_cv(image):
    """
    Segment nuclei using traditional computer vision methods.
    This is a robust fallback that works on any image.
    
    Args:
        image: RGB image (H, W, 3) uint8
        
    Returns:
        dict with 'instance_map' and 'num_nuclei'
    """
    try:
        # Convert to different color spaces
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
        
        # Apply CLAHE for contrast enhancement
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
        enhanced = clahe.apply(gray)
        
        # Try multiple thresholding approaches and combine
        # Approach 1: Otsu on enhanced grayscale
        _, binary1 = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        
        # Approach 2: Adaptive thresholding
        binary2 = cv2.adaptiveThreshold(
            enhanced, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
            cv2.THRESH_BINARY_INV, 21, 5
        )
        
        # Approach 3: Color-based (nuclei are typically darker/purple)
        # Use saturation channel
        sat = hsv[:, :, 1]
        _, binary3 = cv2.threshold(sat, 30, 255, cv2.THRESH_BINARY)
        
        # Combine approaches
        binary = cv2.bitwise_and(binary1, binary2)
        binary = cv2.bitwise_or(binary, binary3)
        
        # Morphological cleaning
        kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        kernel_medium = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        
        # Remove noise
        binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel_small, iterations=2)
        # Fill holes
        binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel_medium, iterations=2)
        
        # Distance transform for watershed
        dist_transform = cv2.distanceTransform(binary, cv2.DIST_L2, 5)
        
        # Normalize distance transform
        if dist_transform.max() > 0:
            dist_norm = dist_transform / dist_transform.max()
        else:
            # No foreground detected - return empty
            return {'instance_map': np.zeros(image.shape[:2], dtype=np.int32), 'num_nuclei': 0}
        
        # Find sure foreground (peaks)
        _, sure_fg = cv2.threshold(dist_transform, 0.3 * dist_transform.max(), 255, 0)
        sure_fg = sure_fg.astype(np.uint8)
        
        # Sure background
        sure_bg = cv2.dilate(binary, kernel_medium, iterations=3)
        
        # Unknown region
        unknown = cv2.subtract(sure_bg, sure_fg)
        
        # Connected components for markers
        num_labels, markers = cv2.connectedComponents(sure_fg)
        markers = markers + 1  # Background is 1, not 0
        markers[unknown == 255] = 0  # Unknown regions are 0
        
        # Watershed
        markers = cv2.watershed(image, markers)
        
        # Create instance map (remove boundaries marked as -1)
        instance_map = np.zeros(image.shape[:2], dtype=np.int32)
        instance_map[markers > 1] = markers[markers > 1] - 1  # Shift labels
        
        # Filter by size
        min_area = 30
        max_area = 10000
        
        unique_ids = np.unique(instance_map)
        unique_ids = unique_ids[unique_ids > 0]
        
        valid_ids = []
        for uid in unique_ids:
            area = np.sum(instance_map == uid)
            if min_area <= area <= max_area:
                valid_ids.append(uid)
            else:
                instance_map[instance_map == uid] = 0
        
        # Relabel consecutively
        final_map = np.zeros_like(instance_map)
        for new_id, old_id in enumerate(valid_ids, start=1):
            final_map[instance_map == old_id] = new_id
        
        num_nuclei = len(valid_ids)
        
        return {
            'instance_map': final_map,
            'num_nuclei': num_nuclei
        }
        
    except Exception as e:
        if DEBUG_MODE:
            print(f"CV segmentation error: {e}")
            traceback.print_exc()
        # Return empty result on error
        return {
            'instance_map': np.zeros(image.shape[:2], dtype=np.int32),
            'num_nuclei': 0
        }


def segment_nuclei_hovernet(image):
    """
    Segment nuclei using HoverNet model.
    
    Args:
        image: RGB image (H, W, 3) uint8
        
    Returns:
        dict with 'instance_map' and 'num_nuclei', or None if failed
    """
    if model is None:
        return None
    
    try:
        # Preprocess
        img = image.astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
        
        # Inference
        with torch.no_grad():
            output = model(img_tensor)
        
        # Extract predictions
        if isinstance(output, dict):
            pred_np = output.get('np', output.get('NP'))
            pred_hv = output.get('hv', output.get('HV'))
        else:
            return None
        
        if pred_np is None:
            return None
        
        # Process NP prediction
        pred_np = torch.softmax(pred_np, dim=1)[0, 1].cpu().numpy()
        
        # Binary mask
        binary_mask = (pred_np > 0.5).astype(np.uint8)
        
        if pred_hv is not None:
            pred_hv = pred_hv[0].permute(1, 2, 0).cpu().numpy()
            # Use HV for watershed
            h_grad = pred_hv[..., 0]
            v_grad = pred_hv[..., 1]
            
            # Energy for watershed
            grad_mag = np.sqrt(h_grad**2 + v_grad**2)
            energy = pred_np * (1 - np.clip(grad_mag, 0, 1))
        else:
            energy = pred_np
        
        # Find markers using local maxima
        from skimage.feature import peak_local_max
        from skimage.segmentation import watershed
        
        coordinates = peak_local_max(
            energy,
            min_distance=5,
            threshold_abs=0.1,
            labels=binary_mask
        )
        
        if len(coordinates) == 0:
            instance_map, _ = scipy_label(binary_mask)
        else:
            markers = np.zeros_like(pred_np, dtype=np.int32)
            for i, coord in enumerate(coordinates, start=1):
                markers[coord[0], coord[1]] = i
            instance_map = watershed(-energy, markers, mask=binary_mask)
        
        num_nuclei = len(np.unique(instance_map)) - 1
        
        return {
            'instance_map': instance_map.astype(np.int32),
            'num_nuclei': max(0, num_nuclei)
        }
        
    except Exception as e:
        if DEBUG_MODE:
            print(f"HoverNet inference error: {e}")
            traceback.print_exc()
        return None


def segment_nuclei(image):
    """
    Main segmentation function. Tries HoverNet first, falls back to CV.
    
    Args:
        image: RGB image (H, W, 3) uint8
        
    Returns:
        dict with 'instance_map' and 'num_nuclei'
    """
    # Try HoverNet first
    if model is not None:
        result = segment_nuclei_hovernet(image)
        if result is not None and result['num_nuclei'] > 0:
            return result
    
    # Fallback to CV-based segmentation
    return segment_nuclei_cv(image)


print("✓ Segmentation functions defined")

## 8. Test Segmentation on Single Image

In [None]:
# Test on a single image to verify segmentation works
def test_single_image():
    """Test segmentation on a single image."""
    # Find a test image
    test_class = CELL_CLASSES[0]
    test_dir = os.path.join(DATASET_BASE_PATH, test_class, IMAGE_SUBFOLDER)
    test_images = glob.glob(os.path.join(test_dir, '*.bmp'))
    
    if not test_images:
        print("❌ No test images found")
        return False
    
    test_path = test_images[0]
    print(f"Testing on: {os.path.basename(test_path)}")
    
    # Load image
    img = cv2.imread(test_path)
    if img is None:
        print(f"❌ Failed to load image")
        return False
    
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    print(f"Image shape: {img_rgb.shape}")
    
    # Run segmentation
    print("Running segmentation...")
    result = segment_nuclei(img_rgb)
    
    print(f"Nuclei detected: {result['num_nuclei']}")
    print(f"Instance map shape: {result['instance_map'].shape}")
    print(f"Unique labels: {np.unique(result['instance_map'])}")
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original
    axes[0].imshow(img_rgb)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    # Instance map
    axes[1].imshow(result['instance_map'], cmap='nipy_spectral')
    axes[1].set_title(f'Instance Map ({result["num_nuclei"]} nuclei)')
    axes[1].axis('off')
    
    # Overlay
    overlay = img_rgb.copy()
    mask = result['instance_map'] > 0
    overlay[mask] = overlay[mask] * 0.5 + np.array([0, 255, 0]) * 0.5
    axes[2].imshow(overlay.astype(np.uint8))
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return result['num_nuclei'] > 0


# Run test
if dataset_valid:
    success = test_single_image()
    if success:
        print("\n✓ Single image test PASSED")
    else:
        print("\n⚠ Single image test: No nuclei detected (may be normal for some images)")
else:
    print("❌ Cannot test - dataset not found")

## 9. Image Processing and Saving

In [None]:
def load_image(image_path):
    """Load an image and convert to RGB."""
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Failed to load: {image_path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def create_overlay(image, mask, alpha=0.4):
    """Create colorful overlay of segmentation on image."""
    overlay = image.copy()
    
    unique_ids = np.unique(mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    if len(unique_ids) == 0:
        return overlay
    
    # Generate colors
    np.random.seed(42)
    colors = np.random.randint(50, 255, size=(len(unique_ids), 3))
    
    colored_mask = np.zeros_like(image)
    for idx, uid in enumerate(unique_ids):
        colored_mask[mask == uid] = colors[idx]
    
    # Blend
    overlay = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
    
    # Draw contours
    for uid in unique_ids:
        binary = (mask == uid).astype(np.uint8)
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(overlay, contours, -1, (255, 255, 0), 1)
    
    return overlay


def save_results(output_dir, image_name, image, instance_map):
    """Save segmentation results."""
    os.makedirs(output_dir, exist_ok=True)
    base_name = os.path.splitext(image_name)[0]
    
    # Save original
    orig_path = os.path.join(output_dir, f"{base_name}_original.png")
    cv2.imwrite(orig_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
    
    # Save mask
    mask_path = os.path.join(output_dir, f"{base_name}_mask.png")
    cv2.imwrite(mask_path, instance_map.astype(np.uint16))
    
    # Save overlay
    overlay = create_overlay(image, instance_map)
    overlay_path = os.path.join(output_dir, f"{base_name}_overlay.png")
    cv2.imwrite(overlay_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    
    return {'original': orig_path, 'mask': mask_path, 'overlay': overlay_path}


print("✓ Image processing functions defined")

## 10. Batch Processing

In [None]:
def process_image(image_path, output_dir):
    """Process a single image."""
    image_name = os.path.basename(image_path)
    
    try:
        # Load
        image = load_image(image_path)
        
        # Segment
        result = segment_nuclei(image)
        
        # Save
        saved = save_results(output_dir, image_name, image, result['instance_map'])
        
        return {
            'success': True,
            'image_name': image_name,
            'num_nuclei': result['num_nuclei'],
            'saved_paths': saved,
            'error': None
        }
        
    except Exception as e:
        error_msg = str(e)
        if DEBUG_MODE:
            print(f"\n❌ Error processing {image_name}: {error_msg}")
            traceback.print_exc()
        
        return {
            'success': False,
            'image_name': image_name,
            'num_nuclei': 0,
            'saved_paths': None,
            'error': error_msg
        }


def process_class(cell_class, max_images=None):
    """Process all images for a cell class."""
    input_dir = os.path.join(DATASET_BASE_PATH, cell_class, IMAGE_SUBFOLDER)
    output_dir = os.path.join(OUTPUT_BASE_PATH, cell_class)
    
    images = sorted(glob.glob(os.path.join(input_dir, '*.bmp')))
    if max_images:
        images = images[:max_images]
    
    print(f"\n{'='*60}")
    print(f"Processing: {cell_class}")
    print(f"Images: {len(images)}")
    print(f"{'='*60}")
    
    results = []
    errors = []
    
    for img_path in tqdm(images, desc=cell_class):
        result = process_image(img_path, output_dir)
        results.append(result)
        if not result['success']:
            errors.append(result)
    
    successful = sum(1 for r in results if r['success'])
    total_nuclei = sum(r['num_nuclei'] for r in results if r['success'])
    avg_nuclei = total_nuclei / successful if successful > 0 else 0
    
    print(f"\nSummary: {successful}/{len(images)} successful")
    print(f"Total nuclei: {total_nuclei}, Avg: {avg_nuclei:.1f}")
    
    if errors and DEBUG_MODE:
        print(f"\nErrors ({len(errors)}):")
        for err in errors[:5]:  # Show first 5 errors
            print(f"  - {err['image_name']}: {err['error']}")
    
    return {
        'cell_class': cell_class,
        'total_images': len(images),
        'successful': successful,
        'failed': len(images) - successful,
        'total_nuclei': total_nuclei,
        'avg_nuclei_per_image': avg_nuclei
    }


def process_all(max_images_per_class=None):
    """Process all cell classes."""
    print("\n" + "="*60)
    print("STARTING BATCH PROCESSING")
    print("="*60)
    
    start_time = datetime.now()
    all_stats = {}
    
    for cell_class in CELL_CLASSES:
        stats = process_class(cell_class, max_images_per_class)
        all_stats[cell_class] = stats
    
    elapsed = (datetime.now() - start_time).total_seconds()
    
    # Overall statistics
    total_images = sum(s['total_images'] for s in all_stats.values())
    total_successful = sum(s['successful'] for s in all_stats.values())
    total_nuclei = sum(s['total_nuclei'] for s in all_stats.values())
    
    overall = {
        'processing_time_seconds': elapsed,
        'total_images': total_images,
        'total_successful': total_successful,
        'total_failed': total_images - total_successful,
        'total_nuclei_detected': total_nuclei,
        'class_statistics': all_stats
    }
    
    # Save statistics
    stats_file = os.path.join(OUTPUT_BASE_PATH, 'processing_statistics.json')
    with open(stats_file, 'w') as f:
        json.dump(overall, f, indent=2)
    
    print("\n" + "="*60)
    print("PROCESSING COMPLETE")
    print("="*60)
    print(f"Total: {total_images} images, {total_successful} successful")
    print(f"Nuclei detected: {total_nuclei}")
    print(f"Time: {elapsed:.1f}s ({elapsed/60:.1f} min)")
    print(f"Stats saved: {stats_file}")
    print("="*60)
    
    return overall


print("✓ Batch processing functions defined")

## 11. Run Processing

### Test Run (5 images)

In [None]:
# Test on first class with 5 images
if dataset_valid:
    print("Running test with 5 images...\n")
    test_stats = process_class(CELL_CLASSES[0], max_images=5)
    
    if test_stats['successful'] > 0:
        print("\n✓ Test PASSED!")
    else:
        print("\n❌ Test FAILED - Check errors above")
else:
    print("❌ Cannot run - check configuration")

### Visualize Test Results

In [None]:
# Visualize results from test
def visualize_samples(cell_class, n=3):
    """Visualize sample results."""
    output_dir = os.path.join(OUTPUT_BASE_PATH, cell_class)
    overlays = sorted(glob.glob(os.path.join(output_dir, '*_overlay.png')))
    
    if not overlays:
        print(f"No results for {cell_class}")
        return
    
    import random
    samples = random.sample(overlays, min(n, len(overlays)))
    
    fig, axes = plt.subplots(1, len(samples), figsize=(5*len(samples), 5))
    if len(samples) == 1:
        axes = [axes]
    
    for i, path in enumerate(samples):
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        axes[i].imshow(img)
        axes[i].set_title(cell_class.replace('im_', ''))
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()


# Show test results
if 'test_stats' in locals() and test_stats['successful'] > 0:
    visualize_samples(CELL_CLASSES[0], n=3)

### Process All Classes (50 images each)

In [None]:
# Process 50 images per class
if dataset_valid:
    overall_stats = process_all(max_images_per_class=50)
else:
    print("❌ Cannot run - check configuration")

### Process All Images (Full Dataset)

In [None]:
# # Uncomment to process ALL images (may take a while)
# if dataset_valid:
#     overall_stats = process_all()
# else:
#     print("❌ Cannot run - check configuration")

## 12. Visualize All Results

In [None]:
# Visualize samples from each class
if 'overall_stats' in locals():
    for cell_class in CELL_CLASSES:
        print(f"\n{cell_class}:")
        visualize_samples(cell_class, n=3)

In [None]:
# Plot statistics
def plot_stats(stats):
    """Plot statistics."""
    class_stats = stats['class_statistics']
    classes = list(class_stats.keys())
    labels = [c.replace('im_', '') for c in classes]
    
    nuclei = [class_stats[c]['total_nuclei'] for c in classes]
    avg = [class_stats[c]['avg_nuclei_per_image'] for c in classes]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    ax1.bar(range(len(classes)), nuclei, color='steelblue')
    ax1.set_xticks(range(len(classes)))
    ax1.set_xticklabels(labels, rotation=45, ha='right')
    ax1.set_ylabel('Total Nuclei')
    ax1.set_title('Total Nuclei per Class')
    
    ax2.bar(range(len(classes)), avg, color='coral')
    ax2.set_xticks(range(len(classes)))
    ax2.set_xticklabels(labels, rotation=45, ha='right')
    ax2.set_ylabel('Avg Nuclei per Image')
    ax2.set_title('Average Nuclei per Image')
    
    plt.tight_layout()
    plt.show()


if 'overall_stats' in locals():
    plot_stats(overall_stats)

## 13. Cleanup

In [None]:
# Clean temp files
if os.path.exists(TEMP_DIR):
    shutil.rmtree(TEMP_DIR)
    print(f"✓ Cleaned: {TEMP_DIR}")

## Notes

### Output Structure
```
HoverNet_Segmentation_Results/
├── im_Dyskeratotic/
│   ├── image_original.png
│   ├── image_mask.png
│   └── image_overlay.png
├── im_Koilocytotic/
├── im_Metaplastic/
├── im_Parabasal/
├── im_Superficial-Intermediate/
└── processing_statistics.json
```

### Troubleshooting
- Set `DEBUG_MODE = True` to see detailed error messages
- CV-based segmentation is used if HoverNet model fails to load
- Zero nuclei detected may be normal for some images

### References
- [HoverNet Paper](https://arxiv.org/abs/1812.06499)
- [HoverNet GitHub](https://github.com/vqdang/hover_net)