# Nuclear Segmentation using HoverNet

This notebook implements nuclear segmentation for cervical cancer cell classification using the HoverNet model.

**Dataset Structure:**
- Base Directory: `Augmented Dataset - Limited Enhancement`
- Classes: `im_Dyskeratotic`, `im_Koilocytotic`, `im_Metaplastic`, `im_Parabasal`, `im_Superficial-Intermediate`
- Images Location: `<class_folder>/NLM_CLAHE/*.bmp`

**Reference:** [HoverNet GitHub](https://github.com/vqdang/hover_net)

## 1. Environment Setup

In [None]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running on Google Colab")
except:
    IN_COLAB = False
    print("Not running on Google Colab")

In [None]:
# Install required dependencies
if IN_COLAB:
    !pip install opencv-python-headless
    !pip install scikit-image
    !pip install scipy
    !pip install matplotlib
    !pip install tqdm
    !pip install imageio
    
    # Install TensorFlow (HoverNet uses TensorFlow 1.x or 2.x)
    !pip install tensorflow==2.12.0
    
    print("\n✓ Dependencies installed successfully")

In [None]:
# Clone HoverNet repository
if IN_COLAB:
    import os
    if not os.path.exists('hover_net'):
        !git clone https://github.com/vqdang/hover_net.git
        print("\n✓ HoverNet repository cloned")
    else:
        print("\n✓ HoverNet repository already exists")
    
    # Add HoverNet to Python path
    import sys
    sys.path.insert(0, '/content/hover_net')
    print("✓ HoverNet added to Python path")

## 2. Mount Google Drive

In [None]:
# Mount Google Drive
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    print("\n✓ Google Drive mounted successfully")
else:
    print("Skipping Google Drive mount (not on Colab)")

## 3. Import Libraries

In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import json
from datetime import datetime
import glob
import tensorflow as tf

print(f"TensorFlow version: {tf.__version__}")
print(f"OpenCV version: {cv2.__version__}")
print("\n✓ All libraries imported successfully")

## 4. Configuration and Dataset Paths

In [None]:
# Configure dataset paths
# IMPORTANT: Update this path to match your Google Drive structure
DRIVE_BASE_PATH = '/content/drive/MyDrive'  # Default Google Drive path

# Dataset configuration
DATASET_NAME = 'Augmented Dataset - Limited Enhancement'
DATASET_BASE_PATH = os.path.join(DRIVE_BASE_PATH, DATASET_NAME)

# Cell classes
CELL_CLASSES = [
    'im_Dyskeratotic',
    'im_Koilocytotic',
    'im_Metaplastic',
    'im_Parabasal',
    'im_Superficial-Intermediate'
]

# Subfolder containing images
IMAGE_SUBFOLDER = 'NLM_CLAHE'

# Output directory for segmentation results
OUTPUT_BASE_PATH = os.path.join(DRIVE_BASE_PATH, 'HoverNet_Segmentation_Results')
os.makedirs(OUTPUT_BASE_PATH, exist_ok=True)

print(f"Dataset base path: {DATASET_BASE_PATH}")
print(f"Output base path: {OUTPUT_BASE_PATH}")
print(f"\nCell classes to process: {len(CELL_CLASSES)}")
for i, cell_class in enumerate(CELL_CLASSES, 1):
    print(f"  {i}. {cell_class}")

In [None]:
# Verify dataset exists and count images
def verify_dataset():
    """Verify dataset structure and count images per class."""
    print("Verifying dataset structure...\n")
    
    if not os.path.exists(DATASET_BASE_PATH):
        print(f"❌ ERROR: Dataset base path not found: {DATASET_BASE_PATH}")
        print("\nPlease update DRIVE_BASE_PATH and DATASET_NAME in the previous cell.")
        return False
    
    total_images = 0
    class_image_counts = {}
    
    for cell_class in CELL_CLASSES:
        class_path = os.path.join(DATASET_BASE_PATH, cell_class)
        image_path = os.path.join(class_path, IMAGE_SUBFOLDER)
        
        if not os.path.exists(image_path):
            print(f"❌ WARNING: Image path not found: {image_path}")
            class_image_counts[cell_class] = 0
            continue
        
        # Count .bmp files
        bmp_files = glob.glob(os.path.join(image_path, '*.bmp'))
        count = len(bmp_files)
        class_image_counts[cell_class] = count
        total_images += count
        
        print(f"✓ {cell_class}: {count} images")
    
    print(f"\n{'='*50}")
    print(f"Total images to process: {total_images}")
    print(f"{'='*50}\n")
    
    return total_images > 0, class_image_counts

# Run verification
dataset_valid, image_counts = verify_dataset()

## 5. Download HoverNet Pre-trained Weights

In [None]:
# Download pre-trained weights for HoverNet
# The original repository provides pre-trained weights

WEIGHTS_DIR = '/content/hover_net_weights'
os.makedirs(WEIGHTS_DIR, exist_ok=True)

print("Downloading HoverNet pre-trained weights...")
print("\nNote: HoverNet provides multiple pre-trained models:")
print("  - CoNSeP dataset model")
print("  - Kumar dataset model")
print("  - CPM dataset model")
print("\nFor cervical cancer cells, you may need to fine-tune or use the general model.")
print("\nPlease download weights from: https://github.com/vqdang/hover_net#data-format")
print(f"And place them in: {WEIGHTS_DIR}")
print("\nAlternatively, you can use the model checkpoint directly from the repository.")

## 6. HoverNet Model Setup

In [None]:
# Import HoverNet modules
try:
    # These imports depend on the HoverNet repository structure
    # Adjust if the repository structure changes
    from hover_net.infer.tile import InferManager
    from hover_net.config import Config
    print("✓ HoverNet modules imported successfully")
except ImportError as e:
    print(f"⚠ Import warning: {e}")
    print("\nNote: HoverNet may require manual configuration.")
    print("Please check the repository structure and adjust imports accordingly.")

In [None]:
# Alternative: Use HoverNet inference script directly
# This approach uses the command-line interface

def setup_hovernet_config():
    """
    Setup HoverNet configuration for inference.
    """
    config = {
        'model_path': os.path.join(WEIGHTS_DIR, 'hovernet_original_consep_notype_tf2pytorch.tar'),
        'nr_inference_workers': 4,
        'nr_post_proc_workers': 4,
        'batch_size': 8,
        'input_shape': [256, 256],  # HoverNet default input size
        'output_types': ['instance', 'type'],  # Segment nuclei instances and types
    }
    
    return config

config = setup_hovernet_config()
print("HoverNet configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## 7. Image Processing Functions

In [None]:
def load_image(image_path):
    """
    Load a .bmp image and convert to RGB.
    
    Args:
        image_path: Path to the image file
        
    Returns:
        numpy.ndarray: RGB image
    """
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Failed to load image: {image_path}")
    
    # Convert BGR to RGB
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img_rgb


def preprocess_for_hovernet(image):
    """
    Preprocess image for HoverNet inference.
    
    Args:
        image: RGB image array
        
    Returns:
        numpy.ndarray: Preprocessed image
    """
    # HoverNet expects images in RGB format with values in [0, 255]
    # Ensure image is in the correct format
    if image.dtype != np.uint8:
        image = (image * 255).astype(np.uint8)
    
    return image


def save_segmentation_result(output_dir, image_name, original_image, segmentation_mask):
    """
    Save segmentation results including original image, mask, and overlay.
    
    Args:
        output_dir: Directory to save results
        image_name: Name of the original image
        original_image: Original RGB image
        segmentation_mask: Segmentation mask from HoverNet
    """
    os.makedirs(output_dir, exist_ok=True)
    
    base_name = os.path.splitext(image_name)[0]
    
    # Save original image
    original_path = os.path.join(output_dir, f"{base_name}_original.png")
    cv2.imwrite(original_path, cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR))
    
    # Save segmentation mask
    mask_path = os.path.join(output_dir, f"{base_name}_mask.png")
    cv2.imwrite(mask_path, segmentation_mask.astype(np.uint16))
    
    # Create and save overlay
    overlay = create_overlay(original_image, segmentation_mask)
    overlay_path = os.path.join(output_dir, f"{base_name}_overlay.png")
    cv2.imwrite(overlay_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    
    return {
        'original': original_path,
        'mask': mask_path,
        'overlay': overlay_path
    }


def create_overlay(image, mask, alpha=0.5):
    """
    Create an overlay of the segmentation mask on the original image.
    
    Args:
        image: Original RGB image
        mask: Segmentation mask
        alpha: Transparency factor
        
    Returns:
        numpy.ndarray: Overlay image
    """
    # Create a colorized version of the mask
    colored_mask = np.zeros_like(image)
    
    # Assign different colors to different nuclei instances
    unique_instances = np.unique(mask)
    np.random.seed(42)  # For reproducible colors
    
    for instance_id in unique_instances:
        if instance_id == 0:  # Skip background
            continue
        color = np.random.randint(0, 255, size=3)
        colored_mask[mask == instance_id] = color
    
    # Blend original image with colored mask
    overlay = cv2.addWeighted(image, 1-alpha, colored_mask, alpha, 0)
    
    return overlay

print("✓ Image processing functions defined")

## 8. HoverNet Inference Function

In [None]:
def run_hovernet_inference(image):
    """
    Run HoverNet inference on a single image.
    
    This is a placeholder function. The actual implementation depends on
    how you set up HoverNet (using the Python API or command-line interface).
    
    Args:
        image: RGB image array
        
    Returns:
        dict: Segmentation results containing:
            - 'instance_map': Instance segmentation map
            - 'type_map': Cell type classification map (if available)
            - 'contours': List of nucleus contours
    """
    # TODO: Implement actual HoverNet inference
    # This will depend on the HoverNet setup
    
    # Option 1: Use HoverNet Python API
    # result = infer_manager.run_inference(image)
    
    # Option 2: Use command-line interface
    # Save image temporarily, run inference, load results
    
    # For now, return a placeholder
    print("⚠ Warning: Using placeholder inference. Implement actual HoverNet inference.")
    
    # Placeholder: Simple thresholding as example
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    # Find contours
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Create instance map
    instance_map = np.zeros(image.shape[:2], dtype=np.uint16)
    for idx, contour in enumerate(contours, start=1):
        cv2.drawContours(instance_map, [contour], -1, idx, -1)
    
    return {
        'instance_map': instance_map,
        'type_map': None,
        'contours': contours,
        'num_nuclei': len(contours)
    }

print("✓ HoverNet inference function defined")

## 9. Batch Processing Pipeline

In [None]:
def process_single_image(image_path, output_dir):
    """
    Process a single image through the HoverNet pipeline.
    
    Args:
        image_path: Path to input image
        output_dir: Directory to save results
        
    Returns:
        dict: Processing results and statistics
    """
    try:
        # Load image
        image = load_image(image_path)
        
        # Preprocess
        preprocessed = preprocess_for_hovernet(image)
        
        # Run HoverNet inference
        results = run_hovernet_inference(preprocessed)
        
        # Save results
        image_name = os.path.basename(image_path)
        saved_paths = save_segmentation_result(
            output_dir,
            image_name,
            image,
            results['instance_map']
        )
        
        return {
            'success': True,
            'image_path': image_path,
            'num_nuclei': results['num_nuclei'],
            'saved_paths': saved_paths,
            'error': None
        }
        
    except Exception as e:
        return {
            'success': False,
            'image_path': image_path,
            'num_nuclei': 0,
            'saved_paths': None,
            'error': str(e)
        }


def process_cell_class(cell_class, progress_bar=True):
    """
    Process all images for a specific cell class.
    
    Args:
        cell_class: Name of the cell class
        progress_bar: Whether to show progress bar
        
    Returns:
        dict: Processing statistics
    """
    # Setup paths
    input_dir = os.path.join(DATASET_BASE_PATH, cell_class, IMAGE_SUBFOLDER)
    output_dir = os.path.join(OUTPUT_BASE_PATH, cell_class)
    
    # Get all .bmp files
    image_files = glob.glob(os.path.join(input_dir, '*.bmp'))
    
    print(f"\nProcessing {cell_class}: {len(image_files)} images")
    print(f"Output directory: {output_dir}\n")
    
    # Process each image
    results = []
    iterator = tqdm(image_files) if progress_bar else image_files
    
    for image_path in iterator:
        result = process_single_image(image_path, output_dir)
        results.append(result)
        
        if not result['success']:
            print(f"\n❌ Error processing {os.path.basename(image_path)}: {result['error']}")
    
    # Calculate statistics
    successful = sum(1 for r in results if r['success'])
    failed = len(results) - successful
    total_nuclei = sum(r['num_nuclei'] for r in results if r['success'])
    avg_nuclei = total_nuclei / successful if successful > 0 else 0
    
    stats = {
        'cell_class': cell_class,
        'total_images': len(image_files),
        'successful': successful,
        'failed': failed,
        'total_nuclei': total_nuclei,
        'avg_nuclei_per_image': avg_nuclei,
        'results': results
    }
    
    return stats


def process_all_classes():
    """
    Process all cell classes in the dataset.
    
    Returns:
        dict: Complete processing statistics
    """
    print("="*70)
    print("Starting batch processing of all cell classes")
    print("="*70)
    
    start_time = datetime.now()
    all_stats = {}
    
    for cell_class in CELL_CLASSES:
        stats = process_cell_class(cell_class)
        all_stats[cell_class] = stats
        
        # Print summary for this class
        print(f"\n{'='*70}")
        print(f"Summary for {cell_class}:")
        print(f"  Total images: {stats['total_images']}")
        print(f"  Successful: {stats['successful']}")
        print(f"  Failed: {stats['failed']}")
        print(f"  Total nuclei detected: {stats['total_nuclei']}")
        print(f"  Average nuclei per image: {stats['avg_nuclei_per_image']:.2f}")
        print(f"{'='*70}\n")
    
    end_time = datetime.now()
    processing_time = (end_time - start_time).total_seconds()
    
    # Overall statistics
    total_images = sum(stats['total_images'] for stats in all_stats.values())
    total_successful = sum(stats['successful'] for stats in all_stats.values())
    total_failed = sum(stats['failed'] for stats in all_stats.values())
    total_nuclei = sum(stats['total_nuclei'] for stats in all_stats.values())
    
    overall_stats = {
        'start_time': start_time.isoformat(),
        'end_time': end_time.isoformat(),
        'processing_time_seconds': processing_time,
        'total_images': total_images,
        'total_successful': total_successful,
        'total_failed': total_failed,
        'total_nuclei_detected': total_nuclei,
        'class_statistics': all_stats
    }
    
    # Save overall statistics
    stats_file = os.path.join(OUTPUT_BASE_PATH, 'processing_statistics.json')
    with open(stats_file, 'w') as f:
        json.dump(overall_stats, f, indent=2)
    
    print("\n" + "="*70)
    print("OVERALL PROCESSING COMPLETE")
    print("="*70)
    print(f"Total images processed: {total_images}")
    print(f"Successful: {total_successful}")
    print(f"Failed: {total_failed}")
    print(f"Total nuclei detected: {total_nuclei}")
    print(f"Processing time: {processing_time:.2f} seconds ({processing_time/60:.2f} minutes)")
    print(f"Statistics saved to: {stats_file}")
    print("="*70 + "\n")
    
    return overall_stats

print("✓ Batch processing functions defined")

## 10. Visualization Functions

In [None]:
def visualize_sample_results(cell_class, num_samples=3):
    """
    Visualize sample segmentation results for a cell class.
    
    Args:
        cell_class: Name of the cell class
        num_samples: Number of samples to visualize
    """
    output_dir = os.path.join(OUTPUT_BASE_PATH, cell_class)
    
    # Get sample overlay images
    overlay_files = glob.glob(os.path.join(output_dir, '*_overlay.png'))
    
    if not overlay_files:
        print(f"No results found for {cell_class}")
        return
    
    # Select random samples
    import random
    samples = random.sample(overlay_files, min(num_samples, len(overlay_files)))
    
    # Create visualization
    fig, axes = plt.subplots(1, len(samples), figsize=(15, 5))
    if len(samples) == 1:
        axes = [axes]
    
    for idx, overlay_path in enumerate(samples):
        overlay = cv2.imread(overlay_path)
        overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
        
        axes[idx].imshow(overlay_rgb)
        axes[idx].set_title(f"{cell_class}\n{os.path.basename(overlay_path)}")
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()


def plot_statistics(stats):
    """
    Plot processing statistics across all cell classes.
    
    Args:
        stats: Overall statistics dictionary
    """
    class_stats = stats['class_statistics']
    
    cell_classes = list(class_stats.keys())
    nuclei_counts = [class_stats[c]['total_nuclei'] for c in cell_classes]
    avg_nuclei = [class_stats[c]['avg_nuclei_per_image'] for c in cell_classes]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot 1: Total nuclei per class
    ax1.bar(range(len(cell_classes)), nuclei_counts)
    ax1.set_xlabel('Cell Class')
    ax1.set_ylabel('Total Nuclei Detected')
    ax1.set_title('Total Nuclei Detected per Cell Class')
    ax1.set_xticks(range(len(cell_classes)))
    ax1.set_xticklabels([c.replace('im_', '') for c in cell_classes], rotation=45, ha='right')
    
    # Plot 2: Average nuclei per image
    ax2.bar(range(len(cell_classes)), avg_nuclei)
    ax2.set_xlabel('Cell Class')
    ax2.set_ylabel('Average Nuclei per Image')
    ax2.set_title('Average Nuclei per Image by Cell Class')
    ax2.set_xticks(range(len(cell_classes)))
    ax2.set_xticklabels([c.replace('im_', '') for c in cell_classes], rotation=45, ha='right')
    
    plt.tight_layout()
    plt.show()

print("✓ Visualization functions defined")

## 11. Run Processing

**⚠️ IMPORTANT:** Before running this cell, make sure:
1. HoverNet repository is cloned
2. Pre-trained weights are downloaded
3. `run_hovernet_inference()` function is properly implemented
4. Dataset paths are correctly configured

In [None]:
# Process all images
if dataset_valid:
    # Run batch processing
    overall_stats = process_all_classes()
else:
    print("❌ Dataset validation failed. Please check your dataset paths.")

## 12. Visualize Results

In [None]:
# Visualize sample results for each class
for cell_class in CELL_CLASSES:
    print(f"\nSample results for {cell_class}:")
    visualize_sample_results(cell_class, num_samples=3)

In [None]:
# Plot overall statistics
if 'overall_stats' in locals():
    plot_statistics(overall_stats)

## 13. Export Results Summary

In [None]:
# Create a summary report
def create_summary_report(stats):
    """
    Create a formatted summary report of the processing results.
    """
    report = []
    report.append("="*80)
    report.append("HoverNet Nuclear Segmentation - Processing Report")
    report.append("="*80)
    report.append(f"\nProcessing Date: {stats['start_time']}")
    report.append(f"Total Processing Time: {stats['processing_time_seconds']:.2f} seconds")
    report.append(f"\nDataset: {DATASET_NAME}")
    report.append(f"Output Location: {OUTPUT_BASE_PATH}")
    
    report.append("\n" + "="*80)
    report.append("OVERALL STATISTICS")
    report.append("="*80)
    report.append(f"Total Images Processed: {stats['total_images']}")
    report.append(f"Successful: {stats['total_successful']}")
    report.append(f"Failed: {stats['total_failed']}")
    report.append(f"Total Nuclei Detected: {stats['total_nuclei_detected']}")
    
    report.append("\n" + "="*80)
    report.append("PER-CLASS STATISTICS")
    report.append("="*80)
    
    for cell_class, class_stats in stats['class_statistics'].items():
        report.append(f"\n{cell_class}:")
        report.append(f"  Images: {class_stats['total_images']}")
        report.append(f"  Successful: {class_stats['successful']}")
        report.append(f"  Failed: {class_stats['failed']}")
        report.append(f"  Total Nuclei: {class_stats['total_nuclei']}")
        report.append(f"  Avg Nuclei/Image: {class_stats['avg_nuclei_per_image']:.2f}")
    
    report.append("\n" + "="*80)
    
    return "\n".join(report)


if 'overall_stats' in locals():
    # Generate and save report
    report = create_summary_report(overall_stats)
    
    # Print report
    print(report)
    
    # Save report to file
    report_file = os.path.join(OUTPUT_BASE_PATH, 'processing_report.txt')
    with open(report_file, 'w') as f:
        f.write(report)
    
    print(f"\n✓ Report saved to: {report_file}")

## 14. Additional Notes and Next Steps

### Important Implementation Notes:

1. **HoverNet Setup:** The `run_hovernet_inference()` function currently contains placeholder code. You need to implement the actual HoverNet inference based on how you configure the model.

2. **Model Weights:** Download the appropriate pre-trained weights from the HoverNet repository and update the `WEIGHTS_DIR` path.

3. **Memory Management:** For large datasets, consider processing images in batches and clearing memory periodically.

4. **GPU Acceleration:** If available, configure TensorFlow to use GPU for faster processing.

### Recommended Next Steps:

1. **Model Evaluation:** Validate segmentation quality on a subset of images
2. **Fine-tuning:** Consider fine-tuning HoverNet on cervical cell images if needed
3. **Feature Extraction:** Extract morphological features from segmented nuclei
4. **Integration:** Integrate segmentation results with your classification pipeline

### References:

- HoverNet Paper: https://arxiv.org/abs/1812.06499
- HoverNet Repository: https://github.com/vqdang/hover_net
- Documentation: Check the repository README for detailed usage instructions