In [None]:
!pip install -r requirements.txt -U
print("Complete")

In [None]:
# Cell 1: Setup and Authentication
import matplotlib.pyplot as plt
from datasets import load_dataset
import huggingface_hub
from PIL import Image
import numpy as np
import torch
from segment_anything import sam_model_registry
from utils.demo import BboxPromptDemo
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

# Login to Hugging Face
hf_token = input("Enter your Hugging Face token: ")
huggingface_hub.login(token=hf_token)

print("Authentication successful!")

In [None]:
# Quick debug cell - run this first to see available splits
dataset_input = input("Please enter dataset path: ")
temp_dataset = load_dataset(dataset_input, token=True)

print("🔍 Dataset Analysis:")
print(f"Available splits: {list(temp_dataset.keys())}")
for split_name in temp_dataset.keys():
    print(f"  - {split_name}: {len(temp_dataset[split_name])} items")
    if len(temp_dataset[split_name]) > 0:
        print(f"    Features: {temp_dataset[split_name].features}")
        
# Choose which split to use as 'train'
chosen_split = input(f"Which split to use as 'train'? {list(temp_dataset.keys())}: ")
print(f"Will use '{chosen_split}' split")

In [None]:
# Cell 2: Load Dataset and Model (Hugging Face Only)
print("Loading lymph node dataset...")
dataset_input = input("Please enter dataset path: ")
image_dataset = load_dataset(dataset_input, token=True)

print("Dataset info:")
print(f"Available splits: {list(image_dataset.keys())}")
print(f"Number of train images: {len(image_dataset['train'])}")
print(f"Features: {image_dataset['train'].features}")

# Load MedSAM model from Hugging Face
print("\nLoading MedSAM model from Hugging Face...")
from huggingface_hub import hf_hub_download

# Download model from Hugging Face repository
MedSAM_CKPT_PATH = hf_hub_download(
    repo_id="GleghornLab/medsam-vit-b",
    filename="medsam_vit_b.pth",
    token=True
)

print(f"✅ Model downloaded to: {MedSAM_CKPT_PATH}")

# Load the model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model = medsam_model.to(device)
medsam_model.eval()
print(f"✅ MedSAM model loaded successfully on {device}")

In [None]:
# Cell 3: Lymph Node Dataset Explorer Class

# Enhanced Multi-Selection Demo
class MultiRegionMedSAMTester:
    def __init__(self, model, dataset):
        self.model = model
        self.dataset = dataset
        self.current_image = None
        self.current_index = 0
        self.bounding_boxes = []  # Store multiple bounding boxes
        self.segmentation_results = []  # Store segmentation results
        
    def test_single_image_multi_region(self, image_index):
        """Test MedSAM with multiple region selection capability"""
        if image_index >= len(self.dataset['train']):
            print(f"❌ Index {image_index} out of range.")
            return
            
        sample = self.dataset['train'][image_index]
        self.current_image = sample['image']
        self.current_index = image_index
        self.bounding_boxes = []  # Reset for new image
        self.segmentation_results = []
        
        print(f"🧪 Multi-Region MedSAM Testing - Image #{image_index + 1}")
        print(f"Image size: {self.current_image.size}")
        print("=" * 60)
        
        # Show the original image
        self._display_current_state()
        
        print("\n📝 Instructions:")
        print("• add_region() - Add a new bounding box selection")
        print("• remove_region(n) - Remove region number n")
        print("• show_regions() - Display all current regions")
        print("• clear_all() - Remove all regions")
        print("• finish() - Complete and show final result")
        
    def add_region(self):
        """Add a new region selection"""
        if self.current_image is None:
            print("❌ No image loaded. Use test_single_image_multi_region(n) first.")
            return
            
        print(f"\n🎯 Adding region #{len(self.bounding_boxes) + 1}")
        print("Draw a bounding box around the follicle/structure...")
        
        # Create a new demo for this region
        temp_demo = BboxPromptDemo(self.model)
        temp_demo.show(self.current_image)
        
        # Note: In a real implementation, you'd need to capture the bounding box coordinates
        # For now, we'll simulate storing the region
        region_id = len(self.bounding_boxes) + 1
        self.bounding_boxes.append(f"Region_{region_id}")
        self.segmentation_results.append(f"Segmentation_{region_id}")
        
        print(f"✅ Region #{region_id} added!")
        self._show_region_summary()
        
    def remove_region(self, region_number):
        """Remove a specific region"""
        if not self.bounding_boxes:
            print("❌ No regions to remove.")
            return
            
        if region_number < 1 or region_number > len(self.bounding_boxes):
            print(f"❌ Invalid region number. Choose 1-{len(self.bounding_boxes)}")
            return
            
        removed_region = self.bounding_boxes.pop(region_number - 1)
        removed_result = self.segmentation_results.pop(region_number - 1)
        
        print(f"🗑️ Removed region #{region_number}")
        
        # Renumber remaining regions
        for i in range(len(self.bounding_boxes)):
            self.bounding_boxes[i] = f"Region_{i + 1}"
            self.segmentation_results[i] = f"Segmentation_{i + 1}"
            
        self._show_region_summary()
        self._display_current_state()
        
    def show_regions(self):
        """Display current regions"""
        if not self.bounding_boxes:
            print("📭 No regions selected yet.")
            return
            
        print(f"\n📊 Current Regions for Image #{self.current_index + 1}:")
        for i, region in enumerate(self.bounding_boxes, 1):
            print(f"  {i}. {region}")
            
    def clear_all(self):
        """Clear all regions"""
        self.bounding_boxes = []
        self.segmentation_results = []
        print("🧹 All regions cleared!")
        self._display_current_state()
        
    def finish(self):
        """Complete segmentation and show results"""
        if not self.bounding_boxes:
            print("❌ No regions selected. Add some regions first!")
            return
            
        print(f"\n🎉 Segmentation Complete!")
        print(f"Image #{self.current_index + 1}: {len(self.bounding_boxes)} regions segmented")
        
        # Display summary
        self._show_region_summary()
        self._display_final_results()
        
    def _display_current_state(self):
        """Display the current image and selected regions"""
        plt.figure(figsize=(10, 8))
        plt.imshow(self.current_image)
        plt.title(f"Lymph Node Image #{self.current_index + 1} - {len(self.bounding_boxes)} regions selected")
        plt.axis('off')
        
        # Add text showing current regions
        if self.bounding_boxes:
            region_text = f"Regions: {', '.join([f'{i+1}' for i in range(len(self.bounding_boxes))])}"
            plt.figtext(0.02, 0.02, region_text, fontsize=10, bbox=dict(boxstyle="round", facecolor='wheat', alpha=0.8))
            
        plt.show()
        
    def _show_region_summary(self):
        """Show summary of current regions"""
        if self.bounding_boxes:
            print(f"📋 Total regions: {len(self.bounding_boxes)}")
        else:
            print("📋 No regions selected")
            
    def _display_final_results(self):
        """Display final segmentation results"""
        # This would show the combined segmentation results
        # For now, just show the original image with region count
        plt.figure(figsize=(12, 8))
        plt.imshow(self.current_image)
        plt.title(f"Final Result: {len(self.bounding_boxes)} Regions Segmented")
        plt.axis('off')
        
        # Add completion message
        completion_text = f"✅ Segmentation Complete\n{len(self.bounding_boxes)} follicles/structures identified"
        plt.figtext(0.02, 0.98, completion_text, fontsize=12, 
                   bbox=dict(boxstyle="round", facecolor='lightgreen', alpha=0.8),
                   verticalalignment='top')
        plt.show()


class LymphNodeMedSAMTester:
    def __init__(self, model, dataset):
        self.model = model
        self.dataset = dataset
        self.bbox_demo = BboxPromptDemo(model)
        self.current_index = 0
        
    def show_dataset_overview(self, num_samples=6):
        """Show overview of lymph node images"""
        print("🔬 Lymph Node Dataset Overview")
        print("=" * 50)
        
        samples = self.dataset['train'].shuffle(seed=42).select(range(num_samples))
        fig, axs = plt.subplots(2, 3, figsize=(15, 10))
        axs = axs.flatten()
        
        for i, sample in enumerate(samples):
            image = sample['image']
            axs[i].imshow(image)
            axs[i].set_title(f"Sample {i+1}\nSize: {image.size}")
            axs[i].axis('off')
            
        plt.tight_layout()
        plt.suptitle("Lymph Node Histology Images", fontsize=16, y=1.02)
        plt.show()
        
    def test_single_image(self, image_index):
        """Test MedSAM on a single lymph node image"""
        if image_index >= len(self.dataset['train']):
            print(f"❌ Index {image_index} out of range. Dataset has {len(self.dataset['train'])} images.")
            return
            
        sample = self.dataset['train'][image_index]
        image = sample['image']
        
        print(f"Testing MedSAM on Lymph Node Image #{image_index + 1}")
        print(f"Image size: {image.size}")
        print(f"Image mode: {image.mode}")
        
        # Show the image first
        plt.figure(figsize=(8, 6))
        plt.imshow(image)
        plt.title(f"Lymph Node Section #{image_index + 1}")
        plt.axis('off')
        plt.show()
        
        print("\n📝 Instructions:")
        print("1. Draw a bounding box around the area you want to segment")
        print("2. Could be: lymphoid follicles, germinal centers, cortex, medulla, etc.")
        print("3. Use tight bounding boxes for best results")
        
        # Run MedSAM demo
        self.bbox_demo.show(image)
        
    def interactive_tester(self):
        """Create interactive widget for testing multiple images"""
        def update_image(image_index):
            clear_output(wait=True)
            self.test_single_image(image_index)
        
        # Create slider
        image_slider = widgets.IntSlider(
            value=0,
            min=0,
            max=min(100, len(self.dataset['train'])-1),  # Limit to first 100 for performance
            step=1,
            description='Image #:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='500px')
        )
        
        # Create button for random selection
        random_button = widgets.Button(
            description="Random Image",
            button_style='info',
            layout=widgets.Layout(width='150px')
        )
        
        def random_image(b):
            import random
            random_idx = random.randint(0, min(100, len(self.dataset['train'])-1))
            image_slider.value = random_idx
            
        random_button.on_click(random_image)
        
        # Display widgets
        display(widgets.HBox([image_slider, random_button]))
        widgets.interact(update_image, image_index=image_slider)
        
    def batch_preview(self, start_idx=0, num_images=9):
        """Preview multiple images in a grid"""
        print(f"🔍 Batch Preview: Images {start_idx} to {start_idx + num_images - 1}")
        
        fig, axs = plt.subplots(3, 3, figsize=(12, 12))
        axs = axs.flatten()
        
        for i in range(num_images):
            idx = start_idx + i
            if idx < len(self.dataset['train']):
                image = self.dataset['train'][idx]['image']
                axs[i].imshow(image)
                axs[i].set_title(f"Image #{idx + 1}")
                axs[i].axis('off')
            else:
                axs[i].axis('off')
                
        plt.tight_layout()
        plt.show()
        
    def compare_segmentations(self, image_indices=[0, 1, 2]):
        """Compare multiple images side by side for segmentation testing"""
        print("🔬 Lymph Node Segmentation Comparison")
        
        fig, axs = plt.subplots(1, len(image_indices), figsize=(5*len(image_indices), 5))
        if len(image_indices) == 1:
            axs = [axs]
            
        for i, idx in enumerate(image_indices):
            if idx < len(self.dataset['train']):
                image = self.dataset['train'][idx]['image']
                axs[i].imshow(image)
                axs[i].set_title(f"Image #{idx + 1}")
                axs[i].axis('off')
                
        plt.tight_layout()
        plt.show()
        
        print("Select an image number to test with MedSAM:")
        for i, idx in enumerate(image_indices):
            print(f"Image {i+1}: Index {idx}")

In [None]:
# # Cell 4: Initialize Tester and Show Overview
# # Create the lymph node tester
# ln_tester = LymphNodeMedSAMTester(medsam_model, image_dataset)

# # Show dataset overview
# ln_tester.show_dataset_overview()


# Cell 4: Initialize Both Testers
# Create the standard lymph node tester
ln_tester = LymphNodeMedSAMTester(medsam_model, image_dataset)

# Create the multi-region tester
multi_tester = MultiRegionMedSAMTester(medsam_model, image_dataset)

# Show dataset overview
ln_tester.show_dataset_overview()

print("\n🎯 Two Testing Modes Available:")
print("1. Single Region: Use quick_test() for simple one-region testing")
print("2. Multi Region: Use multi_region_test() for selecting multiple follicles")

In [None]:
# Cell 5: Interactive Testing Interface
print("Interactive Lymph Node MedSAM Tester")
print("=" * 50)
print("Use the slider below to select different lymph node images")
print("Then draw bounding boxes around structures you want to segment")

# Start interactive tester
ln_tester.interactive_tester()

In [None]:
# # Cell 6: Quick Test Functions
# # Quick test on first few images
# def quick_test(image_num=0):
#     """Quick test function for specific image"""
#     ln_tester.test_single_image(image_num)

# # Batch preview function
# def preview_batch(start=0):
#     """Preview a batch of 9 images"""
#     ln_tester.batch_preview(start, 9)

# print("Quick Test Functions Available:")
# print("• quick_test(image_num) - Test specific image")
# print("• preview_batch(start) - Preview 9 images starting from index")
# print("• ln_tester.interactive_tester() - Full interactive interface")

# # Show first batch
# preview_batch(0)


# Cell 6: Enhanced Quick Test Functions with Multi-Region Support
def quick_test(image_num=0):
    """Quick test function for specific image"""
    ln_tester.test_single_image(image_num)

def multi_region_test(image_num=0):
    """Start multi-region testing for an image"""
    multi_tester.test_single_image_multi_region(image_num)

def add_region():
    """Add a new region to current image"""
    multi_tester.add_region()

def remove_region(region_num):
    """Remove specific region number"""
    multi_tester.remove_region(region_num)

def show_regions():
    """Show all current regions"""
    multi_tester.show_regions()

def clear_all():
    """Clear all regions"""
    multi_tester.clear_all()

def finish_segmentation():
    """Finish and show results"""
    multi_tester.finish()

def preview_batch(start=0):
    """Preview a batch of 9 images"""
    ln_tester.batch_preview(start, 9)

print("Enhanced Functions Available:")
print("=" * 50)
print("Single Region Testing:")
print("• quick_test(image_num) - Test specific image (single region)")
print("• preview_batch(start) - Preview 9 images starting from index")
print()
print("Multi-Region Testing:")
print("• multi_region_test(image_num) - Start multi-region testing")
print("• add_region() - Add new follicle/structure")
print("• remove_region(n) - Remove region number n")
print("• show_regions() - List all current regions")
print("• clear_all() - Remove all regions")
print("• finish_segmentation() - Complete and show results")

print("\n Multi-Region Workflow:")
print("1. multi_region_test(5)  # Start with image 5")
print("2. add_region()          # Select first follicle")
print("3. add_region()          # Select second follicle")
print("4. remove_region(1)      # Remove first selection if wrong")
print("5. add_region()          # Add replacement selection")
print("6. finish_segmentation() # Complete")