In [2]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

--2025-08-09 17:38:05--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
dl.fbaipublicfiles.com (dl.fbaipublicfiles.com) をDNSに問いあわせています... 3.173.197.128, 3.173.197.49, 3.173.197.101, ...
dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.173.197.128|:443 に接続しています... 接続しました。
HTTP による接続要求を送信しました、応答を待っています... 200 OK
長さ: 2564550879 (2.4G) [binary/octet-stream]
`sam_vit_h_4b8939.pth' に保存中


2025-08-09 17:42:25 (9.40 MB/s) - `sam_vit_h_4b8939.pth' へ保存完了 [2564550879/2564550879]



In [None]:
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import cv2
import numpy as np
from pathlib import Path
import torch
from types import MethodType  # 追加
import clip  # 追加
import PIL.Image  # 追加

# input/output directories
input_dir = Path("/Users/Kota/blended/Team3AmazonProject/data/temp/original")
output_dir = Path("/Users/Kota/blended/Team3AmazonProject/data/temp/cropped_sam_with_clip")
output_dir.mkdir(parents=True, exist_ok=True)

# SAM model
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

# For CLIP
clip_device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=clip_device)
clip_labels = ["sofa", "chair", "table"]

# For MPS compatibility
if device == "mps":
    print("⚠️ MPS detected: Patching apply_coords to return float32")
    orig_apply_coords = mask_generator.predictor.transform.apply_coords

    def apply_coords_float32(self, coords, size):
        out = orig_apply_coords(coords, size)
        # numpy配列のままfloat32へ
        return out.astype(np.float32)

    mask_generator.predictor.transform.apply_coords = MethodType(
        apply_coords_float32, mask_generator.predictor.transform
    )

# 画像拡張子
image_exts = {".jpg", ".png"}

for img_path in input_dir.glob("*"):
    if img_path.suffix.lower() not in image_exts:
        continue

    original = cv2.imread(str(img_path))
    if original is None:
        print(f"⚠️ Failed to read: {img_path}")
        continue

    # CLIPでsofa/chair/table判定
    pil_img = PIL.Image.fromarray(cv2.cvtColor(original, cv2.COLOR_BGR2RGB))
    image_input = clip_preprocess(pil_img).unsqueeze(0).to(clip_device)
    text_inputs = clip.tokenize(clip_labels).to(clip_device)
    with torch.no_grad():
        image_features = clip_model.encode_image(image_input)
        text_features = clip_model.encode_text(text_inputs)
        logits_per_image, _ = clip_model(image_input, text_inputs)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0]
    # いずれかのラベルが0.5以上ならマスク処理
    if probs.max() < 0.5:
        print(f"⏩ Skip (not sofa/chair/table): {img_path.name}")
        continue

    # SAMはRGBのuint8でOK。float32化は不要・逆効果
    image_rgb = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)

    # マスク生成
    masks = mask_generator.generate(image_rgb)
    if not masks:
        print(f"❌ No mask detected: {img_path.name}")
        continue

    # 一番大きいマスクを採用
    largest_mask = max(masks, key=lambda x: x["area"])
    mask = largest_mask["segmentation"].astype(np.uint8)

    # マスク適用（黒背景）
    mask_3c = np.stack([mask] * 3, axis=-1)
    masked_img = original * mask_3c

    save_path = output_dir / f"{img_path.stem}_sam_cropped.png"
    cv2.imwrite(str(save_path), masked_img)
    print(f"✅ Saved: {save_path}")

print("SAM segmentation completed!")

Python(60893) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
100%|███████████████████████████████████████| 338M/338M [00:06<00:00, 57.5MiB/s]



⚠️ MPS detected: Patching apply_coords to return float32


KeyboardInterrupt: 