In [1]:
#!/usr/bin/env python3
# sam2_keep_masks_by_label_rules.py

import os
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
import torch
from transformers import pipeline

# ---------------- CONFIG ----------------
CSV_PATH  = Path("/project/biocomplexity/gza5dr/CAFO_Test/yolo_img_preds_all/detections.csv")
IMG_COL   = "image"
BOX_COLS  = ["x1","y1","x2","y2"]   # change if your CSV uses different names

# We’ll auto-detect the label column from this list (first match wins).
LABEL_CANDIDATES = ["label"]

SKIP_DISCARD = False  # True → drop rows whose label == "discard"

OUT_DIR   = Path("/project/biocomplexity/gza5dr/CAFO_Test/sam_on_yolo_all")
OUT_DIR.mkdir(parents=True, exist_ok=True)

MODEL_ID  = "facebook/sam2-hiera-large"  # or "facebook/sam2-hiera-small"

# Thresholds for the rules
INSIDE_THRESH     = 0.90  # barn/building: fraction of the *mask* inside the bbox
COVER_THRESH      = 0.40  # manure_pond: min fraction of the *box* covered by best mask
OVERLAP_MIN_FRAC  = 0.80  # silo/feedlot/silage*/storage: min fraction of the *box* covered

# NEW: Barn additional constraint — mask bbox area vs det bbox area
BARN_MIN_SIZE_RATIO = 0.25  # keep only if area(mask_bbox)/area(det_bbox) >= this
ANY_INSIDE_FRAC = 0.20   # keep mask if ≥20% of its pixels lie inside the bbox

# ---------------------------------------


# ---------- helpers ----------
def _open_pil(p: Path) -> Image.Image:
    return Image.open(p).convert("RGB")

def _to_numpy(x):
    if x is None: return None
    if isinstance(x, torch.Tensor): return x.detach().cpu().numpy()
    return np.asarray(x)

def _extract_masks(pred):
    """
    pred from HF pipeline("mask-generation"). Return masks: (N,H,W) bool.
    """
    if pred is None:
        return None
    if isinstance(pred, (list, tuple)):
        if len(pred) == 0:
            return None
        pred = pred[0]
    if not isinstance(pred, dict):
        return None

    masks = None
    for k in ["masks", "segments", "segmentation"]:
        if k in pred and pred[k] is not None:
            masks = pred[k]
            break
    if masks is None:
        return None

    masks = _to_numpy(masks)
    if masks is None:
        return None
    if masks.ndim == 2:
        masks = masks[None, ...]
    return (masks > 0)

def _valid_box(x1,y1,x2,y2):
    try:
        x1=int(x1); y1=int(y1); x2=int(x2); y2=int(y2)
    except Exception:
        return False
    return (x2 > x1) and (y2 > y1)

def _clamp_xyxy(x1,y1,x2,y2,W,H):
    x1 = max(0, min(int(x1), W-1))
    y1 = max(0, min(int(y1), H-1))
    x2 = max(0, min(int(x2), W))
    y2 = max(0, min(int(y2), H))
    if x2 <= x1: x2 = min(W, x1+1)
    if y2 <= y1: y2 = min(H, y1+1)
    return [x1,y1,x2,y2]

def _area_box(bx):
    x1,y1,x2,y2 = bx
    return max(0, x2 - x1) * max(0, y2 - y1)

def _mask_bbox(mask_bool: np.ndarray):
    """Tight bbox around a binary mask; returns [x1,y1,x2,y2] or None if empty."""
    ys, xs = np.where(mask_bool)
    if ys.size == 0:
        return None
    y1, y2 = ys.min(), ys.max() + 1
    x1, x2 = xs.min(), xs.max() + 1
    return [int(x1), int(y1), int(x2), int(y2)]

def _frac_inside(mask_bool: np.ndarray, box_xyxy) -> float:
    """
    Fraction of the *mask pixels* that lie inside the box.
    1.0 → mask fully inside the box.
    """
    H, W = mask_bool.shape
    x1,y1,x2,y2 = _clamp_xyxy(*box_xyxy, W, H)
    total = int(mask_bool.sum())
    if total == 0:
        return 0.0
    inside = int(mask_bool[y1:y2, x1:x2].sum())
    return inside / float(total)

def _overlap_frac_of_box(mask_bool: np.ndarray, box_xyxy) -> float:
    """
    Fraction of the *box area* that is covered by the mask.
    0.0 → mask misses the box; 1.0 → mask fully covers the box.
    """
    H, W = mask_bool.shape
    x1,y1,x2,y2 = _clamp_xyxy(*box_xyxy, W, H)
    box_area = max(1, (x2 - x1) * (y2 - y1))
    inter = int(mask_bool[y1:y2, x1:x2].sum())
    return inter / float(box_area)

def _find_label_col(df: pd.DataFrame) -> str | None:
    for c in LABEL_CANDIDATES:
        if c in df.columns:
            return c
    return None

def _label_rule_keep_indices(label, masks, box_xyxy, **_):
    return [i for i, m in enumerate(masks) if _frac_inside((m > 0), box_xyxy) >= ANY_INSIDE_FRAC]

# ---------- main ----------
def main():
    device = 0 if torch.cuda.is_available() else -1
    sam = pipeline("mask-generation", model=MODEL_ID, device=device)

    df = pd.read_csv(CSV_PATH)
    df = df[:1]

    # label column
    label_col = _find_label_col(df)

    # optional: remove discard-labeled rows
    if SKIP_DISCARD and label_col is not None:
        df = df[df[label_col].astype(str).str.lower() != "discard"]

    # sanity
    for c in [IMG_COL] + BOX_COLS:
        if c not in df.columns:
            raise SystemExit(f"CSV missing column: {c}")

    # numeric boxes + drop bad
    for c in BOX_COLS:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=[IMG_COL] + BOX_COLS)

    # group bboxes by image (keep label per row too)
    groups = {}
    for i, r in df.iterrows():
        ip = Path(str(r[IMG_COL]))
        if not ip.exists():
            continue
        x1,y1,x2,y2 = [int(r[c]) for c in BOX_COLS]
        if not _valid_box(x1,y1,x2,y2):
            continue
        label = str(r[label_col]) if label_col is not None else ""
        groups.setdefault(ip, []).append((i, [x1,y1,x2,y2], label))

    if not groups:
        raise SystemExit("No valid images/boxes found.")

    for img_path, rows in groups.items():
        im = _open_pil(img_path)
        W, H = im.size

        # Generate ALL masks once for this image
        try:
            pred = sam(image=im)
        except TypeError:
            pred = sam(im)
        masks_all = _extract_masks(pred)
        if masks_all is None or masks_all.shape[0] == 0:
            print(f"[skip] no masks from SAM2: {img_path.name}")
            continue

        kept_masks   = []
        kept_rows    = []
        kept_boxes   = []
        kept_labels  = []

        # Apply label-specific selection per bbox
        for rid, box, label in rows:
            bx = _clamp_xyxy(*box, W, H)
            keep_idx = _label_rule_keep_indices(label, masks_all, bx)
            for i_m in keep_idx:
                kept_masks.append(masks_all[i_m].astype(np.uint8))
                kept_rows.append(int(rid))
                kept_boxes.append(bx)
                kept_labels.append(str(label))

        if not kept_masks:
            print(f"[ok] {img_path.name}: kept 0 (SAM had {masks_all.shape[0]})")
            continue

        masks_arr   = np.stack(kept_masks, axis=0)               # (K,H,W) uint8
        rows_arr    = np.array(kept_rows, dtype=np.int32)        # (K,)
        boxes_arr   = np.array(kept_boxes, dtype=np.int32)       # (K,4)
        labels_arr  = np.array(kept_labels, dtype=object)        # (K,)

        out_path = OUT_DIR / (img_path.stem + ".npz")
        tmp = out_path.with_suffix(".tmp.npz")
        np.savez_compressed(
            tmp,
            masks=masks_arr,
            det_row_indices=rows_arr,     # which CSV row/bbox this mask is for
            det_boxes_xyxy=boxes_arr,     # the bbox used for selection
            det_labels=labels_arr,        # label string used for rules
            image_size=np.array([H, W], dtype=np.int32),
            params=np.array(
                [INSIDE_THRESH, COVER_THRESH, OVERLAP_MIN_FRAC, BARN_MIN_SIZE_RATIO],
                dtype=np.float32
            ),
        )
        os.replace(tmp, out_path)
        print(f"[ok] {img_path.name}: kept {masks_arr.shape[0]} masks (SAM had {masks_all.shape[0]})")

    print("Done.")


if __name__ == "__main__":
    main()


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0
  return np.asarray(x)


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (59,) + inhomogeneous part.

In [2]:
!pip install --upgrade transformers

Defaulting to user installation because normal site-packages is not writeable
Collecting transformers
  Downloading transformers-4.56.1-py3-none-any.whl.metadata (42 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Downloading transformers-4.56.1-py3-none-any.whl (11.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m175.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.22.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m71.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
[2K  Attempting uninstall: tokenizers
[2K    Found existing installation: tokenizers 0.21.4
[2K    Uninstalling tokenizers-0.21.4:
[2K      Successfully uninstalled tokenizers-0.21.4
[2K  Attempting 