In [1]:
import os, glob, json
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from scipy import ndimage as ndi

from segment_anything import sam_model_registry,  SamAutomaticMaskGenerator, SamPredictor

# =========================
# CONFIG (edit these)
# =========================
IN_DIR = "/scratch/gilbreth/abelde/Thesis/StructureAwareGen/dataset/val2017"
OUT_ROOT = "/scratch/gilbreth/abelde/Thesis/StructureAwareGen/scripts/segProto/out_amg_prod-2-trial"
CHECKPOINT = "/scratch/gilbreth/abelde/Thesis/StructureAwareGen/scripts/segProto/checkpoints/sam_vit_b_01ec64.pth"
MODEL_TYPE = "vit_b"       # vit_b / vit_l / vit_h
MAX_IMAGES = 5             # -1 for all

# AMG "segment everything" knobs (balanced)
POINTS_PER_SIDE = 64       # 128 is heavier; 256 is often overkill
CROP_N_LAYERS = 1          # 0 = no crops, 1-2 helps small objects a lot
CROP_OVERLAP_RATIO = 0.35
CROP_N_POINTS_DOWNSCALE = 2

PRED_IOU_THRESH = 0.80
STABILITY_SCORE_THRESH = 0.85
BOX_NMS_THRESH = 0.70

# OpenCV-free post filter
MIN_MASK_REGION_AREA = 300     # raise (500~2000) to remove tiny junk; 0 disables

# Post-selection (makes it "production")
MAX_KEEP = 250                 # final mask bank size per image
DEDUP_IOU_THRESH = 0.90        # drop masks that overlap too much with kept ones

# Outputs
TOP_K_PNG = 25
OVERLAY_TOP_K = 15
SAVE_MASK_EMB = True           # per-mask SAM embedding (very useful for conditioning)
# =========================


def load_image_rgb(path: str) -> np.ndarray:
    return np.array(Image.open(path).convert("RGB"), dtype=np.uint8)


def mask_stats(mask_bool: np.ndarray):
    mask_bool = np.asarray(mask_bool, dtype=bool)
    ys, xs = np.where(mask_bool)
    if xs.size == 0:
        return None

    h, w = mask_bool.shape
    x0, x1 = int(xs.min()), int(xs.max())
    y0, y1 = int(ys.min()), int(ys.max())

    area_frac = float(mask_bool.mean())
    cx = float(xs.mean() / w)
    cy = float(ys.mean() / h)

    bw = float((x1 - x0 + 1) / w)
    bh = float((y1 - y0 + 1) / h)
    bbox_area_frac = float(((x1 - x0 + 1) * (y1 - y0 + 1)) / (h * w))

    bbox_area_px = max(1, (x1 - x0 + 1) * (y1 - y0 + 1))
    fill_frac = float(mask_bool.sum() / bbox_area_px)

    return {
        "area_frac": area_frac,
        "cx": cx, "cy": cy,
        "bbox_w": bw, "bbox_h": bh,
        "bbox_area_frac": bbox_area_frac,
        "bbox_xyxy": [x0, y0, x1, y1],
        "fill_frac": fill_frac,
    }


def bbox_xywh_from_mask(mask_bool: np.ndarray):
    ys, xs = np.where(mask_bool)
    if xs.size == 0:
        return None
    x0, x1 = int(xs.min()), int(xs.max())
    y0, y1 = int(ys.min()), int(ys.max())
    return [x0, y0, int(x1 - x0 + 1), int(y1 - y0 + 1)]


# ---------- OpenCV-free min_mask_region_area replacement ----------
def _remove_small_islands(mask: np.ndarray, min_area: int) -> np.ndarray:
    if min_area <= 0:
        return mask
    lab, n = ndi.label(mask)
    if n == 0:
        return mask
    sizes = ndi.sum(mask, lab, index=np.arange(1, n + 1))
    keep = np.zeros(n + 1, dtype=bool)
    keep[1:] = sizes >= min_area
    return keep[lab]


def _fill_small_holes(mask: np.ndarray, min_area: int) -> np.ndarray:
    if min_area <= 0:
        return mask
    inv = ~mask
    lab, n = ndi.label(inv)
    if n == 0:
        return mask

    border = np.zeros_like(inv, dtype=bool)
    border[0, :] = border[-1, :] = True
    border[:, 0] = border[:, -1] = True
    border_labels = np.unique(lab[border])

    hole_labels = np.setdiff1d(np.arange(1, n + 1), border_labels, assume_unique=False)
    if hole_labels.size == 0:
        return mask

    hole_sizes = ndi.sum(inv, lab, index=hole_labels)
    small_holes = hole_labels[hole_sizes < min_area]
    if small_holes.size == 0:
        return mask

    filled = mask.copy()
    for hl in small_holes:
        filled[lab == hl] = True
    return filled


def filter_mask_like_sam(mask: np.ndarray, min_area: int) -> np.ndarray:
    mask = mask.astype(bool)
    if min_area <= 0:
        return mask
    mask = _fill_small_holes(mask, min_area)
    mask = _remove_small_islands(mask, min_area)
    return mask


def iou(a: np.ndarray, b: np.ndarray) -> float:
    inter = np.logical_and(a, b).sum()
    if inter == 0:
        return 0.0
    union = np.logical_or(a, b).sum()
    return float(inter / max(1, union))


def overlay_contours(img_rgb: np.ndarray, masks: np.ndarray, scores: np.ndarray, out_path: str, top_k=10):
    if masks.size == 0:
        Image.fromarray(img_rgb).save(out_path)
        return
    k = int(min(top_k, masks.shape[0]))
    order = np.argsort(-scores)[:k]

    plt.figure(figsize=(10, 8))
    plt.imshow(img_rgb)
    plt.axis("off")
    for rank, idx in enumerate(order, start=1):
        plt.contour(masks[idx].astype(float), levels=[0.5], linewidths=2)
        plt.text(10, 20 * rank, f"mask {int(idx)} score={scores[idx]:.3f}",
                 color="white", bbox=dict(facecolor="black", alpha=0.5, pad=2))
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()


def panoptic_viz(img_rgb: np.ndarray, label_map: np.ndarray, out_path: str, alpha=0.45, seed=0):
    H, W, _ = img_rgb.shape
    out = img_rgb.astype(np.float32).copy()
    rng = np.random.default_rng(seed)

    ids = np.unique(label_map)
    ids = ids[ids >= 0]
    colors = {int(i): rng.integers(0, 255, size=(3,), dtype=np.uint8) for i in ids}

    overlay = img_rgb.copy()
    for i in ids:
        m = (label_map == i)
        overlay[m] = colors[int(i)]

    blended = (1 - alpha) * out + alpha * overlay.astype(np.float32)
    blended = np.clip(blended, 0, 255).astype(np.uint8)
    Image.fromarray(blended).save(out_path)


# =========================
# RUN
# =========================
os.makedirs(OUT_ROOT, exist_ok=True)
out_masks_dir = os.path.join(OUT_ROOT, "masks_npz")
out_meta_dir  = os.path.join(OUT_ROOT, "meta")
out_png_dir   = os.path.join(OUT_ROOT, "mask_png")
out_viz_dir   = os.path.join(OUT_ROOT, "viz")

for d in [out_masks_dir, out_meta_dir, out_png_dir, out_viz_dir]:
    os.makedirs(d, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

if not os.path.isfile(CHECKPOINT):
    raise FileNotFoundError(f"Checkpoint not found: {CHECKPOINT}")

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT).to(device)
sam.eval()

# Predictor = lets us compute image encoder features once per image (for mask embeddings)
predictor = SamPredictor(sam)

# IMPORTANT: keep min_mask_region_area=0 to avoid cv2 import
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=POINTS_PER_SIDE,
    pred_iou_thresh=PRED_IOU_THRESH,
    stability_score_thresh=STABILITY_SCORE_THRESH,
    box_nms_thresh=BOX_NMS_THRESH,
    crop_n_layers=CROP_N_LAYERS,
    crop_overlap_ratio=CROP_OVERLAP_RATIO,
    crop_n_points_downscale_factor=CROP_N_POINTS_DOWNSCALE,
    min_mask_region_area=0,
)

paths = sorted(glob.glob(os.path.join(IN_DIR, "*")))
if MAX_IMAGES != -1:
    paths = paths[:MAX_IMAGES]

for p in tqdm(paths, desc="AMG-prod"):
    name = os.path.splitext(os.path.basename(p))[0]
    img_rgb = load_image_rgb(p)
    H, W, _ = img_rgb.shape

    # 1) Generate proposals
    amg = mask_generator.generate(img_rgb)

    # 2) Compute image embedding once (optional but powerful)
    if SAVE_MASK_EMB:
        predictor.set_image(img_rgb)
        feat = predictor.get_image_embedding()  # (1,C,hf,wf), vit_b -> usually C=256, hf=wf=64
        hf, wf = feat.shape[-2], feat.shape[-1]
    else:
        feat, hf, wf = None, None, None

    candidates = []
    for orig_i, m in enumerate(amg):
        seg = m["segmentation"].astype(bool)
        if MIN_MASK_REGION_AREA > 0:
            seg = filter_mask_like_sam(seg, MIN_MASK_REGION_AREA)

        area_px = int(seg.sum())
        if area_px == 0:
            continue

        pred_iou = float(m.get("predicted_iou", 0.0))
        stab = float(m.get("stability_score", 0.0))
        score = pred_iou * stab

        st = mask_stats(seg)
        if st is None:
            continue

        # per-mask embedding from SAM image encoder
        emb = None
        if SAVE_MASK_EMB:
            mask_t = torch.from_numpy(seg[None, None].astype(np.float32)).to(device)
            mask_small = F.interpolate(mask_t, size=(hf, wf), mode="nearest")
            denom = mask_small.sum(dim=(2, 3)) + 1e-6
            emb_t = (feat * mask_small).sum(dim=(2, 3)) / denom  # (1,C)
            emb = emb_t.squeeze(0).detach().cpu().to(torch.float16).numpy()  # (C,)

        candidates.append({
            "orig_amg_index": int(orig_i),
            "seg": seg,
            "score": float(score),
            "predicted_iou": pred_iou,
            "stability_score": stab,
            "area_px": area_px,
            "bbox_xywh": bbox_xywh_from_mask(seg),
            "stats": st,
            "emb": emb,
        })

    # 3) Sort and deduplicate (greedy IoU)
    candidates.sort(key=lambda x: x["score"], reverse=True)

    kept = []
    for cand in candidates:
        if len(kept) >= MAX_KEEP:
            break
        ok = True
        for prev in kept:
            if iou(cand["seg"], prev["seg"]) >= DEDUP_IOU_THRESH:
                ok = False
                break
        if ok:
            kept.append(cand)

    # 4) Build final arrays
    N = len(kept)
    if N == 0:
        masks = np.zeros((0, H, W), dtype=np.bool_)
        scores = np.zeros((0,), dtype=np.float32)
        embs = None
        label_map = -np.ones((H, W), dtype=np.int32)
        meta_masks = []
    else:
        masks = np.stack([k["seg"] for k in kept], axis=0).astype(np.bool_)
        scores = np.asarray([k["score"] for k in kept], dtype=np.float32)

        if SAVE_MASK_EMB:
            embs = np.stack([k["emb"] for k in kept], axis=0)  # (N,C) float16
        else:
            embs = None

        # Non-overlapping label map: assign pixels to best mask first
        label_map = -np.ones((H, W), dtype=np.int32)
        occupied = np.zeros((H, W), dtype=bool)
        order = np.argsort(-scores)
        for new_id in order:
            pix = masks[new_id] & (~occupied)
            if pix.sum() == 0:
                continue
            label_map[pix] = int(new_id)
            occupied[pix] = True

        # Metadata
        meta_masks = []
        for new_id, k in enumerate(kept):
            st = dict(k["stats"])
            st.update({
                "mask_index": int(new_id),
                "orig_amg_index": int(k["orig_amg_index"]),
                "score": float(k["score"]),
                "predicted_iou": float(k["predicted_iou"]),
                "stability_score": float(k["stability_score"]),
                "area_px": int(k["area_px"]),
                "bbox_xywh": k["bbox_xywh"],
            })
            meta_masks.append(st)

    # 5) Save compressed masks (+ embeddings)
    packed = np.packbits(masks.reshape(masks.shape[0], -1), axis=1) if masks.shape[0] > 0 else np.zeros((0, 0), dtype=np.uint8)
    np.savez_compressed(
        os.path.join(out_masks_dir, f"{name}.npz"),
        packed=packed,
        shape=np.array(masks.shape, dtype=np.int32),
        scores=scores,
        label_map=label_map,
        emb=embs,
    )

    with open(os.path.join(out_meta_dir, f"{name}.json"), "w") as f:
        json.dump({
            "image": p,
            "device": device,
            "num_masks": int(masks.shape[0]),
            "amg_params": {
                "points_per_side": POINTS_PER_SIDE,
                "crop_n_layers": CROP_N_LAYERS,
                "crop_overlap_ratio": CROP_OVERLAP_RATIO,
                "crop_n_points_downscale_factor": CROP_N_POINTS_DOWNSCALE,
                "pred_iou_thresh": PRED_IOU_THRESH,
                "stability_score_thresh": STABILITY_SCORE_THRESH,
                "box_nms_thresh": BOX_NMS_THRESH,
                "min_mask_region_area_post": MIN_MASK_REGION_AREA,
                "dedup_iou_thresh": DEDUP_IOU_THRESH,
                "max_keep": MAX_KEEP,
            },
            "masks": meta_masks,
        }, f, indent=2)

    # 6) PNG exports for debugging
    if masks.shape[0] > 0:
        order = np.argsort(-scores)[:min(TOP_K_PNG, len(scores))]
        for rank, idx in enumerate(order, start=1):
            msk = (masks[idx].astype(np.uint8) * 255)
            Image.fromarray(msk).save(os.path.join(out_png_dir, f"{name}_mask{idx:04d}_rank{rank}_score{scores[idx]:.3f}.png"))

    overlay_contours(img_rgb, masks, scores, os.path.join(out_viz_dir, f"{name}_contours.png"), top_k=OVERLAY_TOP_K)
    panoptic_viz(img_rgb, label_map, os.path.join(out_viz_dir, f"{name}_panoptic.png"), alpha=0.45, seed=0)

print(f"[DONE] outputs at: {OUT_ROOT}")


  backends.update(_get_backends("networkx.backends"))


device: cuda


AMG-prod:   0%|          | 0/5 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 768.00 MiB. GPU 0 has a total capacity of 23.60 GiB of which 720.19 MiB is free. Process 1728920 has 396.00 MiB memory in use. Process 775485 has 224.00 MiB memory in use. Process 3628798 has 21.38 GiB memory in use. Including non-PyTorch memory, this process has 798.00 MiB memory in use. Of the allocated memory 451.72 MiB is allocated by PyTorch, and 70.28 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
print("h")

In [2]:
!nvidia-smi

Sun Dec 21 15:12:10 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A30                     On  |   00000000:21:00.0 Off |                    0 |
| N/A   36C    P0             31W /  165W |   21782MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                