In [None]:
#Cell 0:
!pip install -r requirements.txt -U
print("Complete")

In [None]:
# Cell 1: Setup and Authentication
%matplotlib widget

import matplotlib.pyplot as plt
from datasets import load_dataset
import huggingface_hub
from PIL import Image
import numpy as np
import torch
from segment_anything import sam_model_registry
from utils.demo import BboxPromptDemo
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

# Login to Hugging Face
hf_token = input("Enter your Hugging Face token: ")
huggingface_hub.login(token=hf_token)

print("Authentication successful!")

In [None]:
# Cell 2: Load Dataset and Model (From Hugging Face)

print("Loading lymph node dataset...")
dataset_input = input("Please enter dataset path: ")

try:
    # First, analyze the dataset structure
    temp_dataset = load_dataset(dataset_input, token=True)
    
    print("🔍 Dataset Analysis:")
    print(f"Available splits: {list(temp_dataset.keys())}")
    for split_name in temp_dataset.keys():
        print(f"  - {split_name}: {len(temp_dataset[split_name])} items")
        if len(temp_dataset[split_name]) > 0:
            print(f"    Features: {temp_dataset[split_name].features}")
    
    # Choose which split to use as 'train'
    chosen_split = input(f"Which split to use as 'train'? {list(temp_dataset.keys())}: ")
    
    if chosen_split in temp_dataset:
        print(f"Will use '{chosen_split}' split")
        
        # Load the chosen split and create 'train' alias
        image_dataset = temp_dataset
        image_dataset['train'] = image_dataset[chosen_split]
        
        print("Dataset info:")
        print(f"Available splits: {list(image_dataset.keys())}")
        print(f"Number of train images: {len(image_dataset['train'])}")
        print(f"Features: {image_dataset['train'].features}")
        
    else:
        raise ValueError(f" Split '{chosen_split}' not found in dataset")
        
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("\n Troubleshooting:")
    print("1. Check if the dataset repository exists and is accessible")
    print("2. Verify your HuggingFace token has the correct permissions")
    print("3. Make sure the dataset path is correct")
    print("4. Try a different split name if available")
    print("\n Cannot continue without valid dataset. Please fix the issue and try again.")
    raise  # Re-raise the error to stop execution

# Load MedSAM model from Hugging Face
print("\nLoading MedSAM model from Hugging Face...")
from huggingface_hub import hf_hub_download

try:
    # Use your manually downloaded checkpoint if already downloaded
    MedSAM_CKPT_PATH = "/home/medsam-vit-b/medsam_vit_b.pth"
    
    # #OR USE THIS section to download the model directly from Hugging Face into cache
    # MedSAM_CKPT_PATH = hf_hub_download(
    #     repo_id="GleghornLab/medsam-vit-b",
    #     filename="medsam_vit_b.pth",
    #     token=True
    # )
    
    print(f"Model downloaded to: {MedSAM_CKPT_PATH}")
    
    # Load the model
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
    medsam_model = medsam_model.to(device)
    medsam_model.eval()
    
    print(f"MedSAM model loaded successfully on {device}")
    
    # Show device info
    print(f"🖥️ Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        
except Exception as e:
    print(f"Error loading model: {e}")
    print("\nModel troubleshooting:")
    print("1. Check if the model repository 'GleghornLab/medsam-vit-b' exists")
    print("2. Verify you have access permissions to this repository")
    print("3. Ensure your HuggingFace token has the correct permissions")
    print("\n Cannot continue without valid model. Please fix the issue and try again.")
    raise  # Re-raise the error to stop execution

print("\n Setup complete! Dataset and model loaded successfully.")



# image_dataset = load_dataset(dataset_input, token=True)

# print("Dataset info:")
# print(f"Available splits: {list(image_dataset.keys())}")
# print(f"Number of train images: {len(image_dataset['train'])}")
# print(f"Features: {image_dataset['train'].features}")

# # Load MedSAM model from Hugging Face
# print("\nLoading MedSAM model from Hugging Face...")
# from huggingface_hub import hf_hub_download

# # Download model from Hugging Face repository
# MedSAM_CKPT_PATH = hf_hub_download(
#     repo_id="GleghornLab/medsam-vit-b",
#     filename="medsam_vit_b.pth",
#     token=True
# )

# print(f"Model downloaded to: {MedSAM_CKPT_PATH}")

# # Load the model
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
# medsam_model = medsam_model.to(device)
# medsam_model.eval()
# print(f"MedSAM model loaded successfully on {device}")

In [None]:
# Cell 3: Optimized MedSAM Interface - Starts Small for Better Performance
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import LassoSelector, Button, Slider
from matplotlib.path import Path
import torch
from PIL import Image, ImageEnhance

class AdjustableViewLassoMedSAMInterface:
    def __init__(self, model, dataset):
        self.model = model
        self.dataset = dataset
        self.original_image = None  # Original for MedSAM processing
        self.display_image = None   # Adjusted for viewing only
        self.lasso_points = []
        self.current_mask = None
        self.fig = None
        self.ax_image = None
        
        # Display adjustment parameters (viewing only)
        self.brightness = 1.0
        self.contrast = 1.0
        self.size_factor = 0.3  # Start much smaller for performance
        
    def load_image(self, image_index):
        """Load an image with viewing adjustments and lasso selection"""
        if image_index >= len(self.dataset['train']):
            print(f"Index {image_index} out of range. Max: {len(self.dataset['train'])-1}")
            return
            
        sample = self.dataset['train'][image_index]
        pil_image = sample['image']
        
        print(f"Loading Image #{image_index + 1}")
        print(f"Original size: {pil_image.size}")
        
        # Store original image for MedSAM processing (never modified)
        self.original_image = np.array(pil_image)
        self.original_pil_image = pil_image
        
        # Start with small size for better performance
        self.brightness = 1.0
        self.contrast = 1.0
        self.size_factor = 0.3  # Start at 30% for 5632x5632 images
        
        print(f"Display will start at {self.size_factor:.1%} size for smooth interaction")
        
        # Initialize display image
        self.update_display_image()
        
        # Create the interactive interface
        self.setup_adjustable_interface()
        
    def update_display_image(self):
        """Update display image with viewing adjustments (original unchanged)"""
        adjusted_image = self.original_pil_image.copy()
        
        # Apply size first (most important for performance)
        new_size = tuple(int(dim * self.size_factor) for dim in adjusted_image.size)
        adjusted_image = adjusted_image.resize(new_size, Image.Resampling.LANCZOS)
        
        # Apply brightness/contrast to smaller image
        if self.brightness != 1.0:
            enhancer = ImageEnhance.Brightness(adjusted_image)
            adjusted_image = enhancer.enhance(self.brightness)
        
        if self.contrast != 1.0:
            enhancer = ImageEnhance.Contrast(adjusted_image)
            adjusted_image = enhancer.enhance(self.contrast)
        
        self.display_image = np.array(adjusted_image)
        
    def setup_adjustable_interface(self):
        """Setup the interface with viewing adjustments and lasso selection"""
        if self.fig is not None:
            plt.close(self.fig)
        
        self.fig = plt.figure(figsize=(15, 10))
        
        # Main image axes
        self.ax_image = plt.axes([0.1, 0.3, 0.6, 0.6])
        self.ax_image.imshow(self.display_image)
        self.ax_image.set_title(f"Lasso Tool (Display: {self.display_image.shape[:2]}, Original: {self.original_image.shape[:2]})")
        
        # Simple lasso selector
        self.lasso = LassoSelector(self.ax_image, self.on_lasso_select)
        
        self.add_viewing_sliders()
        self.add_control_buttons()
        
        print("=" * 60)
        print("🚀 OPTIMIZED Interface:")
        print(f"📱 Display: {self.display_image.shape[:2]} (for smooth interaction)")
        print(f"🔬 Processing: {self.original_image.shape[:2]} (for accuracy)")
        print("💡 Use View Size slider to zoom in when needed")
        print("=" * 60)
        
        plt.show()
        
    def add_viewing_sliders(self):
        """Add sliders for viewing adjustments only"""
        # Size slider - starts at current small size
        ax_size = plt.axes([0.75, 0.8, 0.15, 0.03])
        self.slider_size = Slider(ax_size, 'View Size', 0.1, 1.0, valinit=self.size_factor, valfmt='%.1f')
        self.slider_size.on_changed(self.update_view_size)
        
        # Brightness slider
        ax_brightness = plt.axes([0.75, 0.75, 0.15, 0.03])
        self.slider_brightness = Slider(ax_brightness, 'Brightness', 0.5, 2.0, valinit=1.0, valfmt='%.1f')
        self.slider_brightness.on_changed(self.update_view_brightness)
        
        # Contrast slider
        ax_contrast = plt.axes([0.75, 0.7, 0.15, 0.03])
        self.slider_contrast = Slider(ax_contrast, 'Contrast', 0.5, 2.0, valinit=1.0, valfmt='%.1f')
        self.slider_contrast.on_changed(self.update_view_contrast)
        
        # Performance info
        ax_info = plt.axes([0.75, 0.65, 0.15, 0.03])
        ax_info.text(0.1, 0.5, f"Current: {self.display_image.shape[1]}x{self.display_image.shape[0]}", 
                    transform=ax_info.transAxes, fontsize=9)
        ax_info.set_xticks([])
        ax_info.set_yticks([])
        
    def add_control_buttons(self):
        """Add control buttons"""
        ax_test = plt.axes([0.1, 0.15, 0.12, 0.04])
        ax_accept = plt.axes([0.23, 0.15, 0.12, 0.04])
        ax_clear = plt.axes([0.36, 0.15, 0.12, 0.04])
        ax_new_region = plt.axes([0.49, 0.15, 0.12, 0.04])
        
        self.btn_test = Button(ax_test, 'Test Segmentation')
        self.btn_accept = Button(ax_accept, 'Accept')
        self.btn_clear = Button(ax_clear, 'Clear')
        self.btn_new = Button(ax_new_region, 'New Region')
        
        self.btn_test.on_clicked(self.test_segmentation)
        self.btn_accept.on_clicked(self.accept_segmentation)
        self.btn_clear.on_clicked(self.clear_selection)
        self.btn_new.on_clicked(self.new_region)
        
    def update_view_size(self, val):
        """Update viewing size only"""
        self.size_factor = val
        self.refresh_display()
        
    def update_view_brightness(self, val):
        self.brightness = val
        self.refresh_display()
        
    def update_view_contrast(self, val):
        self.contrast = val
        self.refresh_display()
        
    def refresh_display(self):
        """Refresh the display efficiently"""
        self.update_display_image()
        self.ax_image.clear()
        self.ax_image.imshow(self.display_image)
        self.ax_image.set_title(f"Lasso Tool (Display: {self.display_image.shape[:2]}, Original: {self.original_image.shape[:2]})")
        
        # Recreate lasso selector
        self.lasso = LassoSelector(self.ax_image, self.on_lasso_select)
        
        # Clear selection when view changes
        self.lasso_points = []
        self.current_mask = None
        
        self.fig.canvas.draw_idle()
        
    def on_lasso_select(self, verts):
        """Handle lasso selection - scale coordinates to original image"""
        if len(verts) > 2:
            # Scale display coordinates to original image coordinates
            scale_factor = 1.0 / self.size_factor
            self.lasso_points = [(x * scale_factor, y * scale_factor) for x, y in verts]
            print(f"✓ Lasso drawn with {len(verts)} points (scaled to original coordinates)")
            
    def test_segmentation(self, event):
        """Test segmentation using ORIGINAL full-resolution image"""
        if not self.lasso_points:
            print("Please draw a lasso first!")
            return
            
        print("🔬 Running MedSAM on ORIGINAL full-resolution image...")
        
        # Convert lasso to bounding box on original coordinates
        points = np.array(self.lasso_points)
        x_min, y_min = points.min(axis=0).astype(int)
        x_max, y_max = points.max(axis=0).astype(int)
        
        # Ensure bbox is within bounds
        x_min = max(0, x_min)
        y_min = max(0, y_min)
        x_max = min(self.original_image.shape[1], x_max)
        y_max = min(self.original_image.shape[0], y_max)
        
        bbox = np.array([x_min, y_min, x_max, y_max])
        print(f"Bounding box on original image: {bbox}")
        
        # Get segmentation from MedSAM
        self.current_mask = self.get_medsam_segmentation(bbox)
        self.current_mask = self.refine_with_lasso(self.current_mask)
        
        self.show_segmentation_result()
        
    def get_medsam_segmentation(self, bbox):
        """Get segmentation from MedSAM using ORIGINAL image"""
        img_tensor = torch.tensor(self.original_image).float()
        if len(img_tensor.shape) == 3:
            img_tensor = img_tensor.permute(2, 0, 1)
        
        img_tensor = img_tensor.unsqueeze(0).to(self.model.device)
        
        with torch.no_grad():
            image_embeddings = self.model.image_encoder(img_tensor)
            bbox_tensor = torch.tensor(bbox).float().unsqueeze(0).to(self.model.device)
            
            sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                points=None, boxes=bbox_tensor, masks=None,
            )
            
            masks, iou_predictions = self.model.mask_decoder(
                image_embeddings=image_embeddings,
                image_pe=self.model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            
            mask = masks[0, 0].cpu().numpy()
            
        print(" MedSAM segmentation complete!")
        return mask > 0.5
        
    def refine_with_lasso(self, mask):
        """Refine mask using lasso path on original coordinates"""
        if not self.lasso_points:
            return mask
            
        path = Path(self.lasso_points)
        h, w = mask.shape
        y, x = np.mgrid[:h, :w]
        points = np.column_stack((x.ravel(), y.ravel()))
        inside_lasso = path.contains_points(points).reshape(h, w)
        return mask & inside_lasso
        
    def show_segmentation_result(self):
        """Display segmentation result scaled to display"""
        if self.current_mask is None:
            return
            
        # Scale mask to display size
        mask_pil = Image.fromarray((self.current_mask * 255).astype(np.uint8))
        display_size = self.display_image.shape[:2][::-1]
        scaled_mask_pil = mask_pil.resize(display_size, Image.Resampling.NEAREST)
        display_mask = np.array(scaled_mask_pil) > 127
            
        # Show result
        self.ax_image.clear()
        self.ax_image.imshow(self.display_image)
        
        # Overlay mask
        masked = np.ma.masked_where(~display_mask, display_mask)
        self.ax_image.imshow(masked, alpha=0.6, cmap='Reds')
        
        # Show lasso outline
        if self.lasso_points:
            display_lasso = [(x * self.size_factor, y * self.size_factor) for x, y in self.lasso_points]
            lasso_array = np.array(display_lasso)
            self.ax_image.plot(lasso_array[:, 0], lasso_array[:, 1], 'r-', linewidth=2)
            
        self.ax_image.set_title("Segmentation Result (processed on full-resolution)")
        self.fig.canvas.draw_idle()
        
    def accept_segmentation(self, event):
        """Accept the current segmentation"""
        if self.current_mask is None:
            print(" No segmentation to accept!")
            return
            
        print("Segmentation accepted!")
        print(f"Mask: {self.current_mask.shape}, Pixels: {np.sum(self.current_mask)}")
        self.new_region(event)
        
    def clear_selection(self, event):
        """Clear the current selection"""
        self.lasso_points = []
        self.current_mask = None
        
        self.ax_image.clear()
        self.ax_image.imshow(self.display_image)
        self.ax_image.set_title(f"Lasso Tool (Display: {self.display_image.shape[:2]}, Original: {self.original_image.shape[:2]})")
        
        self.lasso = LassoSelector(self.ax_image, self.on_lasso_select)
        self.fig.canvas.draw_idle()
        print("🧹 Selection cleared")
        
    def new_region(self, event):
        """Start selecting a new region"""
        self.lasso_points = []
        self.lasso = LassoSelector(self.ax_image, self.on_lasso_select)
        self.ax_image.set_title(" Draw another region")
        print("Ready for new region selection")
        
    def preview_images(self, start_idx=0, num_images=9):
        """Preview images at reduced resolution"""
        print(f"Preview: Images {start_idx} to {start_idx + num_images - 1}")
        
        fig, axs = plt.subplots(3, 3, figsize=(12, 12))
        axs = axs.flatten()
        
        for i in range(num_images):
            idx = start_idx + i
            if idx < len(self.dataset['train']):
                image = self.dataset['train'][idx]['image']
                # Show at reduced size for preview
                if max(image.size) > 400:
                    preview_size = tuple(int(dim * 400 / max(image.size)) for dim in image.size)
                    image = image.resize(preview_size, Image.Resampling.LANCZOS)
                
                axs[i].imshow(image)
                axs[i].set_title(f"Image #{idx + 1}")
                axs[i].axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
# Cell 4: Initialize Adjustable View Interface
interface = AdjustableViewLassoMedSAMInterface(medsam_model, image_dataset)

# Preview images
interface.preview_images(0, 9)

print("\nAdjustable View Lasso MedSAM Commands:")
print("=" * 60)
print("• interface.load_image(n) - Load image with viewing adjustments")
print("• interface.preview_images(start, count) - Preview images")

print("\nKey Features:")
print("✓ Viewing adjustments: size, brightness, contrast (temporary)")
print("✓ MedSAM processing: always uses original unmodified image")
print("✓ Coordinate mapping: lasso coordinates mapped to original image")
print("✓ Visual feedback: see adjustments while maintaining data integrity")

print("\nHow to Use:")
print("1. interface.load_image(5)")
print("2. Adjust viewing parameters as needed")
print("3. Draw a lasso around your region")
print("4. Click 'Test Segmentation' (processes original image)")
print("5. Click 'Accept' or 'Clear' to continue")

In [None]:
# Cell 5: Quick Functions
def load_image(n):
    return interface.load_image(n)

def preview(start=0):
    interface.preview_images(start, 9)

print("Quick Commands:")
print("• load_image(n)")
print("• preview()")

In [None]:
load_image(0)