# Nuclear Segmentation using HoverNet

This notebook implements nuclear segmentation for cervical cancer cell classification using the HoverNet model.

**Dataset Structure:**
- Base Directory: `Augmented Dataset - Limited Enhancement`
- Classes: `im_Dyskeratotic`, `im_Koilocytotic`, `im_Metaplastic`, `im_Parabasal`, `im_Superficial-Intermediate`
- Images Location: `<class_folder>/NLM_CLAHE/*.bmp`

**Reference:** [HoverNet GitHub](https://github.com/vqdang/hover_net)

## 1. Environment Setup

In [None]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
    print("✓ Running on Google Colab")
except:
    IN_COLAB = False
    print("ℹ Not running on Google Colab")

In [None]:
# Install required dependencies
if IN_COLAB:
    print("Installing dependencies...\n")
    
    # Core dependencies
    !pip install -q opencv-python-headless
    !pip install -q scikit-image
    !pip install -q scipy
    !pip install -q matplotlib
    !pip install -q tqdm
    !pip install -q pillow
    
    # PyTorch (HoverNet uses PyTorch)
    !pip install -q torch torchvision
    
    # Additional dependencies for HoverNet
    !pip install -q imgaug
    !pip install -q termcolor
    
    print("\n✓ Dependencies installed successfully")
else:
    print("Skipping dependency installation (not on Colab)")

In [None]:
# Clone HoverNet repository
import os
import sys

if IN_COLAB:
    HOVERNET_DIR = '/content/hover_net'
    
    if not os.path.exists(HOVERNET_DIR):
        print("Cloning HoverNet repository...")
        !git clone https://github.com/vqdang/hover_net.git {HOVERNET_DIR}
        print("✓ HoverNet repository cloned")
    else:
        print("✓ HoverNet repository already exists")
    
    # Add HoverNet to Python path
    if HOVERNET_DIR not in sys.path:
        sys.path.insert(0, HOVERNET_DIR)
    print("✓ HoverNet added to Python path")
else:
    HOVERNET_DIR = './hover_net'
    print("Skipping HoverNet clone (not on Colab)")

## 2. Mount Google Drive

In [None]:
# Mount Google Drive
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    print("\n✓ Google Drive mounted successfully")
else:
    print("Skipping Google Drive mount (not on Colab)")

## 3. Import Libraries

In [None]:
import os
import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import json
from datetime import datetime
import glob
import shutil
import warnings
warnings.filterwarnings('ignore')

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F

print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ CUDA device: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print(f"✓ Using device: {DEVICE}")

print(f"✓ OpenCV version: {cv2.__version__}")
print(f"✓ NumPy version: {np.__version__}")
print("\n✓ All libraries imported successfully")

## 4. Configuration

In [None]:
# ============================================================================
# CONFIGURATION - UPDATE THESE PATHS AS NEEDED
# ============================================================================

# Google Drive base path
DRIVE_BASE_PATH = '/content/drive/MyDrive'

# HoverNet weights directory (contains the .tar file)
WEIGHTS_DIR = '/content/drive/MyDrive/Projects/6_Project Phoenix_Cervical Cancer Cell Classification/Explainability Worflows/Nucleus Masking/Hovernet Weights'

# Dataset configuration
DATASET_NAME = 'Augmented Dataset - Limited Enhancement'
DATASET_BASE_PATH = os.path.join(DRIVE_BASE_PATH, DATASET_NAME)

# Cell classes
CELL_CLASSES = [
    'im_Dyskeratotic',
    'im_Koilocytotic',
    'im_Metaplastic',
    'im_Parabasal',
    'im_Superficial-Intermediate'
]

# Subfolder containing images
IMAGE_SUBFOLDER = 'NLM_CLAHE'

# Output directory for segmentation results
OUTPUT_BASE_PATH = os.path.join(DRIVE_BASE_PATH, 'HoverNet_Segmentation_Results')
os.makedirs(OUTPUT_BASE_PATH, exist_ok=True)

# Temporary directory for processing
TEMP_DIR = '/content/temp_hovernet'
os.makedirs(TEMP_DIR, exist_ok=True)

print("="*70)
print("CONFIGURATION")
print("="*70)
print(f"Weights directory: {WEIGHTS_DIR}")
print(f"Dataset base path: {DATASET_BASE_PATH}")
print(f"Output base path: {OUTPUT_BASE_PATH}")
print(f"HoverNet directory: {HOVERNET_DIR}")
print(f"\nCell classes to process: {len(CELL_CLASSES)}")
for i, cell_class in enumerate(CELL_CLASSES, 1):
    print(f"  {i}. {cell_class}")
print("="*70)

## 5. Verify Weights and Dataset

In [None]:
def find_weight_file():
    """
    Find the HoverNet weight file (.tar) in the weights directory.
    """
    print("Searching for HoverNet weights...\n")
    
    if not os.path.exists(WEIGHTS_DIR):
        print(f"❌ ERROR: Weights directory not found: {WEIGHTS_DIR}")
        return None
    
    # List all files
    all_files = os.listdir(WEIGHTS_DIR)
    print(f"Files in weights directory:")
    for f in all_files:
        file_path = os.path.join(WEIGHTS_DIR, f)
        size_mb = os.path.getsize(file_path) / (1024 * 1024)
        print(f"  - {f} ({size_mb:.1f} MB)")
    
    # Find .tar files (HoverNet PyTorch weights)
    tar_files = [f for f in all_files if f.endswith('.tar')]
    
    if not tar_files:
        print("\n❌ ERROR: No .tar weight files found!")
        print("Please download HoverNet weights from:")
        print("https://github.com/vqdang/hover_net#data-format")
        return None
    
    # Use the first .tar file found
    weight_file = os.path.join(WEIGHTS_DIR, tar_files[0])
    print(f"\n✓ Found weight file: {tar_files[0]}")
    
    return weight_file


def verify_dataset():
    """
    Verify dataset structure and count images.
    """
    print("\n" + "="*70)
    print("Verifying dataset...")
    print("="*70 + "\n")
    
    if not os.path.exists(DATASET_BASE_PATH):
        print(f"❌ ERROR: Dataset not found: {DATASET_BASE_PATH}")
        return False, {}
    
    total_images = 0
    class_counts = {}
    
    for cell_class in CELL_CLASSES:
        image_dir = os.path.join(DATASET_BASE_PATH, cell_class, IMAGE_SUBFOLDER)
        
        if not os.path.exists(image_dir):
            print(f"❌ {cell_class}: Directory not found")
            class_counts[cell_class] = 0
            continue
        
        images = glob.glob(os.path.join(image_dir, '*.bmp'))
        count = len(images)
        class_counts[cell_class] = count
        total_images += count
        print(f"✓ {cell_class:40s}: {count:5d} images")
    
    print(f"\n{'='*70}")
    print(f"Total images: {total_images}")
    print(f"{'='*70}")
    
    return total_images > 0, class_counts


# Run verifications
WEIGHT_FILE = find_weight_file()
dataset_valid, image_counts = verify_dataset()

if WEIGHT_FILE and dataset_valid:
    print("\n" + "="*70)
    print("✓ ALL VERIFICATIONS PASSED")
    print("="*70)
else:
    print("\n" + "="*70)
    print("❌ VERIFICATION FAILED - Please fix errors above")
    print("="*70)

## 6. Load HoverNet Model

In [None]:
# Import HoverNet model components
try:
    from hover_net.models.hovernet.net_desc import create_model
    from hover_net.models.hovernet.post_proc import process
    print("✓ HoverNet modules imported successfully")
    HOVERNET_MODULES_AVAILABLE = True
except ImportError as e:
    print(f"⚠ Could not import HoverNet modules: {e}")
    print("Will use alternative inference method")
    HOVERNET_MODULES_AVAILABLE = False

In [None]:
def load_hovernet_model(weight_path):
    """
    Load HoverNet model from .tar checkpoint file.
    
    Args:
        weight_path: Path to the .tar weight file
        
    Returns:
        model: Loaded PyTorch model
        model_config: Model configuration dict
    """
    print(f"\nLoading HoverNet model from: {weight_path}")
    print("This may take a moment...\n")
    
    # Load checkpoint
    checkpoint = torch.load(weight_path, map_location=DEVICE)
    
    # Print checkpoint contents
    print(f"Checkpoint keys: {list(checkpoint.keys())}")
    
    # Get model configuration
    if 'desc' in checkpoint:
        model_config = checkpoint['desc']
        print(f"Model description: {model_config}")
    else:
        model_config = None
    
    # Determine number of types from checkpoint
    nr_types = 0  # Default: no type classification
    if model_config:
        # Try to extract nr_types from model description
        if isinstance(model_config, dict) and 'nr_types' in model_config:
            nr_types = model_config['nr_types']
    
    # Check state dict for type head
    state_dict = checkpoint.get('state_dict', checkpoint.get('model_state_dict', checkpoint))
    
    # Look for type prediction layers to determine nr_types
    for key in state_dict.keys():
        if 'tp' in key.lower() or 'type' in key.lower():
            print(f"Found type-related key: {key}")
            if 'tp.u0.conv.weight' in key:
                # Extract nr_types from output channels
                nr_types = state_dict[key].shape[0]
                print(f"Detected nr_types from weights: {nr_types}")
                break
    
    print(f"\nModel configuration:")
    print(f"  - Number of types: {nr_types}")
    print(f"  - Device: {DEVICE}")
    
    # Create model
    if HOVERNET_MODULES_AVAILABLE:
        try:
            model = create_model(
                mode='fast',  # or 'original'
                nr_types=nr_types if nr_types > 0 else None
            )
            
            # Load state dict
            model.load_state_dict(state_dict, strict=False)
            model = model.to(DEVICE)
            model.eval()
            
            print("\n✓ HoverNet model loaded successfully!")
            return model, {'nr_types': nr_types}
            
        except Exception as e:
            print(f"⚠ Error creating model: {e}")
            print("Will use alternative method")
    
    # Return checkpoint for alternative processing
    return checkpoint, {'nr_types': nr_types, 'state_dict': state_dict}


# Load the model
if WEIGHT_FILE:
    model, model_config = load_hovernet_model(WEIGHT_FILE)
else:
    model = None
    model_config = None
    print("❌ Cannot load model - weight file not found")

## 7. Inference Functions

In [None]:
from scipy.ndimage import label as scipy_label
from scipy.ndimage import binary_fill_holes
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from scipy import ndimage

def preprocess_image(image):
    """
    Preprocess image for HoverNet inference.
    
    Args:
        image: RGB image (H, W, 3) uint8
        
    Returns:
        tensor: Preprocessed image tensor
    """
    # Normalize to [0, 1]
    img = image.astype(np.float32) / 255.0
    
    # Convert to tensor (N, C, H, W)
    img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
    
    return img_tensor.to(DEVICE)


def post_process_nuclei(pred_np, pred_hv):
    """
    Post-process HoverNet predictions to get instance segmentation.
    
    Args:
        pred_np: Nuclear pixel prediction (H, W)
        pred_hv: Horizontal-Vertical gradient prediction (H, W, 2)
        
    Returns:
        instance_map: Instance segmentation map (H, W)
    """
    # Threshold nuclear prediction
    pred_np_binary = pred_np > 0.5
    
    if not np.any(pred_np_binary):
        return np.zeros_like(pred_np, dtype=np.int32)
    
    # Use HV gradients to separate touching nuclei
    h_grad = pred_hv[..., 0]
    v_grad = pred_hv[..., 1]
    
    # Compute gradient magnitude
    grad_mag = np.sqrt(h_grad**2 + v_grad**2)
    
    # Find local maxima (nuclei centers)
    energy = ndimage.gaussian_filter(pred_np * (1 - grad_mag), sigma=1)
    
    # Watershed segmentation
    markers = np.zeros_like(pred_np, dtype=np.int32)
    
    # Find peaks
    coordinates = peak_local_max(
        energy,
        min_distance=5,
        threshold_abs=0.1,
        labels=pred_np_binary.astype(np.int32)
    )
    
    if len(coordinates) == 0:
        # Fallback: use connected components
        instance_map, _ = scipy_label(pred_np_binary)
        return instance_map
    
    for i, coord in enumerate(coordinates, start=1):
        markers[coord[0], coord[1]] = i
    
    # Apply watershed
    instance_map = watershed(-energy, markers, mask=pred_np_binary)
    
    return instance_map.astype(np.int32)


def run_hovernet_inference_direct(model, image):
    """
    Run HoverNet inference using direct model forward pass.
    
    Args:
        model: Loaded HoverNet model
        image: RGB image (H, W, 3) uint8
        
    Returns:
        dict: Segmentation results
    """
    try:
        # Preprocess
        img_tensor = preprocess_image(image)
        
        # Inference
        with torch.no_grad():
            output = model(img_tensor)
        
        # Extract predictions
        if isinstance(output, dict):
            pred_np = output.get('np', output.get('NP', None))
            pred_hv = output.get('hv', output.get('HV', None))
        elif isinstance(output, (list, tuple)):
            pred_np = output[0]
            pred_hv = output[1] if len(output) > 1 else None
        else:
            pred_np = output
            pred_hv = None
        
        # Convert to numpy
        if pred_np is not None:
            pred_np = torch.softmax(pred_np, dim=1)[0, 1].cpu().numpy()
        
        if pred_hv is not None:
            pred_hv = pred_hv[0].permute(1, 2, 0).cpu().numpy()
        else:
            pred_hv = np.zeros((*pred_np.shape, 2))
        
        # Post-process
        instance_map = post_process_nuclei(pred_np, pred_hv)
        num_nuclei = len(np.unique(instance_map)) - 1  # Exclude background
        
        return {
            'instance_map': instance_map,
            'num_nuclei': max(0, num_nuclei),
            'pred_np': pred_np
        }
        
    except Exception as e:
        print(f"⚠ Direct inference failed: {e}")
        return None


def fallback_segmentation(image):
    """
    Fallback nuclear segmentation using traditional CV methods.
    
    Args:
        image: RGB image (H, W, 3)
        
    Returns:
        dict: Segmentation results
    """
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
    # Apply CLAHE
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced = clahe.apply(gray)
    
    # Otsu thresholding
    _, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    
    # Morphological operations
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)
    binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=2)
    
    # Distance transform for watershed
    dist_transform = cv2.distanceTransform(binary, cv2.DIST_L2, 5)
    
    # Find local maxima
    _, sure_fg = cv2.threshold(dist_transform, 0.3 * dist_transform.max(), 255, 0)
    sure_fg = sure_fg.astype(np.uint8)
    
    # Find unknown region
    sure_bg = cv2.dilate(binary, kernel, iterations=3)
    unknown = cv2.subtract(sure_bg, sure_fg)
    
    # Marker labelling
    _, markers = cv2.connectedComponents(sure_fg)
    markers = markers + 1
    markers[unknown == 255] = 0
    
    # Watershed
    markers = cv2.watershed(image, markers)
    
    # Create instance map
    instance_map = np.maximum(markers, 0).astype(np.int32)
    instance_map[markers == -1] = 0  # Mark boundaries as background
    
    # Filter small regions
    unique_ids = np.unique(instance_map)
    for uid in unique_ids:
        if uid <= 1:  # Skip background and borders
            continue
        mask = instance_map == uid
        area = np.sum(mask)
        if area < 50 or area > 5000:  # Filter by size
            instance_map[mask] = 0
    
    # Relabel to ensure consecutive IDs
    unique_ids = np.unique(instance_map)
    unique_ids = unique_ids[unique_ids > 0]
    new_instance_map = np.zeros_like(instance_map)
    for new_id, old_id in enumerate(unique_ids, start=1):
        new_instance_map[instance_map == old_id] = new_id
    
    num_nuclei = len(unique_ids)
    
    return {
        'instance_map': new_instance_map,
        'num_nuclei': num_nuclei,
        'pred_np': None
    }


def run_inference(image, use_model=True):
    """
    Run nuclear segmentation inference.
    
    Args:
        image: RGB image (H, W, 3) uint8
        use_model: Whether to try using HoverNet model
        
    Returns:
        dict: Segmentation results
    """
    # Try HoverNet if available
    if use_model and model is not None and HOVERNET_MODULES_AVAILABLE:
        result = run_hovernet_inference_direct(model, image)
        if result is not None:
            return result
    
    # Fallback to traditional CV
    return fallback_segmentation(image)


print("✓ Inference functions defined")

## 8. Image Processing and Saving

In [None]:
def load_image(image_path):
    """
    Load an image and convert to RGB.
    """
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Failed to load: {image_path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def create_overlay(image, mask, alpha=0.4):
    """
    Create colorful overlay of segmentation on image.
    """
    colored_mask = np.zeros_like(image)
    
    unique_ids = np.unique(mask)
    unique_ids = unique_ids[unique_ids > 0]
    
    np.random.seed(42)
    colors = np.random.randint(50, 255, size=(len(unique_ids), 3))
    
    for idx, uid in enumerate(unique_ids):
        colored_mask[mask == uid] = colors[idx]
    
    # Blend
    overlay = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
    
    # Draw contours
    for uid in unique_ids:
        contours, _ = cv2.findContours(
            (mask == uid).astype(np.uint8),
            cv2.RETR_EXTERNAL,
            cv2.CHAIN_APPROX_SIMPLE
        )
        cv2.drawContours(overlay, contours, -1, (255, 255, 0), 1)
    
    return overlay


def save_results(output_dir, image_name, image, instance_map):
    """
    Save segmentation results.
    """
    os.makedirs(output_dir, exist_ok=True)
    base_name = os.path.splitext(image_name)[0]
    
    # Save original
    orig_path = os.path.join(output_dir, f"{base_name}_original.png")
    cv2.imwrite(orig_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
    
    # Save mask
    mask_path = os.path.join(output_dir, f"{base_name}_mask.png")
    cv2.imwrite(mask_path, instance_map.astype(np.uint16))
    
    # Save overlay
    overlay = create_overlay(image, instance_map)
    overlay_path = os.path.join(output_dir, f"{base_name}_overlay.png")
    cv2.imwrite(overlay_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    
    return {'original': orig_path, 'mask': mask_path, 'overlay': overlay_path}


print("✓ Image processing functions defined")

## 9. Batch Processing

In [None]:
def process_image(image_path, output_dir):
    """
    Process a single image.
    """
    try:
        image = load_image(image_path)
        result = run_inference(image)
        
        image_name = os.path.basename(image_path)
        saved = save_results(output_dir, image_name, image, result['instance_map'])
        
        return {
            'success': True,
            'image_name': image_name,
            'num_nuclei': result['num_nuclei'],
            'saved_paths': saved,
            'error': None
        }
    except Exception as e:
        return {
            'success': False,
            'image_name': os.path.basename(image_path),
            'num_nuclei': 0,
            'saved_paths': None,
            'error': str(e)
        }


def process_class(cell_class, max_images=None):
    """
    Process all images for a cell class.
    """
    input_dir = os.path.join(DATASET_BASE_PATH, cell_class, IMAGE_SUBFOLDER)
    output_dir = os.path.join(OUTPUT_BASE_PATH, cell_class)
    
    images = sorted(glob.glob(os.path.join(input_dir, '*.bmp')))
    if max_images:
        images = images[:max_images]
    
    print(f"\n{'='*60}")
    print(f"Processing: {cell_class}")
    print(f"Images: {len(images)}")
    print(f"{'='*60}")
    
    results = []
    for img_path in tqdm(images, desc=cell_class):
        result = process_image(img_path, output_dir)
        results.append(result)
    
    successful = sum(1 for r in results if r['success'])
    total_nuclei = sum(r['num_nuclei'] for r in results if r['success'])
    avg_nuclei = total_nuclei / successful if successful > 0 else 0
    
    print(f"\nSummary: {successful}/{len(images)} successful")
    print(f"Total nuclei: {total_nuclei}, Avg: {avg_nuclei:.1f}")
    
    return {
        'cell_class': cell_class,
        'total_images': len(images),
        'successful': successful,
        'failed': len(images) - successful,
        'total_nuclei': total_nuclei,
        'avg_nuclei_per_image': avg_nuclei
    }


def process_all(max_images_per_class=None):
    """
    Process all cell classes.
    """
    print("\n" + "="*60)
    print("STARTING BATCH PROCESSING")
    print("="*60)
    
    start_time = datetime.now()
    all_stats = {}
    
    for cell_class in CELL_CLASSES:
        stats = process_class(cell_class, max_images_per_class)
        all_stats[cell_class] = stats
    
    elapsed = (datetime.now() - start_time).total_seconds()
    
    # Overall statistics
    total_images = sum(s['total_images'] for s in all_stats.values())
    total_successful = sum(s['successful'] for s in all_stats.values())
    total_nuclei = sum(s['total_nuclei'] for s in all_stats.values())
    
    overall = {
        'processing_time_seconds': elapsed,
        'total_images': total_images,
        'total_successful': total_successful,
        'total_failed': total_images - total_successful,
        'total_nuclei_detected': total_nuclei,
        'class_statistics': all_stats
    }
    
    # Save statistics
    stats_file = os.path.join(OUTPUT_BASE_PATH, 'processing_statistics.json')
    with open(stats_file, 'w') as f:
        json.dump(overall, f, indent=2)
    
    print("\n" + "="*60)
    print("PROCESSING COMPLETE")
    print("="*60)
    print(f"Total: {total_images} images, {total_successful} successful")
    print(f"Nuclei detected: {total_nuclei}")
    print(f"Time: {elapsed:.1f}s ({elapsed/60:.1f} min)")
    print(f"Stats saved: {stats_file}")
    print("="*60)
    
    return overall


print("✓ Batch processing functions defined")

## 10. Visualization

In [None]:
def visualize_samples(cell_class, n=3):
    """
    Visualize sample results.
    """
    output_dir = os.path.join(OUTPUT_BASE_PATH, cell_class)
    overlays = sorted(glob.glob(os.path.join(output_dir, '*_overlay.png')))
    
    if not overlays:
        print(f"No results for {cell_class}")
        return
    
    import random
    samples = random.sample(overlays, min(n, len(overlays)))
    
    fig, axes = plt.subplots(1, len(samples), figsize=(5*len(samples), 5))
    if len(samples) == 1:
        axes = [axes]
    
    for i, path in enumerate(samples):
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        axes[i].imshow(img)
        axes[i].set_title(cell_class.replace('im_', ''))
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()


def plot_stats(stats):
    """
    Plot statistics.
    """
    class_stats = stats['class_statistics']
    classes = list(class_stats.keys())
    labels = [c.replace('im_', '') for c in classes]
    
    nuclei = [class_stats[c]['total_nuclei'] for c in classes]
    avg = [class_stats[c]['avg_nuclei_per_image'] for c in classes]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    ax1.bar(range(len(classes)), nuclei, color='steelblue')
    ax1.set_xticks(range(len(classes)))
    ax1.set_xticklabels(labels, rotation=45, ha='right')
    ax1.set_ylabel('Total Nuclei')
    ax1.set_title('Total Nuclei per Class')
    
    ax2.bar(range(len(classes)), avg, color='coral')
    ax2.set_xticks(range(len(classes)))
    ax2.set_xticklabels(labels, rotation=45, ha='right')
    ax2.set_ylabel('Avg Nuclei per Image')
    ax2.set_title('Average Nuclei per Image')
    
    plt.tight_layout()
    plt.show()


print("✓ Visualization functions defined")

## 11. Run Processing

### Test Run

In [None]:
# Test on first class with 5 images
if WEIGHT_FILE and dataset_valid:
    print("Running test with 5 images...\n")
    test_stats = process_class(CELL_CLASSES[0], max_images=5)
    
    if test_stats['successful'] > 0:
        print("\n✓ Test successful!")
        visualize_samples(CELL_CLASSES[0], n=3)
    else:
        print("\n❌ Test failed")
else:
    print("❌ Cannot run - check configuration")

### Process All Classes (Limited)

In [None]:
# Process 50 images per class
if WEIGHT_FILE and dataset_valid:
    overall_stats = process_all(max_images_per_class=50)
else:
    print("❌ Cannot run - check configuration")

### Process All Images (Full)

In [None]:
# # Uncomment to process ALL images
# if WEIGHT_FILE and dataset_valid:
#     overall_stats = process_all()
# else:
#     print("❌ Cannot run - check configuration")

## 12. Visualize Results

In [None]:
# Visualize samples from each class
if 'overall_stats' in locals():
    for cell_class in CELL_CLASSES:
        print(f"\n{cell_class}:")
        visualize_samples(cell_class, n=3)

In [None]:
# Plot statistics
if 'overall_stats' in locals():
    plot_stats(overall_stats)

## 13. Cleanup

In [None]:
# Clean temp files
if os.path.exists(TEMP_DIR):
    shutil.rmtree(TEMP_DIR)
    print(f"✓ Cleaned: {TEMP_DIR}")

## Notes

### Output Structure
```
HoverNet_Segmentation_Results/
├── im_Dyskeratotic/
│   ├── image_original.png
│   ├── image_mask.png
│   └── image_overlay.png
├── im_Koilocytotic/
├── im_Metaplastic/
├── im_Parabasal/
├── im_Superficial-Intermediate/
└── processing_statistics.json
```

### Troubleshooting
- **Model loading fails**: Ensure the .tar file is a valid HoverNet checkpoint
- **Out of memory**: Reduce batch size or use fewer images
- **Poor results**: The fallback CV method is used if HoverNet fails

### References
- [HoverNet Paper](https://arxiv.org/abs/1812.06499)
- [HoverNet GitHub](https://github.com/vqdang/hover_net)