In [2]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

--2025-08-09 17:38:05--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
dl.fbaipublicfiles.com (dl.fbaipublicfiles.com) をDNSに問いあわせています... 3.173.197.128, 3.173.197.49, 3.173.197.101, ...
dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|3.173.197.128|:443 に接続しています... 接続しました。
HTTP による接続要求を送信しました、応答を待っています... 200 OK
長さ: 2564550879 (2.4G) [binary/octet-stream]
`sam_vit_h_4b8939.pth' に保存中


2025-08-09 17:42:25 (9.40 MB/s) - `sam_vit_h_4b8939.pth' へ保存完了 [2564550879/2564550879]



In [1]:
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import cv2
import numpy as np
from pathlib import Path
import torch
from types import MethodType  # 追加

# 入出力
input_dir = Path("/Users/Kota/blended/Team3AmazonProject/data/temp/original")
output_dir = Path("/Users/Kota/blended/Team3AmazonProject/data/temp/cropped_sam")
output_dir.mkdir(parents=True, exist_ok=True)

# SAMモデル
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

# ★ MPS対策：apply_coordsの戻りを必ずfloat32に（ライブラリは未修正のまま）
if device == "mps":
    print("⚠️ MPS detected: Patching apply_coords to return float32")
    orig_apply_coords = mask_generator.predictor.transform.apply_coords

    def apply_coords_float32(self, coords, size):
        out = orig_apply_coords(coords, size)
        # numpy配列のままfloat32へ
        return out.astype(np.float32)

    mask_generator.predictor.transform.apply_coords = MethodType(
        apply_coords_float32, mask_generator.predictor.transform
    )

# 画像拡張子
image_exts = {".jpg", ".png"}

for img_path in input_dir.glob("*"):
    if img_path.suffix.lower() not in image_exts:
        continue

    original = cv2.imread(str(img_path))
    if original is None:
        print(f"⚠️ Failed to read: {img_path}")
        continue

    # SAMはRGBのuint8でOK。float32化は不要・逆効果
    image_rgb = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)

    # マスク生成
    masks = mask_generator.generate(image_rgb)
    if not masks:
        print(f"❌ No mask detected: {img_path.name}")
        continue

    # 一番大きいマスクを採用
    largest_mask = max(masks, key=lambda x: x["area"])
    mask = largest_mask["segmentation"].astype(np.uint8)

    # マスク適用（黒背景）
    mask_3c = np.stack([mask] * 3, axis=-1)
    masked_img = original * mask_3c

    save_path = output_dir / f"{img_path.stem}_sam_cropped.png"
    cv2.imwrite(str(save_path), masked_img)
    print(f"✅ Saved: {save_path}")

print("SAM segmentation completed!")


⚠️ MPS detected: Patching apply_coords to return float32
✅ Saved: /Users/Kota/blended/Team3AmazonProject/data/temp/cropped_sam/Table--154-_jpg.rf.67264c7a7b156dd01c44e5e35de6cbe9_sam_cropped.png
✅ Saved: /Users/Kota/blended/Team3AmazonProject/data/temp/cropped_sam/Table--154-_jpg.rf.67264c7a7b156dd01c44e5e35de6cbe9_sam_cropped.png
✅ Saved: /Users/Kota/blended/Team3AmazonProject/data/temp/cropped_sam/Sofa--348-_jpg.rf.74a1bda29972fc468b85d6c9eab48bbd_sam_cropped.png
✅ Saved: /Users/Kota/blended/Team3AmazonProject/data/temp/cropped_sam/Sofa--348-_jpg.rf.74a1bda29972fc468b85d6c9eab48bbd_sam_cropped.png


KeyboardInterrupt: 