# SAM / SAM‑HQ Batch Inference
This notebook runs Segment‑Anything (SAM) or SAM‑HQ on a folder of images and saves the masks, blends, and probability maps.


## 1. Imports

In [None]:
import os
import torch
from segment_anything import sam_model_registry, get_sam_label
from PIL import Image
from tqdm import tqdm
import numpy as np

## 2. Helper functions

In [None]:
@torch.no_grad()
def get_sam_info(image,
                 box_nms=0.7,
                 min_mask_region_area=100,
                 pred_iou_thresh=0.88,
                 stability_score_thresh=0.92):
    """Run SAM/SAM‑HQ on a single image and return (label, blend, label_P)."""
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        label, blend, label_P = get_sam_label(
            sam, image,
            box_nms=box_nms,
            min_mask_region_area=min_mask_region_area,
            pred_iou_thresh=pred_iou_thresh,
            stability_score_thresh=stability_score_thresh
        )
    return label, blend, label_P


def process_data_img(image_dir, img_name):
    """Run inference on *img_name* and write outputs to disk."""
    png_name = os.path.splitext(img_name)[0] + '.png'
    img      = Image.open(os.path.join(image_dir, img_name))

    label, blend, label_P = get_sam_info(
        img,
        box_nms=box_nms,
        min_mask_region_area=min_region,
        pred_iou_thresh=pred_iou_thresh,
        stability_score_thresh=stability_score_thresh,
    )

    label.save(os.path.join(lbl_dir,  png_name))
    blend.save(os.path.join(bld_dir,  png_name))
    np.save(os.path.join(lblp_dir, png_name.replace('.png', '.npy')), label_P)


def make_dirs():
    """Create output directories if they don't exist."""
    for d in (sam_img_dir, lbl_dir, bld_dir, lblp_dir):
        os.makedirs(d, exist_ok=True)

## 3. Parameters
Feel free to tweak these before running the next cell.

In [None]:
# ─── Inference hyper‑parameters ───────────────────────────
min_region             = 500
box_nms                = 0.6
pred_iou_thresh        = 0.85
stability_score_thresh = 0.85

# ─── Checkpoint ───────────────────────────────────────────
sam_checkpoint = {'vit_h': "SAM/ckpt/sam_vit_h_4b8939.pth"}  # or SAM‑HQ checkpoint
model_type     = "vit_h"

# ─── Paths ────────────────────────────────────────────────
image_dir   = "/path/to/images"  # <- update
sam_img_dir = "/path/to/output"  # <- update

lbl_dir  = os.path.join(sam_img_dir, "label")
bld_dir  = os.path.join(sam_img_dir, "blend")
lblp_dir = os.path.join(sam_img_dir, "label_p")

## 4. Load SAM/SAM‑HQ model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint[model_type])
sam.to(device)
sam.eval()
print("Model loaded on", device)

## 5. Run batch inference

In [None]:
make_dirs()

val_lines = os.listdir(image_dir)
for fname in tqdm(val_lines, desc="Processing"):
    process_data_img(image_dir, fname)