In [1]:
import os
import cv2
import numpy as np
import torch
import pickle
import json
import random
import sys
from PIL import Image
import torchvision.transforms as T
from tqdm import tqdm

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

In [None]:
def main():
    print("--- Initializing Configuration and Models ---")

    CONFIG = {
        "sam_cfg_path": "configs/sam2.1/sam2.1_hiera_l.yaml",
        "sam_ckpt_path": "./checkpoints/sam2.1_hiera_large.pt",
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "base_data_dir": "/home/s2behappy4/data/gyuhyeong/dataset/MMAD/MVTec-AD/grid/train/good/",
        "output_dir": "/home/s2behappy4/data/gyuhyeong/MLLM_Anomaly/Demo_data/",
        "image_size": 512,
        "variants_per_image": 20,
        "patch_size_range": (64, 128)
    }
    print(f"Using device: {CONFIG['device']}")
    print(f"Output will be saved to: {CONFIG['output_dir']}")

    sam_model = build_sam2(CONFIG["sam_cfg_path"], CONFIG["sam_ckpt_path"], device=CONFIG["device"])
    sam_model.eval()

    mask_generator = SAM2AutomaticMaskGenerator(
        model=sam_model, points_per_side=32, points_per_batch=64, pred_iou_thresh=0.90, 
        stability_score_thresh=0.95, box_nms_thresh=0.1, crop_n_layers=0, point_grids=None, 
        min_mask_region_area=10, output_mode="binary_mask", use_m2m=False, multimask_output=True
    )

    resize_transform = T.Resize((CONFIG["image_size"], CONFIG["image_size"]), interpolation=T.InterpolationMode.BICUBIC)

    print(f"\n--- Starting Data Generation ---")
    good_image_paths = sorted([os.path.join(CONFIG["base_data_dir"], f) for f in os.listdir(CONFIG["base_data_dir"]) if f.endswith('.png')])
    os.makedirs(CONFIG["output_dir"], exist_ok=True)

    for base_image_path in tqdm(good_image_paths, desc="Processing Base Images"):
        try:
            base_image_name = os.path.splitext(os.path.basename(base_image_path))[0]
            base_image_pil = Image.open(base_image_path).convert("RGB")
            base_image_resized_pil = resize_transform(base_image_pil)
            base_image_cv = cv2.cvtColor(np.array(base_image_resized_pil), cv2.COLOR_RGB2BGR)

            original_masks_data = mask_generator.generate(base_image_cv)
            
            if not original_masks_data:
                print(f"Warning: SAM2 did not generate any masks for {base_image_name}. Skipping.")
                continue

            raw_masks = [d['segmentation'] for d in original_masks_data]

            h, w = raw_masks[0].shape
            areas = [m.sum() for m in raw_masks]
            sorted_idx = np.argsort(areas)
            union_prev = np.zeros((h, w), dtype=bool)
            clean_masks = []

            for idx in sorted_idx:
                m = raw_masks[idx]
                m_clean = m & (~union_prev)
                if m_clean.sum() > 0:
                    clean_masks.append(m_clean)
                    union_prev |= m_clean
            
            union_all = union_prev.copy()
            inv_mask = ~union_all
            if inv_mask.sum() >= 10:
                clean_masks.append(inv_mask)
                
            processed_masks = clean_masks 

            for i in tqdm(range(CONFIG["variants_per_image"]), desc=f"  Creating Variants for {base_image_name}", leave=False):
                variant_name = f"{base_image_name}_variant{i+1}"
                variant_output_dir = os.path.join(CONFIG["output_dir"], variant_name)
                os.makedirs(variant_output_dir, exist_ok=True)

                source_image_path = random.choice([p for p in good_image_paths if p != base_image_path])
                source_image_pil = Image.open(source_image_path).convert("RGB")
                source_image_resized_pil = resize_transform(source_image_pil)

                patch_size = random.randint(*CONFIG["patch_size_range"])
                left = random.randint(0, CONFIG["image_size"] - patch_size)
                top = random.randint(0, CONFIG["image_size"] - patch_size)
                patch = source_image_resized_pil.crop((left, top, left + patch_size, top + patch_size))

                angle = random.uniform(-15, 15)
                scale = random.uniform(0.9, 1.1)
                new_size = int(patch_size * scale)

                patch_rgba = patch.convert("RGBA")

                rotated_patch = patch_rgba.rotate(angle, expand=True, resample=Image.BICUBIC)
                final_patch = rotated_patch.resize((new_size, new_size), resample=Image.BICUBIC)

                anomalous_image_pil = base_image_resized_pil.copy()
                paste_left = random.randint(0, CONFIG["image_size"] - new_size)
                paste_top = random.randint(0, CONFIG["image_size"] - new_size)

                anomalous_image_pil.paste(final_patch, (paste_left, paste_top), final_patch) 
                
                anomaly_bbox = (paste_left, paste_top, paste_left + new_size, paste_top + new_size)

                labels = {}
                bbox_mask = np.zeros((h, w), dtype=bool)
                x1, y1, x2, y2 = anomaly_bbox
                bbox_mask[y1:y2, x1:x2] = True
                
                for idx, mask in enumerate(processed_masks):
                    if np.any(mask & bbox_mask):
                        labels[str(idx)] = "anomaly"
                    else:
                        labels[str(idx)] = "normal"

                anomalous_image_pil.save(os.path.join(variant_output_dir, "image.png"))
                with open(os.path.join(variant_output_dir, "original_masks.pkl"), "wb") as f:
                    pickle.dump(processed_masks, f)
                with open(os.path.join(variant_output_dir, "labels.json"), "w") as f:
                    json.dump(labels, f)
        
        except Exception as e:
            print(f"Error processing {base_image_path}: {e}")
            continue

    print("\n--- Data generation finished successfully! ---")

if __name__ == "__main__":
    main()