In [None]:
import os
import glob
import random
import torch
import numpy as np
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

base_dir      = "/home/s2behappy4/data/gyuhyeong/dataset/MMAD/MVTec-AD"
save_root     = "/home/s2behappy4/data/gyuhyeong/code/bridge_data"
classes       = sorted(os.listdir(base_dir))
n_per_class   = 10
alpha         = 0.5

checkpoint    = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg     = "configs/sam2.1/sam2.1_hiera_l.yaml"
device        = "cuda"
model         = build_sam2(model_cfg, checkpoint).to(device)
mask_generator = SAM2AutomaticMaskGenerator(
    model=model,
    points_per_side=32,
    pred_iou_thresh=0.4,
    stability_score_thresh=0.4,
    mask_threshold=0.3,
    box_nms_thresh=0.8,
    crop_n_layers=2,
    crop_overlap_ratio=0.2,
    crop_n_points_downscale_factor=1,
    min_mask_region_area=300,
    output_mode="binary_mask",
    multimask_output=True,
    use_m2m=False
)

# CutPaste Augmentation (random patch from image)
def cutpaste(img_np):
    h, w, _ = img_np.shape
    area = h * w

    patch_area = np.random.uniform(0.02, 0.15) * area
    aspect = np.random.uniform(0.3, 3.0)
    ph = int(np.sqrt(patch_area / aspect))
    pw = int(np.sqrt(patch_area * aspect))
    ph, pw = min(ph, h), min(pw, w)
    y1 = np.random.randint(0, h - ph + 1)
    x1 = np.random.randint(0, w - pw + 1)
    patch = img_np[y1:y1+ph, x1:x1+pw].copy()
    y2 = np.random.randint(0, h - ph + 1)
    x2 = np.random.randint(0, w - pw + 1)
    out = img_np.copy()
    out[y2:y2+ph, x2:x2+pw] = patch
    return out

# Loop over classes and images, augment → SAM2 → overlay → save
os.makedirs(save_root, exist_ok=True)

for cls in classes:
    good_dir = os.path.join(base_dir, cls, "train", "good")
    image_paths = sorted(glob.glob(os.path.join(good_dir, "*.png")))[:n_per_class]
    cls_save = os.path.join(save_root, cls)
    os.makedirs(cls_save, exist_ok=True)
    
    for img_path in image_paths:
        # 1) load & augment
        img = Image.open(img_path).convert("RGB")
        img_np = np.array(img)
        aug_np = cutpaste(img_np)
        
        # 2) SAM2 mask
        masks = mask_generator.generate(aug_np)
        
        # 3) save
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        for i, ann in enumerate(masks):
            seg = ann["segmentation"]
            overlay = aug_np.astype(np.float32).copy()
            # alpha blending
            overlay[seg] = (1 - alpha) * aug_np[seg] + alpha * np.array([255, 0, 0])
            overlay = overlay.astype(np.uint8)
            
            out_path = os.path.join(
                cls_save,
                f"{cls}_{base_name}_mask{i:02d}.png"
            )
            Image.fromarray(overlay).save(out_path)

# # Overlay Image & Mask save

In [None]:
import os
import glob
import numpy as np
from PIL import Image

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

base_dir     = "/home/s2behappy4/data/gyuhyeong/dataset/MMAD/MVTec-AD"
classes      = ["bottle", "cable", "capsule"]
n_per_class  = 10
alpha        = 0.5

overlay_root = "/home/s2behappy4/data/gyuhyeong/code/bridge_data"
mask_root    = "/home/s2behappy4/data/gyuhyeong/code/bridge_mask"
os.makedirs(overlay_root, exist_ok=True)
os.makedirs(mask_root,    exist_ok=True)

checkpoint     = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg      = "configs/sam2.1/sam2.1_hiera_l.yaml"
device         = "cuda"
model          = build_sam2(model_cfg, checkpoint).to(device)
mask_generator = SAM2AutomaticMaskGenerator(
    model                       = model,
    points_per_side             = 32,
    pred_iou_thresh             = 0.4,
    stability_score_thresh      = 0.4,
    mask_threshold              = 0.3,
    box_nms_thresh              = 0.8,
    crop_n_layers               = 2,
    crop_overlap_ratio          = 0.2,
    crop_n_points_downscale_factor = 1,
    min_mask_region_area        = 300,
    output_mode                 = "binary_mask",
    multimask_output            = True,
    use_m2m                     = False
)

def cutpaste(img_np):
    h, w, _ = img_np.shape
    area     = h * w
    patch_area = np.random.uniform(0.02, 0.15) * area
    aspect     = np.random.uniform(0.3, 3.0)
    ph = int(np.sqrt(patch_area / aspect))
    pw = int(np.sqrt(patch_area * aspect))
    ph, pw = min(ph, h), min(pw, w)
    y1 = np.random.randint(0, h - ph + 1)
    x1 = np.random.randint(0, w - pw + 1)
    patch = img_np[y1:y1+ph, x1:x1+pw].copy()
    y2 = np.random.randint(0, h - ph + 1)
    x2 = np.random.randint(0, w - pw + 1)
    out = img_np.copy()
    out[y2:y2+ph, x2:x2+pw] = patch
    return out

for cls in classes:
    good_dir   = os.path.join(base_dir, cls, "train", "good")
    paths      = sorted(glob.glob(os.path.join(good_dir, "*.png")))[:n_per_class]
    cls_ov_dir = os.path.join(overlay_root, cls); os.makedirs(cls_ov_dir, exist_ok=True)
    cls_ms_dir = os.path.join(mask_root,    cls); os.makedirs(cls_ms_dir, exist_ok=True)

    for img_path in paths:
        base = os.path.splitext(os.path.basename(img_path))[0]
        # 1) load & augment
        img       = Image.open(img_path).convert("RGB")
        img_np    = np.array(img)
        aug_np    = cutpaste(img_np)

        # 2) SAM2 mask
        masks = mask_generator.generate(aug_np)

        # 3) save
        for i, ann in enumerate(masks):
            seg = ann["segmentation"]                

            overlay = aug_np.astype(np.float32).copy()
            overlay[seg] = (1-alpha)*aug_np[seg] + alpha*np.array([255, 0, 0])
            overlay = overlay.astype(np.uint8)

            ov_path = os.path.join(cls_ov_dir, f"{cls}_{base}_mask{i:02d}.png")
            Image.fromarray(overlay).save(ov_path)

            mask_img = (seg.astype(np.uint8) * 255)
            ms_path  = os.path.join(cls_ms_dir, f"{cls}_{base}_mask{i:02d}_mask.png")
            Image.fromarray(mask_img).save(ms_path)

In [2]:
import os

BASE_DIR = "/home/s2behappy4/data/gyuhyeong/code/bridge_data"
classes = ["bottle", "cable", "capsule"]

total_count = 0
counts = {}

for cls in classes:
    class_dir = os.path.join(BASE_DIR, cls)
    if not os.path.isdir(class_dir):
        print(f"오류: {class_dir}")
        counts[cls] = 0
        continue

    n = sum(1 for f in os.listdir(class_dir) if f.lower().endswith(".png"))
    counts[cls] = n
    total_count += n

for cls, n in counts.items():
    print(f"{cls:>8s} 클래스 오버레이 이미지 개수: {n}")
print(f"\n▶ 총 overlay 이미지 개수: {total_count}")

  bottle 클래스 오버레이 이미지 개수: 144
   cable 클래스 오버레이 이미지 개수: 1851
 capsule 클래스 오버레이 이미지 개수: 319

▶ 총 overlay 이미지 개수: 2314
