In [None]:
#Cell 0:
!pip install -r requirements.txt -U
print("Complete")

In [None]:
# Cell 1: Setup and Authentication
%matplotlib widget

import matplotlib.pyplot as plt
from datasets import load_dataset
import huggingface_hub
from PIL import Image
import numpy as np
import torch
from segment_anything import sam_model_registry
from utils.demo import BboxPromptDemo
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

# Login to Hugging Face
hf_token = input("Enter your Hugging Face token: ")
huggingface_hub.login(token=hf_token)

print("Authentication successful!")

In [None]:
# Cell 2: Load Dataset and Model (From Hugging Face)

print("Loading lymph node dataset...")
dataset_input = input("Please enter dataset path: ")

try:
    # First, analyze the dataset structure
    temp_dataset = load_dataset(dataset_input, token=True)
    
    print("🔍 Dataset Analysis:")
    print(f"Available splits: {list(temp_dataset.keys())}")
    for split_name in temp_dataset.keys():
        print(f"  - {split_name}: {len(temp_dataset[split_name])} items")
        if len(temp_dataset[split_name]) > 0:
            print(f"    Features: {temp_dataset[split_name].features}")
    
    # Choose which split to use as 'train'
    chosen_split = input(f"Which split to use as 'train'? {list(temp_dataset.keys())}: ")
    
    if chosen_split in temp_dataset:
        print(f"Will use '{chosen_split}' split")
        
        # Load the chosen split and create 'train' alias
        image_dataset = temp_dataset
        image_dataset['train'] = image_dataset[chosen_split]
        
        print("Dataset info:")
        print(f"Available splits: {list(image_dataset.keys())}")
        print(f"Number of train images: {len(image_dataset['train'])}")
        print(f"Features: {image_dataset['train'].features}")
        
    else:
        raise ValueError(f" Split '{chosen_split}' not found in dataset")
        
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("\n Troubleshooting:")
    print("1. Check if the dataset repository exists and is accessible")
    print("2. Verify your HuggingFace token has the correct permissions")
    print("3. Make sure the dataset path is correct")
    print("4. Try a different split name if available")
    print("\n Cannot continue without valid dataset. Please fix the issue and try again.")
    raise  # Re-raise the error to stop execution

# Load MedSAM model from Hugging Face
print("\nLoading MedSAM model from Hugging Face...")
from huggingface_hub import hf_hub_download

try:
    # Use your manually downloaded checkpoint if already downloaded
    MedSAM_CKPT_PATH = "/home/medsam-vit-b/medsam_vit_b.pth"
    
    # #OR USE THIS section to download the model directly from Hugging Face into cache
    # MedSAM_CKPT_PATH = hf_hub_download(
    #     repo_id="GleghornLab/medsam-vit-b",
    #     filename="medsam_vit_b.pth",
    #     token=True
    # )
    
    print(f"Model downloaded to: {MedSAM_CKPT_PATH}")
    
    # Load the model
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
    medsam_model = medsam_model.to(device)
    medsam_model.eval()
    
    print(f"MedSAM model loaded successfully on {device}")
    
    # Show device info
    print(f"🖥️ Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        
except Exception as e:
    print(f"Error loading model: {e}")
    print("\nModel troubleshooting:")
    print("1. Check if the model repository 'GleghornLab/medsam-vit-b' exists")
    print("2. Verify you have access permissions to this repository")
    print("3. Ensure your HuggingFace token has the correct permissions")
    print("\n Cannot continue without valid model. Please fix the issue and try again.")
    raise  # Re-raise the error to stop execution

print("\n Setup complete! Dataset and model loaded successfully.")



# image_dataset = load_dataset(dataset_input, token=True)

# print("Dataset info:")
# print(f"Available splits: {list(image_dataset.keys())}")
# print(f"Number of train images: {len(image_dataset['train'])}")
# print(f"Features: {image_dataset['train'].features}")

# # Load MedSAM model from Hugging Face
# print("\nLoading MedSAM model from Hugging Face...")
# from huggingface_hub import hf_hub_download

# # Download model from Hugging Face repository
# MedSAM_CKPT_PATH = hf_hub_download(
#     repo_id="GleghornLab/medsam-vit-b",
#     filename="medsam_vit_b.pth",
#     token=True
# )

# print(f"Model downloaded to: {MedSAM_CKPT_PATH}")

# # Load the model
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
# medsam_model = medsam_model.to(device)
# medsam_model.eval()
# print(f"MedSAM model loaded successfully on {device}")

In [None]:
# Cell 3: Pure MedSAM Approach (Closest to Original)
import numpy as np

class PureMedSAMInterface:
    def __init__(self, model, dataset):
        self.model = model
        self.dataset = dataset
        
    def load_image(self, image_index):
        """Load image using pure MedSAM approach"""
        if image_index >= len(self.dataset['train']):
            print(f"Index {image_index} out of range. Max: {len(self.dataset['train'])-1}")
            return
            
        sample = self.dataset['train'][image_index]
        image = sample['image']
        
        print(f"Loading Image #{image_index + 1}")
        print(f"Size: {image.size}")
        
        # Convert PIL to numpy (this is the only adaptation needed)
        image_array = np.array(image)
        
        # Use MedSAM EXACTLY as intended - this is the original MedSAM demo
        demo = BboxPromptDemo(self.model)
        return demo.show(image_array)  # This is the standard MedSAM way
        
    def preview_images(self, start_idx=0, num_images=9):
        """Simple preview to help select images"""
        print(f"Preview: Images {start_idx} to {start_idx + num_images - 1}")
        
        fig, axs = plt.subplots(3, 3, figsize=(12, 12))
        axs = axs.flatten()
        
        for i in range(num_images):
            idx = start_idx + i
            if idx < len(self.dataset['train']):
                image = self.dataset['train'][idx]['image']
                axs[i].imshow(image)
                axs[i].set_title(f"Image #{idx + 1}")
                axs[i].axis('off')
            else:
                axs[i].axis('off')
                
        plt.tight_layout()
        plt.show()

In [None]:
# Cell 4: Initialize Native Interface
interface = SimpleMedSAMInterface(medsam_model, image_dataset)

# Preview images
interface.preview_images(0, 9)

print("\nNative MedSAM Commands:")
print("=" * 50)
print("• interface.load_image(n) - Load image for multi-region segmentation")
print("• interface.preview_images(start, count) - Preview images")

print("\nHow to Use:")
print("1. interface.load_image(5)")
print("2. Click and drag multiple bounding boxes")
print("3. Each box automatically segments")
print("4. All regions handled in one interface")

In [None]:
# Cell 5: Quick Functions
def load_image(n):
    return interface.load_image(n)

def preview(start=0):
    interface.preview_images(start, 9)

print("Quick Commands:")
print("• load_image(n)")
print("• preview()")