In [13]:
import os
import torch
from transformers import Sam3Model, Sam3Processor
from PIL import Image
import requests

# 1. Setup Device
if torch.backends.mps.is_available():
    device = "mps"
    print("‚úÖ Using Apple Metal GPU (MPS)")
else:
    device = "cpu"
    print("‚ö†Ô∏è MPS not available. Using CPU")

# 2. Load Model from LOCAL FOLDER
local_path = "./sam3_model"
print(f"üìÇ Loading SAM 3 Image Model from: '{local_path}'...")

try:
    processor = Sam3Processor.from_pretrained(local_path, trust_remote_code=True)
    model = Sam3Model.from_pretrained(local_path, trust_remote_code=True).to(device)
    print("‚úÖ Model loaded successfully!")
except Exception as e:
    print(f"\n‚ùå LOAD ERROR: {e}")
    exit()

‚úÖ Using Apple Metal GPU (MPS)
üìÇ Loading SAM 3 Image Model from: './sam3_model'...


Loading weights: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1468/1468 [00:01<00:00, 1430.22it/s, Materializing param=vision_encoder.neck.fpn_layers.3.proj2.weight]                       


‚úÖ Model loaded successfully!


In [14]:
import numpy as np

In [10]:
def segment_image(image_path, prompt, model, processor, device="mps"):
    """
    Segments objects in an image based on a text prompt.
    
    Args:
        image_path (str): Path to the local image file.
        prompt (str): Text prompt (e.g., "cat").
        model: The loaded Sam3Model.
        processor: The loaded Sam3Processor.
        device (str): "mps" for Mac or "cpu".

    Returns:
        PIL.Image: The original image with a red overlay on detected objects.
    """
    
    # 1. Load & Verify Image
    try:
        raw_image = Image.open(image_path).convert("RGB")
    except Exception as e:
        print(f"‚ùå Error opening image: {e}")
        return None

    # 2. Process Inputs (Manual Split for Stability)
    # A. Image
    image_inputs = processor.image_processor(
        raw_image, 
        return_tensors="pt"
    )
    
    # B. Text (Strict max_length=32 for SAM 3)
    text_inputs = processor.tokenizer(
        [prompt], 
        return_tensors="pt",
        padding="max_length",
        max_length=32,
        truncation=True
    )
    
    # C. Combine & Move to GPU
    inputs = dict(image_inputs)
    inputs.update(text_inputs)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # 3. Run Inference
    with torch.no_grad():
        outputs = model(**inputs)
        
    # 4. Post-Process results
    # Convert tensor size to list to avoid TypeError
    target_sizes = image_inputs["original_sizes"].tolist()
    
    results = processor.post_process_instance_segmentation(
        outputs, 
        threshold=0.5, 
        mask_threshold=0.5, 
        target_sizes=target_sizes
    )[0]
    
    # 5. Create Visualization
    if len(results['masks']) > 0:
        # Combine ALL found masks into one layer (in case there are 2 cats)
        all_masks_np = results['masks'].cpu().numpy()
        combined_mask = np.max(all_masks_np, axis=0) 
        
        # Resize mask to match original image dimensions
        mask_image = Image.fromarray((combined_mask * 255).astype('uint8')).resize(raw_image.size)
        
        # Create a transparent Red layer (R, G, B, Alpha)
        red_layer = Image.new("RGBA", raw_image.size, (255, 0, 0, 100)) 
        
        # Paste the red layer using the mask as the transparency guide
        final_image = raw_image.convert("RGBA")
        final_image.paste(red_layer, (0,0), mask_image)
        
        return final_image
    else:
        print(f"ü§∑ No objects found for prompt: '{prompt}'")
        return raw_image

In [16]:
my_image = "efficientsam3_arm/assets/persons.jpg"  # Make sure this file exists
my_prompt = "person"

In [17]:
import time
import torch
from PIL import Image
import numpy as np

def benchmark_inference(image_path, prompt, model, processor, runs=5):
    """
    Compares inference speed between MPS (GPU) and CPU.
    """
    
    # 1. Prepare Data (Do this once)
    print(f"‚öôÔ∏è  Preparing input for '{prompt}'...")
    raw_image = Image.open(image_path).convert("RGB")
    
    # Process inputs (CPU side)
    image_inputs = processor.image_processor(raw_image, return_tensors="pt")
    text_inputs = processor.tokenizer(
        [prompt], 
        return_tensors="pt", 
        padding="max_length", 
        max_length=32, 
        truncation=True
    )
    
    # Combine inputs into a dictionary
    base_inputs = dict(image_inputs)
    base_inputs.update(text_inputs)

    # --- Helper Function to Run Timing ---
    def run_timing(device_name):
        print(f"\nüöÄ Testing on {device_name.upper()}...")
        
        # A. Move Model & Inputs to Device
        device = torch.device(device_name)
        model.to(device)
        inputs = {k: v.to(device) for k, v in base_inputs.items()}
        
        # B. Warm-up (Important!)
        # We run it once so PyTorch compiles the kernels. 
        # Without this, the first run looks much slower than reality.
        print("   üî• Warming up...")
        with torch.no_grad():
            _ = model(**inputs)
            if device_name == "mps":
                torch.mps.synchronize() # Wait for GPU to finish
        
        # C. Measure Loop
        print(f"   ‚è±Ô∏è  Running {runs} loops...")
        times = []
        
        for i in range(runs):
            start = time.perf_counter()
            
            with torch.no_grad():
                _ = model(**inputs)
                
            # Crucial: GPU is async, we must wait for it to finish before stopping clock
            if device_name == "mps":
                torch.mps.synchronize()
                
            end = time.perf_counter()
            times.append(end - start)
            print(f"      Run {i+1}: {end - start:.4f}s")
            
        avg_time = sum(times) / len(times)
        print(f"   ‚úÖ {device_name.upper()} Average: {avg_time:.4f} seconds")
        return avg_time

    # 2. Run MPS Test
    mps_time = run_timing("mps")
    
    # 3. Run CPU Test
    cpu_time = run_timing("cpu")
    
    # 4. Compare
    speedup = cpu_time / mps_time
    print("\n" + "="*40)
    print(f"üèÜ RESULT: MPS is {speedup:.1f}x faster than CPU")
    print("="*40)

# --- EXECUTE ---
# Replace with your actual image path
benchmark_inference(my_image, "persons", model, processor, runs=5)

‚öôÔ∏è  Preparing input for 'persons'...

üöÄ Testing on MPS...
   üî• Warming up...
   ‚è±Ô∏è  Running 5 loops...
      Run 1: 11.8456s
      Run 2: 12.1344s
      Run 3: 10.6460s
      Run 4: 7.2120s
      Run 5: 4.7849s
   ‚úÖ MPS Average: 9.3246 seconds

üöÄ Testing on CPU...
   üî• Warming up...
   ‚è±Ô∏è  Running 5 loops...
      Run 1: 8.9002s
      Run 2: 7.6865s
      Run 3: 7.2332s
      Run 4: 7.1389s
      Run 5: 7.2409s
   ‚úÖ CPU Average: 7.6399 seconds

üèÜ RESULT: MPS is 0.8x faster than CPU


In [19]:
import time
import torch
from PIL import Image
import numpy as np

def benchmark_inference_fixed(image_path, prompt, model, processor, warmups=3, test_runs=10):
    
    # 1. Prepare Data
    print(f"‚öôÔ∏è  Preparing input for '{prompt}'...")
    raw_image = Image.open(image_path).convert("RGB")
    
    image_inputs = processor.image_processor(raw_image, return_tensors="pt")
    text_inputs = processor.tokenizer(
        [prompt], 
        return_tensors="pt", 
        padding="max_length", 
        max_length=32, 
        truncation=True
    )
    
    base_inputs = dict(image_inputs)
    base_inputs.update(text_inputs)

    def run_timing(device_name):
        print(f"\nüöÄ Testing on {device_name.upper()}...")
        device = torch.device(device_name)
        model.to(device)
        inputs = {k: v.to(device) for k, v in base_inputs.items()}
        
        # --- AGGRESSIVE WARMUP ---
        # We run multiple times specifically to trigger the JIT compilation on Mac
        print(f"   üî• Warming up ({warmups} loops - ignored)...")
        for _ in range(warmups):
            with torch.no_grad():
                _ = model(**inputs)
                if device_name == "mps":
                    torch.mps.synchronize()
        
        # --- REAL TEST ---
        print(f"   ‚è±Ô∏è  Running {test_runs} loops (measuring)...")
        times = []
        
        for i in range(test_runs):
            start = time.perf_counter()
            with torch.no_grad():
                _ = model(**inputs)
            
            if device_name == "mps":
                torch.mps.synchronize()
                
            end = time.perf_counter()
            times.append(end - start)
            # Optional: Print every run to ensure it's stable
            # print(f"      Run {i+1}: {end - start:.4f}s")
            
        avg_time = sum(times) / len(times)
        best_time = min(times)
        
        print(f"   ‚úÖ {device_name.upper()} Average: {avg_time:.4f}s | Best: {best_time:.4f}s")
        return avg_time

    # Run Benchmark
    mps_time = run_timing("mps")
    cpu_time = run_timing("cpu")
    
    speedup = cpu_time / mps_time
    print("\n" + "="*40)
    print(f"üèÜ STEADY STATE RESULT: MPS is {speedup:.1f}x faster")
    print("="*40)

# Run with more loops to see the real stability
benchmark_inference_fixed(my_image, "persons", model, processor, warmups=5, test_runs=10)

‚öôÔ∏è  Preparing input for 'persons'...

üöÄ Testing on MPS...
   üî• Warming up (5 loops - ignored)...
   ‚è±Ô∏è  Running 10 loops (measuring)...
   ‚úÖ MPS Average: 4.9935s | Best: 4.7233s

üöÄ Testing on CPU...
   üî• Warming up (5 loops - ignored)...
   ‚è±Ô∏è  Running 10 loops (measuring)...
   ‚úÖ CPU Average: 7.7520s | Best: 7.1564s

üèÜ STEADY STATE RESULT: MPS is 1.6x faster
