# # SAM2 & SigLIP load

In [None]:
# SAM2 → overlay image → SigLIP CLS → (mask_token, mask) pair save

import os, cv2, torch, numpy as np
from PIL import Image
from pathlib import Path
from tqdm import tqdm
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from transformers import SiglipVisionModel, SiglipImageProcessor

root_ds   = Path("/home/s2behappy4/data/gyuhyeong/dataset/MMAD/MVTec-AD/hazelnut")
out_root  = Path("/home/s2behappy4/data/gyuhyeong/code/siglip_token/hazelnut")
device    = "cuda" if torch.cuda.is_available() else "cpu"
alpha     = 0.5
save_overlay = False
defects   = ["crack", "cut", "hole", "print"]

sam_model = build_sam2(
    "configs/sam2.1/sam2.1_hiera_l.yaml",
    "./checkpoints/sam2.1_hiera_large.pt"
).to(device)

sam_gen = SAM2AutomaticMaskGenerator(
    model=sam_model,
    points_per_side=64, pred_iou_thresh=0.6, stability_score_thresh=0.5,
    mask_threshold=0.5, box_nms_thresh=0.7, crop_n_layers=2,
    crop_overlap_ratio=0.2, crop_n_points_downscale_factor=1,
    min_mask_region_area=300, output_mode="binary_mask",
    multimask_output=True, use_m2m=False
)

proc = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
vis  = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384").to(device).eval()

In [None]:
for defect in defects:
    img_paths = sorted((root_ds / "test" / defect).glob("*.png"))
    out_dir   = out_root / defect / "01"
    out_dir.mkdir(parents=True, exist_ok=True)
    if save_overlay:
        ov_dir = Path("/home/s2behappy4/data/gyuhyeong/code/overlays/hazelnut") / defect
        ov_dir.mkdir(parents=True, exist_ok=True)

    total = len(img_paths)
    for idx, img_path in enumerate(img_paths, 1):
        print(f"{defect:<5s} {idx:3d}/{total}  {img_path.name} 처리 중...")

        rgb = np.array(Image.open(img_path).convert("RGB"))         
        masks = sam_gen.generate(rgb)
        if not masks:
            print("실패")
            continue

        pairs = []
        for m_idx, ann in enumerate(masks):
            seg = ann["segmentation"].astype(bool)
            overlay = rgb.copy().astype(np.float32)
            red = np.array([255, 0, 0], dtype=np.float32)
            overlay[seg] = (1 - alpha) * overlay[seg] + alpha * red
            overlay = overlay.astype(np.uint8)
            overlay_pil = Image.fromarray(overlay)

            if save_overlay:
                ov_name = f"{img_path.stem}_m{m_idx:02d}_overlay.png"
                overlay_pil.save(ov_dir / ov_name)

            batch = proc(images=overlay_pil, return_tensors="pt").to(device)
            with torch.no_grad():
                cls = vis(**batch).pooler_output.squeeze(0).cpu()

            pairs.append({"mask_token": cls, "mask": torch.from_numpy(seg)})

        torch.save(pairs, out_dir / f"{img_path.stem}_pairs.pt")
        print(f"   ↳ mask {len(pairs):2d}개 → {img_path.stem}_pairs.pt 저장 완료")

# # Good Image

In [None]:
import os, torch, numpy as np
from PIL import Image
from pathlib import Path
from tqdm import tqdm
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from transformers import SiglipVisionModel, SiglipImageProcessor

root_ds   = Path("/home/s2behappy4/data/gyuhyeong/dataset/MMAD/MVTec-AD/hazelnut")
out_root  = Path("/home/s2behappy4/data/gyuhyeong/code/siglip_token/hazelnut/good/01")
out_root.mkdir(parents=True, exist_ok=True)
device    = "cuda" if torch.cuda.is_available() else "cpu"
alpha     = 0.5

sam_model = build_sam2(
    "configs/sam2.1/sam2.1_hiera_l.yaml",
    "./checkpoints/sam2.1_hiera_large.pt"
).to(device)

sam_gen = SAM2AutomaticMaskGenerator(
    model=sam_model,
    points_per_side=64, pred_iou_thresh=0.6, stability_score_thresh=0.5,
    mask_threshold=0.5, box_nms_thresh=0.7, crop_n_layers=2,
    crop_overlap_ratio=0.2, crop_n_points_downscale_factor=1,
    min_mask_region_area=300, output_mode="binary_mask",
    multimask_output=True, use_m2m=False
)

proc = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
vis  = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384").to(device).eval()

In [None]:
good_dir = root_ds / "test" / "good"
img_paths = sorted(good_dir.glob("*.png"))
print(f"정상 이미지 {len(img_paths)}장 처리 시작 ...")

for idx, img_path in enumerate(img_paths, 1):
    print(f"good {idx:3d}/{len(img_paths)}  {img_path.name} 처리 중...")

    rgb = np.array(Image.open(img_path).convert("RGB"))  
    masks = sam_gen.generate(rgb)
    if not masks:
        print("실패")
        continue

    pairs = []
    for m_idx, ann in enumerate(masks):
        seg = ann["segmentation"].astype(bool)
        overlay = rgb.copy().astype(np.float32)
        red = np.array([255, 0, 0], dtype=np.float32)
        overlay[seg] = (1 - alpha) * overlay[seg] + alpha * red
        overlay = overlay.astype(np.uint8)
        overlay_pil = Image.fromarray(overlay)

        batch = proc(images=overlay_pil, return_tensors="pt").to(device)
        with torch.no_grad():
            cls = vis(**batch).pooler_output.squeeze(0).cpu() 

        pairs.append({"mask_token": cls, "mask": torch.from_numpy(seg)})

    torch.save(pairs, out_root / f"{img_path.stem}_pairs.pt")
    print(f"   ↳ mask {len(pairs):2d}개 → {img_path.stem}_pairs.pt 저장 완료")