In [None]:
#Cell 0:
!pip install -r requirements.txt -U
print("Complete")

In [None]:
# Cell 1: Setup and Authentication
%matplotlib widget

import matplotlib.pyplot as plt
from datasets import load_dataset
import huggingface_hub
from PIL import Image
import numpy as np
import torch
from segment_anything import sam_model_registry
from utils.demo import BboxPromptDemo
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

# Login to Hugging Face
hf_token = input("Enter your Hugging Face token: ")
huggingface_hub.login(token=hf_token)

print("Authentication successful!")

In [None]:
# Cell 2: Load Dataset and Model (From Hugging Face)

print("Loading lymph node dataset...")
dataset_input = input("Please enter dataset path: ")

try:
    # First, analyze the dataset structure
    temp_dataset = load_dataset(dataset_input, token=True)
    
    print("🔍 Dataset Analysis:")
    print(f"Available splits: {list(temp_dataset.keys())}")
    for split_name in temp_dataset.keys():
        print(f"  - {split_name}: {len(temp_dataset[split_name])} items")
        if len(temp_dataset[split_name]) > 0:
            print(f"    Features: {temp_dataset[split_name].features}")
    
    # Choose which split to use as 'train'
    chosen_split = input(f"Which split to use as 'train'? {list(temp_dataset.keys())}: ")
    
    if chosen_split in temp_dataset:
        print(f"Will use '{chosen_split}' split")
        
        # Load the chosen split and create 'train' alias
        image_dataset = temp_dataset
        image_dataset['train'] = image_dataset[chosen_split]
        
        print("Dataset info:")
        print(f"Available splits: {list(image_dataset.keys())}")
        print(f"Number of train images: {len(image_dataset['train'])}")
        print(f"Features: {image_dataset['train'].features}")
        
    else:
        raise ValueError(f" Split '{chosen_split}' not found in dataset")
        
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("\n Troubleshooting:")
    print("1. Check if the dataset repository exists and is accessible")
    print("2. Verify your HuggingFace token has the correct permissions")
    print("3. Make sure the dataset path is correct")
    print("4. Try a different split name if available")
    print("\n Cannot continue without valid dataset. Please fix the issue and try again.")
    raise  # Re-raise the error to stop execution

# Load MedSAM model from Hugging Face
print("\nLoading MedSAM model from Hugging Face...")
from huggingface_hub import hf_hub_download

try:
    # Use your manually downloaded checkpoint if already downloaded
    MedSAM_CKPT_PATH = "/home/medsam-vit-b/medsam_vit_b.pth"
    
    # #OR USE THIS section to download the model directly from Hugging Face into cache
    # MedSAM_CKPT_PATH = hf_hub_download(
    #     repo_id="GleghornLab/medsam-vit-b",
    #     filename="medsam_vit_b.pth",
    #     token=True
    # )
    
    print(f"Model downloaded to: {MedSAM_CKPT_PATH}")
    
    # Load the model
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
    medsam_model = medsam_model.to(device)
    medsam_model.eval()
    
    print(f"MedSAM model loaded successfully on {device}")
    
    # Show device info
    print(f"🖥️ Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        
except Exception as e:
    print(f"Error loading model: {e}")
    print("\nModel troubleshooting:")
    print("1. Check if the model repository 'GleghornLab/medsam-vit-b' exists")
    print("2. Verify you have access permissions to this repository")
    print("3. Ensure your HuggingFace token has the correct permissions")
    print("\n Cannot continue without valid model. Please fix the issue and try again.")
    raise  # Re-raise the error to stop execution

print("\n Setup complete! Dataset and model loaded successfully.")



# image_dataset = load_dataset(dataset_input, token=True)

# print("Dataset info:")
# print(f"Available splits: {list(image_dataset.keys())}")
# print(f"Number of train images: {len(image_dataset['train'])}")
# print(f"Features: {image_dataset['train'].features}")

# # Load MedSAM model from Hugging Face
# print("\nLoading MedSAM model from Hugging Face...")
# from huggingface_hub import hf_hub_download

# # Download model from Hugging Face repository
# MedSAM_CKPT_PATH = hf_hub_download(
#     repo_id="GleghornLab/medsam-vit-b",
#     filename="medsam_vit_b.pth",
#     token=True
# )

# print(f"Model downloaded to: {MedSAM_CKPT_PATH}")

# # Load the model
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
# medsam_model = medsam_model.to(device)
# medsam_model.eval()
# print(f"MedSAM model loaded successfully on {device}")

In [None]:
# Cell 3: Compact Optimized MedSAM Interface - Fixed Layout
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import LassoSelector, Button, Slider
from matplotlib.path import Path
import torch
from PIL import Image, ImageEnhance

class AdjustableViewLassoMedSAMInterface:
    def __init__(self, model, dataset):
        self.model = model
        self.dataset = dataset
        self.original_image = None  # Original for MedSAM processing
        self.display_image = None   # Adjusted for viewing only
        self.lasso_points = []
        self.current_mask = None
        self.fig = None
        self.ax_image = None
        self.current_image_index = None
        
        # Display adjustment parameters (viewing only)
        self.brightness = 1.0
        self.contrast = 1.0
        self.size_factor = 0.15  # Start small for performance
        
        # Storage for results
        self.segmentation_results = []
        
    def load_image(self, image_index):
        """Load an image with viewing adjustments and lasso selection"""
        if image_index >= len(self.dataset['train']):
            print(f"Index {image_index} out of range. Max: {len(self.dataset['train'])-1}")
            return
            
        self.current_image_index = image_index
        sample = self.dataset['train'][image_index]
        pil_image = sample['image']
        
        print(f"Loading Image #{image_index + 1}")
        print(f"Original size: {pil_image.size}")
        
        # Store original image for MedSAM processing (never modified)
        self.original_image = np.array(pil_image)
        self.original_pil_image = pil_image
        
        # Reset to default settings
        self.brightness = 1.0
        self.contrast = 1.0
        self.size_factor = 0.15  # Start at 15% for good visibility
        
        print(f"Display starting at {self.size_factor:.1%} size")
        
        # Initialize display image
        self.update_display_image()
        
        # Create the interactive interface
        self.setup_adjustable_interface()
        
    def update_display_image(self):
        """Update display image with viewing adjustments (original unchanged)"""
        adjusted_image = self.original_pil_image.copy()
        
        # Apply size first (most important for performance)
        new_size = tuple(int(dim * self.size_factor) for dim in adjusted_image.size)
        # Ensure minimum size for visibility
        new_size = tuple(max(200, dim) for dim in new_size)
        adjusted_image = adjusted_image.resize(new_size, Image.Resampling.LANCZOS)
        
        # Apply brightness/contrast to smaller image
        if self.brightness != 1.0:
            enhancer = ImageEnhance.Brightness(adjusted_image)
            adjusted_image = enhancer.enhance(self.brightness)
        
        if self.contrast != 1.0:
            enhancer = ImageEnhance.Contrast(adjusted_image)
            adjusted_image = enhancer.enhance(self.contrast)
        
        self.display_image = np.array(adjusted_image)
        
    def setup_adjustable_interface(self):
        """Setup compact interface that fits in Jupyter"""
        if self.fig is not None:
            plt.close(self.fig)
        
        # Much smaller figure size to fit in Jupyter
        self.fig = plt.figure(figsize=(12, 8))
        plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05)
        
        # Main image axes - larger area for image
        self.ax_image = plt.axes([0.05, 0.25, 0.6, 0.7])
        self.ax_image.imshow(self.display_image)
        self.ax_image.set_title(f"Image #{self.current_image_index + 1} - Lasso Tool\nDisplay: {self.display_image.shape[:2]}", fontsize=10)
        self.ax_image.axis('off')  # Remove axes for cleaner look
        
        # Simple lasso selector
        self.lasso = LassoSelector(self.ax_image, self.on_lasso_select,
                                  props=dict(color='red', linewidth=2))
        
        self.add_compact_controls()
        
        print("✅ Interface loaded!")
        print(f"📱 Display: {self.display_image.shape[:2]} (use sliders to adjust)")
        print(f"🔬 Processing: {self.original_image.shape[:2]} (full resolution)")
        
        plt.show()
        
    def add_compact_controls(self):
        """Add compact controls on the right side"""
        # Sliders on the right
        slider_width = 0.25
        slider_height = 0.02
        slider_x = 0.7
        
        # Size slider
        ax_size = plt.axes([slider_x, 0.85, slider_width, slider_height])
        self.slider_size = Slider(ax_size, 'Size', 0.1, 0.6, valinit=self.size_factor, valfmt='%.2f')
        self.slider_size.on_changed(self.update_view_size)
        
        # Brightness slider
        ax_brightness = plt.axes([slider_x, 0.8, slider_width, slider_height])
        self.slider_brightness = Slider(ax_brightness, 'Bright', 0.5, 2.0, valinit=1.0, valfmt='%.1f')
        self.slider_brightness.on_changed(self.update_view_brightness)
        
        # Contrast slider
        ax_contrast = plt.axes([slider_x, 0.75, slider_width, slider_height])
        self.slider_contrast = Slider(ax_contrast, 'Contrast', 0.5, 2.0, valinit=1.0, valfmt='%.1f')
        self.slider_contrast.on_changed(self.update_view_contrast)
        
        # Info display
        ax_info = plt.axes([slider_x, 0.7, slider_width, 0.03])
        ax_info.text(0.1, 0.5, f"Current: {self.display_image.shape[1]}x{self.display_image.shape[0]}", 
                    transform=ax_info.transAxes, fontsize=8)
        ax_info.set_xticks([])
        ax_info.set_yticks([])
        
        # Control buttons - arranged vertically on right
        button_width = 0.12
        button_height = 0.04
        button_x = slider_x + 0.02
        
        ax_test = plt.axes([button_x, 0.55, button_width, button_height])
        ax_accept = plt.axes([button_x, 0.5, button_width, button_height])
        ax_clear = plt.axes([button_x, 0.45, button_width, button_height])
        ax_new = plt.axes([button_x, 0.4, button_width, button_height])
        ax_save = plt.axes([button_x, 0.35, button_width, button_height])
        
        self.btn_test = Button(ax_test, 'Test Seg')
        self.btn_accept = Button(ax_accept, 'Accept')
        self.btn_clear = Button(ax_clear, 'Clear')
        self.btn_new = Button(ax_new, 'New Region')
        self.btn_save = Button(ax_save, 'Save All')
        
        self.btn_test.on_clicked(self.test_segmentation)
        self.btn_accept.on_clicked(self.accept_segmentation)
        self.btn_clear.on_clicked(self.clear_selection)
        self.btn_new.on_clicked(self.new_region)
        self.btn_save.on_clicked(self.save_all_results)
        
        # Navigation buttons at bottom
        nav_y = 0.05
        ax_prev = plt.axes([0.05, nav_y, 0.08, 0.04])
        ax_next = plt.axes([0.15, nav_y, 0.08, 0.04])
        ax_goto = plt.axes([0.25, nav_y, 0.08, 0.04])
        
        self.btn_prev = Button(ax_prev, 'Previous')
        self.btn_next = Button(ax_next, 'Next')
        self.btn_goto = Button(ax_goto, 'Go to...')
        
        self.btn_prev.on_clicked(self.previous_image)
        self.btn_next.on_clicked(self.next_image)
        self.btn_goto.on_clicked(self.goto_image)
        
    def update_view_size(self, val):
        """Update viewing size only"""
        self.size_factor = val
        self.refresh_display()
        
    def update_view_brightness(self, val):
        self.brightness = val
        self.refresh_display()
        
    def update_view_contrast(self, val):
        self.contrast = val
        self.refresh_display()
        
    def refresh_display(self):
        """Refresh the display efficiently"""
        self.update_display_image()
        self.ax_image.clear()
        self.ax_image.imshow(self.display_image)
        self.ax_image.set_title(f"Image #{self.current_image_index + 1} - Lasso Tool\nDisplay: {self.display_image.shape[:2]}", fontsize=10)
        self.ax_image.axis('off')
        
        # Recreate lasso selector
        self.lasso = LassoSelector(self.ax_image, self.on_lasso_select,
                                  props=dict(color='red', linewidth=2))
        
        # Clear selection when view changes
        self.lasso_points = []
        self.current_mask = None
        
        self.fig.canvas.draw_idle()
        
    def on_lasso_select(self, verts):
        """Handle lasso selection - scale coordinates to original image and keep visible"""
        if len(verts) > 2:
            # Scale display coordinates to original image coordinates
            scale_factor = 1.0 / self.size_factor
            self.lasso_points = [(x * scale_factor, y * scale_factor) for x, y in verts]
            
            # Keep lasso visible on display
            display_lasso = [(x * self.size_factor, y * self.size_factor) for x, y in self.lasso_points]
            lasso_array = np.array(display_lasso + [display_lasso[0]])  # Close the loop
            
            # Draw the lasso outline permanently
            if hasattr(self, 'lasso_line'):
                self.lasso_line.remove()
            
            self.lasso_line = self.ax_image.plot(lasso_array[:, 0], lasso_array[:, 1], 
                                            'r-', linewidth=2, alpha=0.8)[0]
            
            # Add start/end markers
            if hasattr(self, 'lasso_markers'):
                self.lasso_markers.remove()
            
            self.lasso_markers = self.ax_image.plot(display_lasso[0][0], display_lasso[0][1], 
                                                'ro', markersize=6, alpha=0.8)[0]
            
            self.fig.canvas.draw_idle()
            
            print(f"✓ Lasso drawn with {len(verts)} points (staying visible)")
            
    def test_segmentation(self, event):
        """Test segmentation using ORIGINAL full-resolution image"""
        if not self.lasso_points:
            print("❌ Please draw a lasso first!")
            return
            
        print("🔬 Running MedSAM...")
        
        # Convert lasso to bounding box on original coordinates
        points = np.array(self.lasso_points)
        x_min, y_min = points.min(axis=0).astype(int)
        x_max, y_max = points.max(axis=0).astype(int)
        
        # Ensure bbox is within bounds
        x_min = max(0, x_min)
        y_min = max(0, y_min)
        x_max = min(self.original_image.shape[1], x_max)
        y_max = min(self.original_image.shape[0], y_max)
        
        bbox = np.array([x_min, y_min, x_max, y_max])
        print(f"📦 Bbox: {bbox}")
        
        # Get segmentation from MedSAM
        self.current_mask = self.get_medsam_segmentation(bbox)
        self.current_mask = self.refine_with_lasso(self.current_mask)
        
        self.show_segmentation_result()
        
    def get_medsam_segmentation(self, bbox):
        """Get segmentation from MedSAM using ORIGINAL image"""
        img_tensor = torch.tensor(self.original_image).float()
        if len(img_tensor.shape) == 3:
            img_tensor = img_tensor.permute(2, 0, 1)
        
        img_tensor = img_tensor.unsqueeze(0).to(self.model.device)
        
        with torch.no_grad():
            image_embeddings = self.model.image_encoder(img_tensor)
            bbox_tensor = torch.tensor(bbox).float().unsqueeze(0).to(self.model.device)
            
            sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                points=None, boxes=bbox_tensor, masks=None,
            )
            
            masks, iou_predictions = self.model.mask_decoder(
                image_embeddings=image_embeddings,
                image_pe=self.model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            
            mask = masks[0, 0].cpu().numpy()
            
        print("✅ Segmentation complete!")
        return mask > 0.5
        
    def refine_with_lasso(self, mask):
        """Refine mask using lasso path on original coordinates"""
        if not self.lasso_points:
            return mask
            
        path = Path(self.lasso_points)
        h, w = mask.shape
        y, x = np.mgrid[:h, :w]
        points = np.column_stack((x.ravel(), y.ravel()))
        inside_lasso = path.contains_points(points).reshape(h, w)
        return mask & inside_lasso
        
    def show_segmentation_result(self):
        """Display segmentation result scaled to display"""
        if self.current_mask is None:
            return
            
        # Scale mask to display size
        mask_pil = Image.fromarray((self.current_mask * 255).astype(np.uint8))
        display_size = self.display_image.shape[:2][::-1]
        scaled_mask_pil = mask_pil.resize(display_size, Image.Resampling.NEAREST)
        display_mask = np.array(scaled_mask_pil) > 127
            
        # Show result
        self.ax_image.clear()
        self.ax_image.imshow(self.display_image)
        
        # Overlay mask
        masked = np.ma.masked_where(~display_mask, display_mask)
        self.ax_image.imshow(masked, alpha=0.6, cmap='Reds')
        
        # Show lasso outline
        if self.lasso_points:
            display_lasso = [(x * self.size_factor, y * self.size_factor) for x, y in self.lasso_points]
            lasso_array = np.array(display_lasso)
            self.ax_image.plot(lasso_array[:, 0], lasso_array[:, 1], 'r-', linewidth=2)
            
        self.ax_image.set_title(f"✅ Segmentation Result - Image #{self.current_image_index + 1}", fontsize=10)
        self.ax_image.axis('off')
        self.fig.canvas.draw_idle()
        
    def accept_segmentation(self, event):
        """Accept and save the current segmentation"""
        if self.current_mask is None:
            print("❌ No segmentation to accept!")
            return
        
        # Store result
        result = {
            'image_index': self.current_image_index,
            'mask': self.current_mask.copy(),
            'lasso_points': self.lasso_points.copy(),
            'bbox': self.get_bbox_from_lasso(),
            'mask_pixels': int(np.sum(self.current_mask)),
            'settings': {
                'brightness': self.brightness,
                'contrast': self.contrast,
                'size_factor': self.size_factor
            }
        }
        
        self.segmentation_results.append(result)
        
        print("✅ Segmentation accepted!")
        print(f"📊 Pixels: {np.sum(self.current_mask)}")
        print(f"💾 Total saved: {len(self.segmentation_results)}")
        self.new_region(event)
        
    def get_bbox_from_lasso(self):
        """Get bounding box from lasso points"""
        if not self.lasso_points:
            return None
        points = np.array(self.lasso_points)
        x_min, y_min = points.min(axis=0).astype(int)
        x_max, y_max = points.max(axis=0).astype(int)
        return [x_min, y_min, x_max, y_max]
        
    def clear_selection(self, event):
        """Clear the current selection"""
        self.lasso_points = []
        self.current_mask = None
        
        # Remove persistent lasso lines if they exist
        if hasattr(self, 'lasso_line'):
            self.lasso_line.remove()
            delattr(self, 'lasso_line')
        
        if hasattr(self, 'lasso_markers'):
            self.lasso_markers.remove()
            delattr(self, 'lasso_markers')
        
        self.ax_image.clear()
        self.ax_image.imshow(self.display_image)
        self.ax_image.set_title(f"Image #{self.current_image_index + 1} - Lasso Tool\nDisplay: {self.display_image.shape[:2]}", fontsize=10)
        self.ax_image.axis('off')
        
        self.lasso = LassoSelector(self.ax_image, self.on_lasso_select,
                                props=dict(color='red', linewidth=2))
        self.fig.canvas.draw_idle()
        print("🧹 Selection cleared")
        
    def new_region(self, event):
        """Start selecting a new region"""
        self.lasso_points = []
        self.lasso = LassoSelector(self.ax_image, self.on_lasso_select,
                                  props=dict(color='blue', linewidth=2))
        self.ax_image.set_title(f"🆕 Draw another region - Image #{self.current_image_index + 1}", fontsize=10)
        print("🎯 Ready for new region")
        
    def previous_image(self, event):
        """Load previous image"""
        if self.current_image_index > 0:
            self.load_image(self.current_image_index - 1)
        else:
            print("Already at first image")
            
    def next_image(self, event):
        """Load next image"""
        if self.current_image_index < len(self.dataset['train']) - 1:
            self.load_image(self.current_image_index + 1)
        else:
            print("Already at last image")
            
    def goto_image(self, event):
        """Go to specific image"""
        try:
            idx = int(input(f"Enter image index (0-{len(self.dataset['train'])-1}): "))
            self.load_image(idx)
        except:
            print("Invalid index")
            
    def save_all_results(self, event):
        """Save all segmentation results"""
        if not self.segmentation_results:
            print("❌ No results to save!")
            return
            
        import os
        from datetime import datetime
        
        # Create output directory
        output_dir = "lymph_node_segmentations"
        os.makedirs(output_dir, exist_ok=True)
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Save results as numpy archive
        masks = [r['mask'] for r in self.segmentation_results]
        metadata = {i: {k: v for k, v in r.items() if k != 'mask'} 
                   for i, r in enumerate(self.segmentation_results)}
        
        filename = os.path.join(output_dir, f"segmentations_{timestamp}.npz")
        np.savez_compressed(filename, masks=masks, metadata=metadata)
        
        print(f"💾 Saved {len(self.segmentation_results)} segmentations to {filename}")
        
    def preview_images(self, start_idx=0, num_images=9):
        """Preview images at reduced resolution"""
        print(f"🖼️ Preview: Images {start_idx} to {start_idx + num_images - 1}")
        
        fig, axs = plt.subplots(3, 3, figsize=(10, 10))
        axs = axs.flatten()
        
        for i in range(num_images):
            idx = start_idx + i
            if idx < len(self.dataset['train']):
                image = self.dataset['train'][idx]['image']
                # Show at reduced size for preview
                if max(image.size) > 300:
                    preview_size = tuple(int(dim * 300 / max(image.size)) for dim in image.size)
                    image = image.resize(preview_size, Image.Resampling.LANCZOS)
                
                axs[i].imshow(image)
                axs[i].set_title(f"#{idx + 1}")
                axs[i].axis('off')
            else:
                axs[i].axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
# Cell 4: Initialize Adjustable View Interface
interface = AdjustableViewLassoMedSAMInterface(medsam_model, image_dataset)

# Preview images
interface.preview_images(0, 9)

print("\nAdjustable View Lasso MedSAM Commands:")
print("=" * 60)
print("• interface.load_image(n) - Load image with viewing adjustments")
print("• interface.preview_images(start, count) - Preview images")

print("\nKey Features:")
print("✓ Viewing adjustments: size, brightness, contrast (temporary)")
print("✓ MedSAM processing: always uses original unmodified image")
print("✓ Coordinate mapping: lasso coordinates mapped to original image")
print("✓ Visual feedback: see adjustments while maintaining data integrity")

print("\nHow to Use:")
print("1. interface.load_image(5)")
print("2. Adjust viewing parameters as needed")
print("3. Draw a lasso around your region")
print("4. Click 'Test Segmentation' (processes original image)")
print("5. Click 'Accept' or 'Clear' to continue")

In [None]:
# Cell 5: Quick Functions
def load_image(n):
    return interface.load_image(n)

def preview(start=0):
    interface.preview_images(start, 9)

print("Quick Commands:")
print("• load_image(n)")
print("• preview()")

In [None]:
#Cell 6
load_image(0)