In [5]:
# Debug cell - check mask structure
import os
from PIL import Image

# Load one test image
test_img = Image.open(os.path.join('USIS10K/test', os.listdir('USIS10K/test')[0]))
test_state = processor.set_image(test_img)
test_outputs = set_multi_text_prompts(processor, test_state, PROMPTS, CATEGORIES)

# Check the first output
print("Number of outputs:", len(test_outputs))
print("\nFirst output keys:", test_outputs[0].keys())
print("Category:", test_outputs[0]['category'])
print("\nMasks type:", type(test_outputs[0]['masks']))
print("Number of masks:", len(test_outputs[0]['masks']))

if len(test_outputs[0]['masks']) > 0:
    first_mask = test_outputs[0]['masks'][0]
    print("\nFirst mask type:", type(first_mask))
    print("First mask shape:", first_mask.shape if hasattr(first_mask, 'shape') else 'No shape attribute')
    print("First mask dtype:", first_mask.dtype if hasattr(first_mask, 'dtype') else 'No dtype attribute')
    print("First mask min/max:", first_mask.min().item() if hasattr(first_mask, 'min') else 'N/A', 
          first_mask.max().item() if hasattr(first_mask, 'max') else 'N/A')

Number of outputs: 7

First output keys: dict_keys(['category', 'boxes', 'scores', 'masks'])
Category: fish

Masks type: <class 'torch.Tensor'>
Number of masks: 2

First mask type: <class 'torch.Tensor'>
First mask shape: torch.Size([1, 480, 640])
First mask dtype: torch.bool
First mask min/max: False True


In [1]:
from huggingface_hub import login
login(token="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")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from PIL import Image
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# Load the model
model = build_sam3_image_model()
processor = Sam3Processor(model)

In [None]:
import torch
import json
from tqdm import tqdm
import numpy as np
from pycocotools import mask as mask_util

# Define categories and their descriptive prompts
CATEGORIES = ["fish", "wrecks/ruins", "reefs", "aquatic plants", "human divers", "robots", "sea-floor"]
PROMPTS = [
    "Underwater vertebrates, e.g., fish, turtles",
    "Wrecks, ruins, and damaged artifacts underwater",
    "Underwater invertebrates and coral reefs",
    "Aquatic plants and flora underwater",
    "Human divers and their equipment underwater, scuba divers, human body",
    "Underwater robots, ROVs, and submersibles",
    "Rocks and bottom substrate on the sea floor"
]

def set_multi_text_prompts(processor, state, prompts: list, categories: list):
    """
    Process multiple text prompts in parallel for the same image.
    Returns a list of outputs, one per prompt with category labels.
    """
    if "backbone_out" not in state:
        raise ValueError("You must call set_image before set_multi_text_prompts")
    
    results = []
    
    # Process all prompts at once (batched text encoding)
    text_outputs = processor.model.backbone.forward_text(prompts, device=processor.device)
    
    for idx, (prompt, category) in enumerate(zip(prompts, categories)):
        # Create a copy of state for this prompt
        prompt_state = {
            "original_height": state["original_height"],
            "original_width": state["original_width"],
            "backbone_out": {**state["backbone_out"]},
        }
        
        # Extract features for this specific prompt
        prompt_text_outputs = {
            "language_features": text_outputs["language_features"][:, idx:idx+1],
            "language_mask": text_outputs["language_mask"][idx:idx+1],
            "language_embeds": text_outputs["language_embeds"][:, idx:idx+1],
        }
        prompt_state["backbone_out"].update(prompt_text_outputs)
        
        if "geometric_prompt" not in prompt_state:
            prompt_state["geometric_prompt"] = processor.model._get_dummy_prompt()
        
        # Run grounding for this prompt
        output = processor._forward_grounding(prompt_state)
        results.append({
            "category": category,
            "boxes": output["boxes"],
            "scores": output["scores"],
            "masks": output["masks"],
        })
    
    return results

def convert_multi_output_to_serializable(multi_outputs):
    """Convert multi-prompt tensor outputs to JSON-serializable format with RLE encoding"""
    result = {}
    for out in multi_outputs:
        category = out['category']
        
        # Convert masks to RLE format
        masks_rle = []
        for mask in out['masks']:
            # Convert tensor to numpy array
            # Handle both 2D and 3D masks (squeeze if needed)
            mask_np = mask.cpu().numpy()
            if mask_np.ndim == 3:
                mask_np = mask_np.squeeze(0)  # Remove batch dimension
            
            # Convert bool to uint8
            mask_np = mask_np.astype(np.uint8)
            
            # Ensure mask is in Fortran order for pycocotools
            mask_np = np.asfortranarray(mask_np)
            
            # Encode to RLE
            rle = mask_util.encode(mask_np)
            
            # Convert bytes to string for JSON serialization
            if isinstance(rle['counts'], bytes):
                rle['counts'] = rle['counts'].decode('utf-8')
            
            masks_rle.append(rle)
        
        result[category] = {
            'boxes': out['boxes'].cpu().numpy().tolist(),
            'scores': out['scores'].cpu().numpy().tolist(),
            'masks': masks_rle,
        }
    return result

preds = []

import os
imgs = os.listdir('USIS10K/test')
for i in tqdm(range(len(imgs))):
    # Load image
    img = Image.open(os.path.join('USIS10K/test', imgs[i]))
    
    # Set image
    inference_state = processor.set_image(img)

    # Run all prompts in parallel
    multi_outputs = set_multi_text_prompts(processor, inference_state, PROMPTS, CATEGORIES)
    
    # Convert tensors to serializable format
    serializable_output = convert_multi_output_to_serializable(multi_outputs)
    serializable_output['img_path'] = os.path.join('USIS10K/test', imgs[i])
    preds.append(serializable_output)

    if i % 50 == 0:
        with open('sam3_usis10k_preds_rle.json', 'w') as f:
            json.dump(preds, f)

with open('sam3_usis10k_preds_rle.json', 'w') as f:
    json.dump(preds, f)

print(f"Saved {len(preds)} predictions to sam3_usis10k_preds_rle.json")

100%|██████████| 1596/1596 [55:36<00:00,  2.09s/it]

Saved 1596 predictions to sam3_usis10k_preds_rle.json





In [None]:
# import torch
# from tqdm import tqdm

# preds = []
# batch_size = 8  # Adjust based on your GPU memory

# for i in tqdm(range(0, len(data), batch_size)):
#     batch = data[i:i + batch_size]
    
#     # Load all images in the batch
#     imgs = [Image.open(item['img_path']) for item in batch]
    
#     # Set image for the first one (if needed by processor)
#     inference_state = processor.set_image_batch(imgs)

#     outputs = processor.set_text_prompt(state=inference_state, prompt="fish or wrecks/ruins or reefs or aquatic plants or human divers or robots or sea-floor")
#     preds.extend(outputs)

# with open('sam3_usis10k_preds.json', 'w') as f:
#     json.dump(preds, f)