# Segment Anything in Medical Images ([colab](https://colab.research.google.com/drive/1N4wv9jljtEZ_w-f92iOLXCdkD-KJlsJH?usp=sharing))

In [None]:
!pip install -r requirements.txt -U
print("Complete")

On the local device:
- Create a fresh environment `conda create -n medsam python=3.10 -y` and activate it `conda activate medsam`
- Install 
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
- Continue to next cell

## 2. Load pre-trained model

Please download the checkpoint [here](https://drive.google.com/drive/folders/1ETWmi4AiniJeWOt6HAsYgTjYv_fkgzoN?usp=drive_link). This pre-trained model can be directed loaded with SAM's checkpoint loader. 

In [None]:
%matplotlib widget
from segment_anything import sam_model_registry
from utils.demo import BboxPromptDemo
MedSAM_CKPT_PATH = "/home/medsam-vit-b/medsam_vit_b.pth"
device = "cuda:0"
medsam_model = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model = medsam_model.to(device)
medsam_model.eval()

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
    )


In [None]:
%matplotlib widget
import os
import numpy as np
import torch
from PIL import Image
from datasets import load_dataset, DatasetDict, concatenate_datasets
from skimage import transform
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
from sklearn.model_selection import train_test_split
from medsam_inference import medsam_inference


# ─── LOAD & CONCAT ALL HF SPLITS ─────────────────────────────────────────────
DATASET_NAME = "GleghornLab/full_LN_6-1"
raw = load_dataset(DATASET_NAME, token=True)
if isinstance(raw, DatasetDict):
    parts = []
    for split_name, ds_split in raw.items():
        ds_split = ds_split.add_column("section", [split_name]*len(ds_split))
        parts.append(ds_split)
    ds = concatenate_datasets(parts)
else:
    ds = raw
print("Total examples:", len(ds))

# ─── GLOBAL RANDOM SPLIT 75/12.5/12.5% ────────────────────────────────────
all_idx = list(range(len(ds)))
train_idx, rest   = train_test_split(all_idx, train_size=0.75, random_state=42)
val_idx, test_idx = train_test_split(rest,   train_size=0.5,  random_state=42)
print(f"→ {len(train_idx)} train, {len(val_idx)} val, {len(test_idx)} test examples")


In [None]:
from utils.demo import BboxPromptDemo
import os
from PIL import Image
import tempfile


# Make output folders if not already present
BASE = "/home/MedSAM/data/follicle"
ROOTS = {
    "train": BASE + "_train",
    "val":   BASE + "_val",
    "test":  BASE + "_test",
}
for split, root in ROOTS.items():
    for sub in ("images", "masks"):
        os.makedirs(os.path.join(root, sub), exist_ok=True)

# Create the demo object
bbox_prompt_demo = BboxPromptDemo(medsam_model)

# Loop through your splits and annotate

for split, idxs in [("train", train_idx), ("val", val_idx), ("test", test_idx)]:
    print(f"\n=== {split.upper()} ({len(idxs)} images) ===")
    for idx in idxs:
        img = ds[idx]["image"]  # PIL image
        print(f"Annotating {split} image #{idx}")
        # Save to a temporary file
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
            img.save(tmp.name)
            mask = bbox_prompt_demo.show(tmp.name)  # Pass file path

        # Save the image and mask if a mask was returned
        if mask is not None:
            name = f"{idx:04d}.png"
            img.save(os.path.join(ROOTS[split], "images", name))
            Image.fromarray((mask * 255).astype("uint8")).save(os.path.join(ROOTS[split], "masks", name))
            print(f"✓ saved {name} (image & mask) to {ROOTS[split]}")
        else:
            print("⚠ No mask returned, skipped.")

In [None]:
# %matplotlib widget
# import os
# import numpy as np
# import torch
# import matplotlib.pyplot as plt
# from PIL import Image
# from skimage import transform
# from matplotlib.widgets import RectangleSelector


# # ─── MAKE OUTPUT FOLDERS ─────────────────────────────────────────────────────
# BASE = "/home/MedSAM/data/follicle"
# ROOTS = {
#     "train": BASE + "_train",
#     "val":   BASE + "_val",
#     "test":  BASE + "_test",
# }
# for split, root in ROOTS.items():
#     for sub in ("images", "masks"):
#         os.makedirs(os.path.join(root, sub), exist_ok=True)

# # ─── ANNOTATION FUNCTION ────────────────────────────────────────────────────
# def annotate_and_save_multi(idx, root_img, root_mask):
#     # close any leftover figures
#     plt.close("all")

#     pil_img = ds[idx]["image"]
#     img = np.array(pil_img)
#     H, W = img.shape[:2]
#     bboxes = []

#     fig, ax = plt.subplots(figsize=(6,6))
#     ax.imshow(img)
#     ax.set_title(f"Image #{idx}: draw boxes, ENTER when done")

#     def onselect(e0, e1):
#         x1,y1 = int(e0.xdata), int(e0.ydata)
#         x2,y2 = int(e1.xdata), int(e1.ydata)
#         bboxes.append([x1,y1,x2,y2])
#         ax.add_patch(plt.Rectangle((x1,y1), x2-x1, y2-y1,
#                                    edgecolor="yellow", fill=False, lw=2))
#         fig.canvas.draw()

#     selector = RectangleSelector(
#         ax, onselect,
#         interactive=True,
#         useblit=True,
#         button=[1],
#         minspanx=5, minspany=5
#     )
#     fig.canvas.mpl_connect(
#         "key_press_event",
#         lambda ev: plt.close(fig) if ev.key=="enter" else None
#     )
#     plt.show()

#     if not bboxes:
#         print("⚠ no boxes → skipped")
#         return

#     # preprocess & encode once
#     img1024 = transform.resize(
#         img, (1024,1024),
#         order=3, preserve_range=True, anti_aliasing=True
#     ).astype(np.uint8)
#     norm = (img1024 - img1024.min()) / np.clip(
#         img1024.max() - img1024.min(), 1e-8, None
#     )
#     tensor = (
#         torch.tensor(norm)
#              .float()
#              .permute(2,0,1)
#              .unsqueeze(0)
#              .to(device)
#     )
#     with torch.no_grad():
#         embedding = medsam_model.image_encoder(tensor)

#     # run & save one mask per box
#     for i, box in enumerate(bboxes):
#         box1024 = np.array([box]) / np.array([W, H, W, H]) * 1024
#         mask = medsam_inference(medsam_model, embedding, box1024, H, W)

#         name = f"{idx:04d}_{i:02d}.png"
#         pil_img.save(os.path.join(root_img, name))
#         Image.fromarray((mask * 255).astype("uint8")) \
#              .save(os.path.join(root_mask, name))
#         print(f"✓ saved {name} → {root_mask}")

# # ─── PROMPT & ANNOTATE SPLITS ────────────────────────────────────────────────
# bbox_prompt_demo = BboxPromptDemo(medsam_model)

# for split, idxs in [("train", train_idx), ("val", val_idx), ("test", test_idx)]:
#     print(f"\n=== {split.upper()} ({len(idxs)} images) ===")
#     for idx in idxs:
#         img = ds[idx]["image"]
#         print(f"Annotating {split} image #{idx}")
#         bbox_prompt_demo.show(img)
# # 
# #  for split, idxs in [("train", train_idx),
# #                     ("val",   val_idx),
# #                     ("test",  test_idx)]:
# #     print(f"\n=== {split.upper()} ({len(idxs)} images) ===")
# #     for idx in idxs:
# #         input(f"\nPress ENTER to annotate {split} image #{idx} ")
# #         annotate_and_save_multi(
# #             idx,
# #             os.path.join(ROOTS[split], "images"),
# #             os.path.join(ROOTS[split], "masks")
# #         )