In [23]:
import os
import cv2
import numpy as np
from PIL import Image
import torch
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
import logging
from tqdm import tqdm
import json

# Assuming sam_scripts is in the same parent directory or PYTHONPATH includes it
from sam2_scripts.util import generate_polygons_from_masks

copy code here from batch_procesor; see how it looks; and correct acc.

In [24]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

SUPPORTED_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')

In [25]:
def select_device(device_str='auto'):
    if device_str == 'cuda' and torch.cuda.is_available():
        return torch.device('cuda')
    # Add MPS check if needed
    elif device_str == 'mps' and torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

In [26]:
image_dir = '/Users/jamesemilian/Desktop/cronus-ai/figure8/anomaly_detect/data/coco_data_prep/coco_val_1pct/data'
model_config = "configs/sam2.1/sam2.1_hiera_t.yaml" # Default model config
model_checkpoint = "checkpoints/sam2.1_hiera_tiny.pt" # Default checkpoint
device_str = "auto"
mask_gen_kwargs = None

In [32]:
# Handle mask generator arguments - use defaults similar to segment.py 'dev'
if mask_gen_kwargs is None:
        mask_gen_kwargs = {
        "points_per_side": 32,
        "pred_iou_thresh": 0.86,
        "stability_score_thresh": 0.92,
        "crop_n_layers": 1,
        "crop_n_points_downscale_factor": 2,
        "min_mask_region_area": 100, # Default value from SAM
        "output_mode": "binary_mask" # Ensure we get masks for polygon conversion
    }

device = select_device(device_str=device_str)
logging.info(f"Using device: {device}")

try:
    logging.info("Loading SAM 2 model...")
    sam2 = build_sam2(model_config, model_checkpoint, device=device, apply_postprocessing=False)
    # mask_generator = Sam2AutomaticMaskGenerator(model, **mask_gen_kwargs)
    mask_generator = SAM2AutomaticMaskGenerator(sam2)
    logging.info("Model loaded successfully.")
except Exception as e:
    logging.error(f"Failed to load SAM 2 model: {e}")

all_masks_data = {}
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(SUPPORTED_EXTENSIONS)]

if not image_files:
    logging.warning(f"No supported image files found in {image_dir}")

logging.info(f"Found {len(image_files)} images to process.")


2025-04-20 12:04:48,615 - INFO - Using device: cpu
2025-04-20 12:04:48,616 - INFO - Loading SAM 2 model...
2025-04-20 12:04:48,911 - INFO - Loaded checkpoint sucessfully
2025-04-20 12:04:48,926 - INFO - Model loaded successfully.
2025-04-20 12:04:48,927 - INFO - Found 245 images to process.


In [33]:
image_files = image_files[0:3]

In [34]:
image_files

['000000107814.jpg', '000000443537.jpg', '000000127905.jpg']

In [35]:

for filename in tqdm(image_files, desc="Processing images"):
    image_path = os.path.join(image_dir, filename)
    image_name = os.path.basename(filename) # Use basename as key
    logging.info(f"Processing {image_name}...")
    print(image_name)
    # Load image using OpenCV (consistent with segment.py)
    image = cv2.imread(image_path)
    if image is None:
        logging.warning(f"Could not read image: {image_path}. Skipping.")
        continue
    # Convert BGR to RGB
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_shape = image_rgb.shape[:2] # (height, width)

    logging.info(f"Generating masks for {image_name}...")
    # Generate raw masks
    raw_masks = mask_generator.generate(image_rgb)
    logging.info(f"Generated {len(raw_masks)} raw masks for {image_name}.")

    if not raw_masks:
        logging.warning(f"No masks generated for {image_name}. Skipping polygon conversion.")
        all_masks_data[image_name] = []
        continue

    # Process masks to get polygons, base64 etc.
    logging.info(f"Converting masks to polygons for {image_name}...")
    processed_masks = generate_polygons_from_masks(raw_masks, image_shape)
    logging.info(f"Successfully processed masks for {image_name}.")

    all_masks_data[image_name] = processed_masks

Processing images:   0%|          | 0/3 [00:00<?, ?it/s]2025-04-20 12:04:55,021 - INFO - Processing 000000107814.jpg...
2025-04-20 12:04:55,026 - INFO - Generating masks for 000000107814.jpg...
2025-04-20 12:04:55,026 - INFO - For numpy array image, we assume (HxWxC) format
2025-04-20 12:04:55,042 - INFO - Computing image embeddings for the provided image...


000000107814.jpg


2025-04-20 12:04:55,721 - INFO - Image embeddings computed.
2025-04-20 12:05:22,248 - INFO - Generated 1 raw masks for 000000107814.jpg.
2025-04-20 12:05:22,249 - INFO - Converting masks to polygons for 000000107814.jpg...
Processing images:   0%|          | 0/3 [00:27<?, ?it/s]


TypeError: generate_polygons_from_masks() missing 1 required positional argument: 'image_height'

In [37]:
# OK, this API call is just stuck! 
# clean up this code and run to see if it gets me masks

In [None]:

# environment = os.getenv("ENV", "prod")
# device = select_device()

# print("SEGMENT_SELECTED_DEVICE:", device)

# input = request.get_json()
# image_url = input["image"]
# storage_token = input["storageRefsToken"]

# cache_file_path = generate_cache_file_path(image_url)

# try:
#     response = requests.get(image_url)

#     print(f"Image URL Response Code: {response.status_code}")

image_dir = '/Users/jamesemilian/Desktop/cronus-ai/figure8/anomaly_detect/data/coco_data_prep/coco_val_1pct/data'
model_config = "configs/sam2.1/sam2.1_hiera_t.yaml" # Default model config
model_checkpoint = "checkpoints/sam2.1_hiera_tiny.pt" # Default checkpoint
device_str = "auto"
mask_gen_kwargs = None

image = image_rgb
image_width, image_height = image.size
image_np = np.array(image)

print("Start Mask Generator")
# print("Cache File Path", cache_file_path)

# has_cached_data = fetch_from_cds(cache_file_path, storage_token)
# if has_cached_data is not None:
#     sam_result = has_cached_data
# else:
    sam2 = build_sam2(
        model_config, model_checkpoint, device=device, apply_postprocessing=False
    )

    if environment == "dev":
        mask_generator = SAM2AutomaticMaskGenerator(sam2)
    else:
        # This configuration generates better masks
        # However, without using cuda there is a performance issue
        # Since not all devices support cuda
        # Only prod need to use this configuration
        mask_generator = SAM2AutomaticMaskGenerator(
            model=sam2,
            points_per_side=64,
            points_per_batch=128,
            pred_iou_thresh=0.7,
            stability_score_thresh=0.92,
            stability_score_offset=0.7,
            crop_n_layers=1,
            box_nms_thresh=0.7,
            crop_n_points_downscale_factor=2,
            min_mask_region_area=25.0,
            use_m2m=True,
        )

    sam_result = mask_generator.generate(image_np)
    save_to_cds(cache_file_path, sam_result, storage_token)

print("Mask Generator Finished")

mask_data_list = []
for idx, mask_data in enumerate(sam_result):
    mask_info = {
        "id": idx,
        "bounding_boxes": mask_data["bbox"],
        "point_coords": mask_data["point_coords"],
        "segmentation": mask_data["segmentation"],
    }
    mask_data_list.append(mask_info)

print("Generate polygons from masks", len(mask_data_list))
valid_polygons = generate_polygons_from_masks(
    mask_data_list, image_width, image_height
)

final_response = {"masks": valid_polygons}

return final_response

except Exception as e:
message = f"Failed to predict segmentation: {e}"
return build_response_error(message)

In [38]:
image_rgb.shape

(427, 640, 3)

In [3]:
!curl -X POST http://localhost:8000/api/v1/generate_masks -H "Content-Type: application/json" -d '{"image": "http://localhost:9000/000000274036.jpg","storageRefsToken": "test-token"}' > segmentation_result.json

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 51780  100 51696  100    84   3624      5  0:00:16  0:00:14  0:00:02 12869    0     13  0:00:06  0:00:06 --:--:--     00:00:12  0:00:02     0


let's take a look at the segmentation:

In [4]:
!pip install requests Pillow numpy opencv-python matplotlib



In [9]:
2+2

4

In [1]:
import json
import requests
from PIL import Image, UnidentifiedImageError
import numpy as np
import cv2 # OpenCV for drawing
import matplotlib.pyplot as plt
import io
import os
import argparse # Added for command-line arguments

def visualize_segmentation(json_path, image_url_or_path, show_bboxes=False):
    """
    Loads segmentation results from a JSON file and visualizes the masks
    (and optionally bounding boxes) on the corresponding image.

    Args:
        json_path (str): Path to the segmentation_result.json file.
        image_url_or_path (str): The URL or local file path of the original image.
        show_bboxes (bool): If True, draw bounding boxes around the masks.
    """
    # --- 1. Load the Segmentation Data ---
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        masks_data = data.get("masks")
        if masks_data is None:
            print(f"Error: 'masks' key not found in {json_path}")
            return
        if not isinstance(masks_data, list):
             print(f"Error: 'masks' key in {json_path} is not a list.")
             return
        print(f"Loaded {len(masks_data)} mask entries from {json_path}")
    except FileNotFoundError:
        print(f"Error: JSON file not found at {json_path}")
        return
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {json_path}")
        return
    except Exception as e:
        print(f"An unexpected error occurred loading JSON: {e}")
        return

    # --- 2. Load the Original Image ---
    try:
        if image_url_or_path.startswith(('http://', 'https://')):
            print(f"Fetching image from URL: {image_url_or_path}")
            response = requests.get(image_url_or_path, stream=True, timeout=10)
            response.raise_for_status()
            img_pil = Image.open(io.BytesIO(response.content))
        else:
            print(f"Loading image from local path: {image_url_or_path}")
            if not os.path.exists(image_url_or_path):
                 print(f"Error: Image file not found at {image_url_or_path}")
                 return
            img_pil = Image.open(image_url_or_path)

        image_np = np.array(img_pil.convert('RGB'))
        print(f"Image loaded successfully. Shape: {image_np.shape}")

    except requests.exceptions.RequestException as e:
        print(f"Error fetching image URL {image_url_or_path}: {e}")
        return
    except FileNotFoundError:
        print(f"Error: Image file not found at {image_url_or_path}")
        return
    except UnidentifiedImageError:
        print(f"Error: Could not identify or open image file/data from {image_url_or_path}")
        return
    except Exception as e:
        print(f"An unexpected error occurred loading the image: {e}")
        return

    # --- 3. Draw Masks and BBoxes on the Image ---
    overlay = image_np.copy()
    output = image_np.copy() # Draw bboxes directly onto this later if needed
    alpha = 0.5
    num_masks = len(masks_data)
    colors = plt.cm.viridis(np.linspace(0, 1, num_masks))[:, :3]
    bbox_color_bgr = (0, 255, 0) # Green for BBoxes (BGR format for OpenCV)
    bbox_thickness = 2

    mask_count = 0
    bbox_count = 0
    for i, mask_info in enumerate(masks_data):
        polygon_points = mask_info.get("polygon")
        # Use "bounding_boxes" key, expecting [xmin, ymin, width, height] format
        bbox_data = mask_info.get("bounding_boxes")

        # Draw Polygon (Mask)
        if polygon_points and isinstance(polygon_points, list) and len(polygon_points) >= 3:
            try:
                pts = np.array(polygon_points, dtype=np.int32).reshape((-1, 1, 2))
                color_rgb_0_1 = colors[i]
                color_bgr_0_255 = tuple(int(c * 255) for c in color_rgb_0_1[::-1])
                cv2.fillPoly(overlay, [pts], color_bgr_0_255)
                mask_count += 1
            except ValueError as e:
                 print(f"Warning: Skipping mask {i} polygon due to error converting points: {e}")
                 continue # Skip to next mask if polygon is bad
            except Exception as e:
                 print(f"Warning: An unexpected error occurred drawing mask {i} polygon: {e}")
                 continue # Skip to next mask if polygon drawing fails
        else:
            print(f"Warning: Skipping mask {i} polygon due to missing or invalid 'polygon' data.")
            # Continue to potentially draw bbox even if polygon is missing/invalid

        # Draw Bounding Box (if requested and available)
        if show_bboxes:
            if bbox_data and isinstance(bbox_data, list) and len(bbox_data) == 4:
                try:
                    x_min, y_min, width, height = map(int, bbox_data)
                    x_max = x_min + width
                    y_max = y_min + height
                    # Draw rectangle directly on the 'output' image (before blending)
                    cv2.rectangle(output, (x_min, y_min), (x_max, y_max), bbox_color_bgr, bbox_thickness)
                    bbox_count += 1
                except (ValueError, TypeError) as e:
                    print(f"Warning: Skipping bbox for mask {i} due to invalid data format: {e}. Data: {bbox_data}")
                except Exception as e:
                     print(f"Warning: An unexpected error occurred drawing bbox {i}: {e}")
            else:
                # Only warn if bbox was expected but not found/valid
                print(f"Warning: Skipping bbox for mask {i} due to missing or invalid 'bounding_boxes' data. Found: {bbox_data}")

    # Blend the overlay (masks) with the output image (original + bboxes)
    cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output)

    print(f"Drew {mask_count} valid masks" + (f" and {bbox_count} valid bounding boxes." if show_bboxes else " onto the image."))

    # --- 4. Display the Result ---
    plt.figure(figsize=(12, 10))
    plt.imshow(output)
    title = f"Segmentation Masks from {os.path.basename(json_path)} on {os.path.basename(image_url_or_path)}"
    if show_bboxes:
        title += " (with Bounding Boxes)"
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

# --- Main execution block with argparse ---
if __name__ == "__main__":
    # --- Argument Parser Setup ---
    parser = argparse.ArgumentParser(description="Visualize segmentation masks (and optionally bounding boxes) from a JSON file on an image.")
    # Required positional arguments
    parser.add_argument("json_path", help="Path to the segmentation_result.json file.")
    parser.add_argument("image_location", help="URL or local file path of the original image.")
    # Optional flag
    parser.add_argument("--show-bboxes", action="store_true", help="Display bounding boxes in addition to masks.")

    args = parser.parse_args()

    # --- Run Visualization --- Pass the command line arg to the function
    # visualize_segmentation('/Users/jamesemilian/Desktop/cronus-ai/figure8/anomaly_detect/sam2/segmentation_result.json', 'http://localhost:9000/000000274036.jpg', True)

usage: ipykernel_launcher.py [-h] [--show-bboxes] json_path image_location
ipykernel_launcher.py: error: the following arguments are required: json_path, image_location


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [17]:
!python visualise_segmentation.py segmentation_result.json http://localhost:9000/000000274036.jpg --show-bboxes

An unexpected error occurred loading JSON: load() missing 1 required positional argument: 'fp'
