# Segment Anything in Medical Images ([colab](https://colab.research.google.com/drive/1N4wv9jljtEZ_w-f92iOLXCdkD-KJlsJH?usp=sharing))

In [None]:
!pip install -r requirements.txt -U
print("Complete")

On the local device:
- Create a fresh environment `conda create -n medsam python=3.10 -y` and activate it `conda activate medsam`
- Install 
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
- Continue to next cell

## 2. Load pre-trained model

Please download the checkpoint [here](https://drive.google.com/drive/folders/1ETWmi4AiniJeWOt6HAsYgTjYv_fkgzoN?usp=drive_link). This pre-trained model can be directed loaded with SAM's checkpoint loader. 

In [None]:
%matplotlib widget
from segment_anything import sam_model_registry
from utils.demo import BboxPromptDemo
MedSAM_CKPT_PATH = "/home/medsam-vit-b/medsam_vit_b.pth"
device = "cuda:0"
medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model = medsam_model.to(device)
medsam_model.eval()

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
    )


In [None]:
# in tutorial_quickstart_holk.ipynb, new cell
%matplotlib widget
import os, numpy as np, torch
from PIL import Image
from datasets import load_dataset
from skimage import transform
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
from medsam_inference import medsam_inference

# CONFIGURE paths & HF repo
DATASET_NAME = "GleghornLab/full_LN_6-1"
SPLIT        = "train"
ROOT         = "/home/MedSAM/data/follicle/train"
IMG_DIR      = os.path.join(ROOT, "images"); os.makedirs(IMG_DIR, exist_ok=True)
MASK_DIR     = os.path.join(ROOT, "masks");  os.makedirs(MASK_DIR, exist_ok=True)

ds = load_dataset(DATASET_NAME, token=True)[SPLIT]

# build per‐section index map
section_map = {}
for i, ex in enumerate(ds):
    sec = ex["section"]  # or whatever your field is
    section_map.setdefault(sec, []).append(i)

# now pick how many per section go to train/val/test:
train_idx = []
val_idx   = []
test_idx  = []
for sec, inds in section_map.items():
    # e.g. first 6 → train, next 1 → val, next 1 → test
    train_idx += inds[:6]
    val_idx   += inds[6:7]
    test_idx  += inds[7:8]

print("Train:", len(train_idx), "Val:", len(val_idx), "Test:", len(test_idx))


def annotate_and_save_multi(idx):
    pil_img = ds[idx]["image"]
    img = np.array(pil_img); H, W = img.shape[:2]
    bboxes = []

    fig, ax = plt.subplots(figsize=(6,6))
    ax.imshow(img); ax.set_title(f"Image #{idx}: draw boxes, ENTER when done")

    def onselect(e0,e1):
        x1,y1 = int(e0.xdata), int(e0.ydata)
        x2,y2 = int(e1.xdata), int(e1.ydata)
        bboxes.append([x1,y1,x2,y2])
        ax.add_patch(plt.Rectangle((x1,y1),x2-x1,y2-y1,
                                   edgecolor="yellow",fill=False,lw=2))
        fig.canvas.draw()

    sel = RectangleSelector(ax, onselect, drawtype="box",
                            useblit=True, button=[1],
                            minspanx=5, minspany=5)
    fig.canvas.mpl_connect("key_press_event",
                           lambda evt: plt.close(fig) if evt.key=="enter" else None)
    plt.show()

    if not bboxes:
        print("no boxes → skipped"); return

    # preprocess & encode once
    img1024 = transform.resize(img, (1024,1024),
                               order=3, preserve_range=True,
                               anti_aliasing=True).astype(np.uint8)
    norm = (img1024 - img1024.min())/np.clip(img1024.max()-img1024.min(),1e-8,None)
    tensor = (torch.tensor(norm).float()
                        .permute(2,0,1)
                        .unsqueeze(0)
                        .to(device))
    with torch.no_grad():
        embedding = medsam_model.image_encoder(tensor)

    # run and save one mask per box
    for i, box in enumerate(bboxes):
        box1024 = np.array([box]) / np.array([W,H,W,H]) * 1024
        mask = medsam_inference(medsam_model, embedding, box1024, H, W)
        name = f"{idx:04d}_{i:02d}.png"
        pil_img.save(os.path.join(IMG_DIR,  name))
        Image.fromarray((mask*255).astype("uint8"))\
             .save(os.path.join(MASK_DIR, name))
        print("✓", name)


In [None]:
# in the next cell, annotate multiple objects on image 0
annotate_and_save_multi(0)# draw multiple boxes, ENTER to finish

# then when that finishes, call for image 1, 2, …
#annotate_and_save_multi(1)# then move on to image #1, etc.