In [23]:
import os
import cv2
import numpy as np
from PIL import Image
import torch
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
import logging
from tqdm import tqdm
import json

# Assuming sam_scripts is in the same parent directory or PYTHONPATH includes it
from sam2_scripts.util import generate_polygons_from_masks

copy code here from batch_procesor; see how it looks; and correct acc.

In [24]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

SUPPORTED_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

In [25]:
def select_device(device_str='auto'):
    if device_str == 'cuda' and torch.cuda.is_available():
        return torch.device('cuda')
    # Add MPS check if needed
    elif device_str == 'mps' and torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

In [26]:
image_dir = '/Users/jamesemilian/Desktop/cronus-ai/figure8/anomaly_detect/data/coco_data_prep/coco_val_1pct/data'
model_config = "configs/sam2.1/sam2.1_hiera_t.yaml" # Default model config
model_checkpoint = "checkpoints/sam2.1_hiera_tiny.pt" # Default checkpoint
device_str = "auto"
mask_gen_kwargs = None

In [32]:
# Handle mask generator arguments - use defaults similar to segment.py 'dev'
if mask_gen_kwargs is None:
        mask_gen_kwargs = {
        "points_per_side": 32,
        "pred_iou_thresh": 0.86,
        "stability_score_thresh": 0.92,
        "crop_n_layers": 1,
        "crop_n_points_downscale_factor": 2,
        "min_mask_region_area": 100, # Default value from SAM
        "output_mode": "binary_mask" # Ensure we get masks for polygon conversion
    }

device = select_device(device_str=device_str)
logging.info(f"Using device: {device}")

try:
    logging.info("Loading SAM 2 model...")
    sam2 = build_sam2(model_config, model_checkpoint, device=device, apply_postprocessing=False)
    # mask_generator = Sam2AutomaticMaskGenerator(model, **mask_gen_kwargs)
    mask_generator = SAM2AutomaticMaskGenerator(sam2)
    logging.info("Model loaded successfully.")
except Exception as e:
    logging.error(f"Failed to load SAM 2 model: {e}")

all_masks_data = {}
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(SUPPORTED_EXTENSIONS)]

if not image_files:
    logging.warning(f"No supported image files found in {image_dir}")

logging.info(f"Found {len(image_files)} images to process.")


2025-04-20 12:04:48,615 - INFO - Using device: cpu
2025-04-20 12:04:48,616 - INFO - Loading SAM 2 model...
2025-04-20 12:04:48,911 - INFO - Loaded checkpoint sucessfully
2025-04-20 12:04:48,926 - INFO - Model loaded successfully.
2025-04-20 12:04:48,927 - INFO - Found 245 images to process.


In [33]:
image_files = image_files[0:3]

In [34]:
image_files

['000000107814.jpg', '000000443537.jpg', '000000127905.jpg']

In [35]:

for filename in tqdm(image_files, desc="Processing images"):
    image_path = os.path.join(image_dir, filename)
    image_name = os.path.basename(filename) # Use basename as key
    logging.info(f"Processing {image_name}...")
    print(image_name)
    # Load image using OpenCV (consistent with segment.py)
    image = cv2.imread(image_path)
    if image is None:
        logging.warning(f"Could not read image: {image_path}. Skipping.")
        continue
    # Convert BGR to RGB
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_shape = image_rgb.shape[:2] # (height, width)

    logging.info(f"Generating masks for {image_name}...")
    # Generate raw masks
    raw_masks = mask_generator.generate(image_rgb)
    logging.info(f"Generated {len(raw_masks)} raw masks for {image_name}.")

    if not raw_masks:
        logging.warning(f"No masks generated for {image_name}. Skipping polygon conversion.")
        all_masks_data[image_name] = []
        continue

    # Process masks to get polygons, base64 etc.
    logging.info(f"Converting masks to polygons for {image_name}...")
    processed_masks = generate_polygons_from_masks(raw_masks, image_shape)
    logging.info(f"Successfully processed masks for {image_name}.")

    all_masks_data[image_name] = processed_masks

Processing images:   0%|          | 0/3 [00:00<?, ?it/s]2025-04-20 12:04:55,021 - INFO - Processing 000000107814.jpg...
2025-04-20 12:04:55,026 - INFO - Generating masks for 000000107814.jpg...
2025-04-20 12:04:55,026 - INFO - For numpy array image, we assume (HxWxC) format
2025-04-20 12:04:55,042 - INFO - Computing image embeddings for the provided image...


000000107814.jpg


2025-04-20 12:04:55,721 - INFO - Image embeddings computed.
2025-04-20 12:05:22,248 - INFO - Generated 1 raw masks for 000000107814.jpg.
2025-04-20 12:05:22,249 - INFO - Converting masks to polygons for 000000107814.jpg...
Processing images:   0%|          | 0/3 [00:27<?, ?it/s]


TypeError: generate_polygons_from_masks() missing 1 required positional argument: 'image_height'

In [37]:
# OK, this API call is just stuck! 
# clean up this code and run to see if it gets me masks

In [None]:

# environment = os.getenv("ENV", "prod")
# device = select_device()

# print("SEGMENT_SELECTED_DEVICE:", device)

# input = request.get_json()
# image_url = input["image"]
# storage_token = input["storageRefsToken"]

# cache_file_path = generate_cache_file_path(image_url)

# try:
#     response = requests.get(image_url)

#     print(f"Image URL Response Code: {response.status_code}")

image_dir = '/Users/jamesemilian/Desktop/cronus-ai/figure8/anomaly_detect/data/coco_data_prep/coco_val_1pct/data'
model_config = "configs/sam2.1/sam2.1_hiera_t.yaml" # Default model config
model_checkpoint = "checkpoints/sam2.1_hiera_tiny.pt" # Default checkpoint
device_str = "auto"
mask_gen_kwargs = None

image = image_rgb
image_width, image_height = image.size
image_np = np.array(image)

print("Start Mask Generator")
# print("Cache File Path", cache_file_path)

# has_cached_data = fetch_from_cds(cache_file_path, storage_token)
# if has_cached_data is not None:
#     sam_result = has_cached_data
# else:
    sam2 = build_sam2(
        model_config, model_checkpoint, device=device, apply_postprocessing=False
    )

    if environment == "dev":
        mask_generator = SAM2AutomaticMaskGenerator(sam2)
    else:
        # This configuration generates better masks
        # However, without using cuda there is a performance issue
        # Since not all devices support cuda
        # Only prod need to use this configuration
        mask_generator = SAM2AutomaticMaskGenerator(
            model=sam2,
            points_per_side=64,
            points_per_batch=128,
            pred_iou_thresh=0.7,
            stability_score_thresh=0.92,
            stability_score_offset=0.7,
            crop_n_layers=1,
            box_nms_thresh=0.7,
            crop_n_points_downscale_factor=2,
            min_mask_region_area=25.0,
            use_m2m=True,
        )

    sam_result = mask_generator.generate(image_np)
    save_to_cds(cache_file_path, sam_result, storage_token)

print("Mask Generator Finished")

mask_data_list = []
for idx, mask_data in enumerate(sam_result):
    mask_info = {
        "id": idx,
        "bounding_boxes": mask_data["bbox"],
        "point_coords": mask_data["point_coords"],
        "segmentation": mask_data["segmentation"],
    }
    mask_data_list.append(mask_info)

print("Generate polygons from masks", len(mask_data_list))
valid_polygons = generate_polygons_from_masks(
    mask_data_list, image_width, image_height
)

final_response = {"masks": valid_polygons}

return final_response

except Exception as e:
message = f"Failed to predict segmentation: {e}"
return build_response_error(message)

In [38]:
image_rgb.shape

(427, 640, 3)