# SageMaker JumpStart - deploy VS model

This notebook demonstrates how to use the SageMaker Python SDK to deploy a SageMaker JumpStart VS model and invoke the endpoint.

The model used in this notebook is Meta's SAM 2.1 model. This model is used in segmentation tasks over images and video. This notebook will be demonstrating the video segementation use case using the SAM 2.1 model. We will be using both point and box prompts in various contexts to infer segmentation masks over a sample video. Single object and combination prompts of points and a box is supported as well as the ability to correct masks and re-infer the masks.

NOTE: The masks generated are on a frame by frame basis using surrounding frame inferences to infer on each frame. This means that adding points after propgating the prompts across the video will only correct the alredy inferred masks. To remask a new inference sessions needs to be started.

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

---

In [None]:
!pip install opencv-python-headless

In [None]:
!pip install -U sagemaker

In [None]:
%conda install -y ffmpeg

In [None]:
from sagemaker.jumpstart.model import JumpStartModel

Select your desired model ID. You can search for available models in the [Built-in Algorithms with pre-trained Model Table](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html).

In [None]:
model_id = "meta-vs-sam-2-1-hiera-small"

## Deploy model

Using the model ID, define your model as a JumpStart model. You can deploy the model on other instance types by passing `instance_type` to `JumpStartModel`. See [Deploy publicly available foundation models with the JumpStartModel class](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use-python-sdk.html#jumpstart-foundation-models-use-python-sdk-model-class) for more configuration options.

In [None]:
model = JumpStartModel(model_id=model_id)

You can now deploy your JumpStart model. The deployment might take few minutes.

In [None]:
predictor = model.deploy()

# SAM2 Endpoint Testing

This notebook demonstrates testing of SAM2 endpoint functionality using modular test classes.

In [None]:
import os
import logging
import sys
import boto3
import json
from typing import List, Dict
import numpy as np
import cv2
from pathlib import Path
import base64
import zlib
from typing import Union, Optional
import time
from datetime import datetime
from IPython.display import Video, display
from IPython.core.display import HTML

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

logger = logging.getLogger(__name__)

### Stream Parser

Due to functionality of SAM2.1 generating masks we need a way to make sure that we return the large payload without reaching sagemaker limits. Therefore we leverage Sagemaker Streaming Responses to stream back chunks of our payload. 
We expect Json Line items from the endpoint so here we have a StreamParser class to parse the chunks we add to it and parse out complete Json Line items.

In [None]:
class StreamParser:
    """
    Parses streaming JSON responses with nested structure support.
    Maintains a buffer (max 10MB) and accumulates valid JSON dictionaries.
    """
    def __init__(self, max_size: int = 10 * 1024 * 1024):
        self.buffer = bytearray()
        self.parsed_responses: List[Dict] = []
        self.max_size = max_size
        self.decoder = json.JSONDecoder()

    def write(self, event: dict) -> None:
        """Processes streaming event and extracts JSON objects."""
        try:
            if payload := event.get("PayloadPart", {}).get("Bytes"):
                self.buffer.extend(payload)
                self._process_buffer()
        except Exception as e:
            logger.error(f"Processing error: {e}")

    def _process_buffer(self) -> None:
        """Extracts and parses complete JSON objects from buffer."""
        if len(self.buffer) > self.max_size:
            logger.warning("Buffer exceeded max size, clearing")
            self.buffer.clear()
            return

        try:
            index = 0
            buffer_str = self.buffer.decode('utf-8')
            buffer_len = len(buffer_str)
            complete_objects = 0
            incomplete_objects = 0

            while index < buffer_len:
                try:
                    # Look for start of JSON object
                    while index < buffer_len and buffer_str[index] not in '{[':
                        index += 1
                    
                    if index >= buffer_len:
                        break

                    # Try to parse JSON object
                    try:
                        result, end = self.decoder.raw_decode(buffer_str[index:])
                        if isinstance(result, dict):                            
                            self.parsed_responses.append(result)
                            complete_objects += 1
                        index += end
                    except json.JSONDecodeError as e:
                        # Check if we might have an incomplete object
                        if buffer_str[index] == '{' and buffer_str[-1] != '}':
                            incomplete_objects += 1
                            break
                        index += 1

                except Exception as e:
                    logger.error(f"Error processing buffer at position {index}: {e}")
                    index += 1

            # Remove processed data if we found complete objects
            if complete_objects > 0:
                processed_bytes = len(buffer_str[:index].encode('utf-8'))
                del self.buffer[:processed_bytes]

        except UnicodeDecodeError as e:
            logger.warning(f"Unicode decode error: {e}. Buffer might contain incomplete UTF-8 sequences.")

    def get_responses(self) -> List[Dict]:
        """Returns accumulated parsed JSON responses."""
        return self.parsed_responses

### Util Funtions

Computer Vision tasks require that we transform videos into formats that are easily transmitable and visualizeable. The following functions are set up to help with those intermediate steps.

Functions:
* encode video (Mandatory)
  * encode the binary data of an video into base64 for the endpoint to decode and infer upon. 
* decompress_mask (Mandatory)
  * as mentioned before we compress and stream the data back and to complement that we have a function to decompress the mask back into its original boolean array. The expected size is 1 channel of the same dimensions of the input image. The method of compresion is compressing a numpy array using zlib then compressing tha using base64. Decompression reverses that by decompressing using base64 and then decompressing using zlib. The resulitng data is loaded into a numpy array from buffer.
* save_visualization (Customizeable)
  * we use opencv to apply the masks we get to each frame then stitch the new frames back together back into a video now with the masks.
  * NOTE: this function, as is, requires FFmpeg and OpenCV.

In [None]:
def encode_video(video_path: str) -> str:
    """Encode video to base64."""
    try:
        with open(video_path, 'rb') as f:
            return base64.b64encode(f.read()).decode('utf-8')
    except Exception as e:
        logger.error(f"Failed to encode video {video_path}: {str(e)}")
        raise

def decompress_mask(mask_data: Union[Dict, np.ndarray]) -> Optional[np.ndarray]:
    """Decompress mask from zlib+base64 format."""
    if isinstance(mask_data, np.ndarray):
        return mask_data.astype(bool)
    
    if not isinstance(mask_data, dict):
        logger.warning(f"Invalid mask data type: {type(mask_data)}")
        return None

    if mask_data.get("compression") != "zlib_base64":
        logger.warning(f"Mask data is not in compressed format: {mask_data.keys()}")
        return None

    counts, shape, dtype = (mask_data.get(key) for key in ("counts", "shape", "dtype"))
    if not all((counts, shape, dtype)):
        logger.warning(f"Missing required mask data fields: counts={bool(counts)}, shape={bool(shape)}, dtype={bool(dtype)}")
        return None

    try:
        decompressed = zlib.decompress(base64.b64decode(counts))
        array = np.frombuffer(decompressed, dtype=np.dtype(dtype))
        return array.reshape(shape) > 0.0
    except Exception as e:
        logger.error(f"Failed to decompress mask: {e}")
        return None

def convert_to_web_compatible_format(output_path: str):
    try:
        import subprocess
        output_web_path = str(output_path).replace('.mp4', '_web.mp4')
        
        # Use ffmpeg to convert to web-compatible format
        cmd = [
            'ffmpeg', '-i', str(output_path),
            '-vcodec', 'libx264',
            '-acodec', 'aac',
            '-y',  # Overwrite output file if it exists
            output_web_path
        ]
        
        subprocess.run(cmd, check=True)
        
        # Replace original file with converted file
        import shutil
        shutil.move(output_web_path, output_path)
        
    except Exception as e:
        logger.error(f"Failed to convert video to web-compatible format: {e}")
        logger.warning("The video file may not play in web browsers")
        
def save_visualization(video_path, results, output_dir, debug_dir=None):
    """Save visualization of masks overlaid on video frames."""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Convert list results to dictionary if needed
    if isinstance(results, list):
        results_dict = {}
        for frame_data in results:
            if "frame_idx" in frame_data:
                results_dict[str(frame_data["frame_idx"])] = {
                    "masks": frame_data.get("masks", []),
                    "obj_ids": frame_data.get("obj_ids", [])
                }
        results = results_dict

    # Generate colors for objects
    all_obj_ids = set()
    for frame_data in results.values():
        if "obj_ids" in frame_data:
            all_obj_ids.update(frame_data["obj_ids"])
    object_colors = {
        obj_id: cv2.cvtColor(np.uint8([[[i / len(all_obj_ids) * 180, 204, 255]]]), cv2.COLOR_HSV2RGB)[0][0] / 255.0 
        for i, obj_id in enumerate(sorted(all_obj_ids))
    }

    # Open video and get properties
    cap = cv2.VideoCapture(str(video_path))
    fps, width, height, total_frames = [cap.get(prop) for prop in [
        cv2.CAP_PROP_FPS, cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FRAME_COUNT
    ]]
    width, height, total_frames = map(int, [width, height, total_frames])

    # Set up video writer
    output_path = output_dir / "visualization.mp4"
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))

    try:
        for frame_idx in range(total_frames):
            ret, frame = cap.read()
            if not ret:
                break

            frame_data = results.get(str(frame_idx), {"masks": [], "obj_ids": []})
            masks_data = frame_data["masks"]
            obj_ids = frame_data["obj_ids"]

            if len(obj_ids) > 0:
                processed_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                
                # Decompress masks
                if isinstance(masks_data, dict):
                    decompressed_masks = decompress_mask(masks_data)
                elif isinstance(masks_data, list):
                    masks = [decompress_mask(mask_data) for mask_data in masks_data]
                    decompressed_masks = np.concatenate(masks, axis=0) if all(m is not None for m in masks) else None

                # Apply masks if decompression was successful
                if decompressed_masks is not None:
                    for i, obj_id in enumerate(obj_ids):
                        mask = decompressed_masks[i]
                        color = object_colors[obj_id]
                        # Apply mask with alpha blending
                        mask = np.squeeze(mask) > 0
                        mask_image = np.zeros((*processed_frame.shape[:2], 4), dtype=np.float32)
                        mask_image[mask] = [*color, 0.6]
                        alpha = np.repeat(mask_image[..., 3:4], 3, axis=2)
                        processed_frame = ((processed_frame.astype(np.float32) / 255.0 * (1 - alpha) + mask_image[..., :3] * alpha) * 255).astype(np.uint8)

                    processed_frame_bgr = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
                    out.write(processed_frame_bgr)
            else:
                out.write(frame)
    finally:
        cap.release()
        out.release()
    convert_to_web_compatible_format(output_path)

### Sagemaker Endpoint interaction abstraction

The following functions are completely optional when interacting with the Sagemaker endpoint. We abstracted repetitive code into the following functions to handle both streaming and non streaming requests. This endpoint has been built with Sticky(Stateful) Sessions in mind so operation of the endpoint is dependent on actions specified.

NOTE: The operations to start and end session should be non streaming requests. When starting a new session, the header `X-Amzn-SageMaker-Session-Id` or the parameter in the `invoke_endpoint/invoke_endpoint_with_response_stream` function, in the sagemaker runtime client, needs to be `NEW_SESSION`. [Blog Post](https://aws.amazon.com/blogs/machine-learning/build-ultra-low-latency-multimodal-generative-ai-applications-using-sticky-session-routing-in-amazon/) discussing Sagemaker Sticky Sessions.

In [None]:
def send_and_check_request(runtime_client, endpoint_name, request, session_id=None):
    """Send request to endpoint and check response."""
    request_type = request.get("type", "unknown")
    logger.info(f"Sending {request_type} request to endpoint {endpoint_name}")
    try:
        session_id, response = _handle_request(runtime_client, endpoint_name, request, session_id)
        return response, session_id
    except Exception as e:
        _handle_error(e, request, request_type, session_id)
        raise

def _handle_request(runtime_client, endpoint_name, request, session_id):
    request_type = request.get("type", "unknown")
    
    if request_type == "start_session":
        session_id = "NEW_SESSION"
    elif request_type == "add_points" and "clear_old_points" not in request:
        request["clear_old_points"] = False

    if request_type in ["start_session", "end_session"]:
        return _handle_non_streaming_request(runtime_client, endpoint_name, request, session_id)
    else:
        return _handle_streaming_request(runtime_client, endpoint_name, request, session_id)

def _handle_streaming_request(runtime_client, endpoint_name, request, session_id):
    response = runtime_client.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Body=json.dumps(request),
        SessionId=session_id,
        Accept="application/jsonlines"
    )
    
    parser = StreamParser()
    for event in response['Body']:
        parser.write(event)
    
    parsed_responses = parser.get_responses()
    logger.info(f"Total responses received: {len(parsed_responses)}")
    
    headers = response['ResponseMetadata']['HTTPHeaders']
    session_id = _get_session_id(headers, request.get("type"), session_id)
    
    return session_id, parsed_responses

def _handle_non_streaming_request(runtime_client, endpoint_name, request, session_id):
    response = runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Body=json.dumps(request),
        SessionId=session_id
    )
    
    response_body = json.loads(response['Body'].read())
    headers = response['ResponseMetadata']['HTTPHeaders']
    logger.info(f"Response headers: {headers}")
    
    session_id = _get_session_id(headers, request.get("type"), session_id)
    
    return session_id, response_body

def _get_session_id(headers, request_type, current_session_id):
    if request_type == "start_session":
        session_id = headers.get('x-amzn-sagemaker-new-session-id', '')
        if session_id and '; Expires=' in session_id:
            session_id, expiry = session_id.split('; Expires=')
            logger.info(f"New session created: {session_id} (expires: {expiry})")
    else:
        session_id = headers.get('x-amzn-sagemaker-session-id', current_session_id)
        if headers.get('x-amzn-sagemaker-closed-session-id'):
            logger.info(f"Session closed: {session_id}")
    
    if not session_id and request_type == "start_session":
        raise RuntimeError("No session ID returned for new session request")
    
    return session_id

def _handle_error(error, request, request_type, session_id):
    logger.error(f"Request failed: {str(error)}")
    logger.error(f"Request details: type={request_type}, session_id={session_id}")

### SAM2.1 Video Predictor

### Session Management

#### 1. Start Session (`start_session`)
- **Purpose**: Initializes a new inference session
- **Parameters**:
  - `path`: Path to input image/video
  - `input_type`: 'image' or 'video'
  - `session_id`: Optional custom identifier
- **Returns**:
  - Session ID
  - Success status
  - Image/video dimensions

#### 2. Close Session (`close_session`)
- **Purpose**: Terminates an active session
- **Parameter**:
  - `session_id`: Session to close
- **Returns**:
  - Success status

### Video Segmentation Controls

#### 3. Add Points (`add_points`)
- **Purpose**: Adds point prompts for video segmentation
- **Parameters**:
  - `frame_index`: Frame number
  - `object_id`: Object tracking identifier
  - `points`: Point coordinates
  - `labels`: Point labels
  - `clear_old_points`: Whether to clear existing points
- **Returns**:
  - Frame index
  - Object IDs list
  - Predicted masks

#### 4. Add Box (`add_box`)
- **Purpose**: Adds box prompt with optional points
- **Parameters**:
  - `frame_index`: Frame number
  - `object_id`: Object tracking identifier
  - `box`: Box coordinates [x1,y1,x2,y2]
  - `points`: Optional point coordinates
  - `labels`: Optional point labels
- **Returns**:
  - Frame index
  - Object IDs list
  - Predicted masks

### State Management

#### 5. Clear Points in Frame (`clear_points_in_frame`)
- **Purpose**: Clears prompts for specific frame
- **Parameters**:
  - `frame_index`: Frame number
  - `object_id`: Object tracking identifier
- **Returns**:
  - Success status

#### 6. Clear Points in Video (`clear_points_in_video`)
- **Purpose**: Clears all prompts in video
- **Parameters**: None
- **Returns**:
  - Success status

#### 7. Propagate in Video (`propagate_in_video`)
- **Purpose**: Propagates masks through video frames
- **Parameter**:
  - `start_frame_index`: Starting frame number
- **Returns**:
  - Frame index
  - Object IDs list
  - Predicted masks


**NOTE**: When adding points and boxes it is recommended to do this before propogation to get full masks. Any points added after a propogation is called will only edit the masks previously generated. The edits may not be substantial and might only alter the masks minimally. If focus of prompt needs to change it is recommended to end session and start again.

In [None]:
class VideoPredictor:
    def __init__(self, runtime_client, endpoint_name, video_path):
        self.runtime_client = runtime_client
        self.endpoint_name = endpoint_name
        self.video_path = video_path

    def save_prompt_viz(self, frame_index: int, test_type: str, points=None, box=None, response=None):
        """Save visualization of prompts and masks on a video frame.
        
        Args:
            frame_index: Frame index to visualize
            test_type: Type of test being run (e.g. "single_object", "multiple_objects")
            points: List of point dictionaries with coordinates and labels
            box: List of box coordinates [x1,y1,x2,y2]
            response: Response from endpoint containing masks
        """
        # Get the frame
        cap = cv2.VideoCapture(self.video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
        ret, frame = cap.read()
        cap.release()
        
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Draw points if provided
        if points:
            for point in points:
                x, y = map(int, point["coordinates"])
                color = (0, 255, 0) if point["label"] == 1 else (255, 0, 0)  # Green for positive, Red for negative
                cv2.circle(frame_rgb, (x, y), 5, color, -1)  # Filled circle
                cv2.circle(frame_rgb, (x, y), 7, color, 2)   # Border
        
        # Draw box if provided
        if box:
            x1, y1, x2, y2 = map(int, box)
            cv2.rectangle(frame_rgb, (x1, y1), (x2, y2), (255, 128, 0), 2)  # Orange box

        with open("debug.json","w+") as f:
            f.write(json.dumps(response))
        
        # Draw masks if available in response
        if response and "masks" in response:
            masks = response["masks"]
            if isinstance(masks, list) and len(masks) > 0:
                for mask_data in masks:
                    mask = decompress_mask(mask_data)
                    if mask is not None:
                        # Apply semi-transparent mask
                        mask_color = np.array([0, 0, 255])  # Blue for mask
                        mask_overlay = np.zeros_like(frame_rgb)
                        mask_overlay[mask] = mask_color
                        frame_rgb = cv2.addWeighted(frame_rgb, 1, mask_overlay, 0.5, 0)
        
        # Save in test-specific directory without timestamp
        output_dir = os.path.join("outputs", "video_predictor", f"{test_type}")
        os.makedirs(output_dir, exist_ok=True)
        
        # Save frame with descriptive name
        frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(output_dir, f"prompt_frame_{frame_index:04d}_{datetime.now().timestamp()}.jpg"), frame_bgr)

    def _start_session(self):
        start_request = {
            "type": "start_session",
            "input_type": "video",
            "path": encode_video(self.video_path),
        }
        _, session_id = send_and_check_request(
            self.runtime_client, self.endpoint_name, start_request
        )
        return session_id

    def _end_session(self, session_id):
        close_request = {"type": "close_session", "session_id": session_id}
        send_and_check_request(
            self.runtime_client, self.endpoint_name, close_request, session_id
        )

    def _propagate_video(self, session_id, test_type: str, start_frame_index=0):
        """Propagate tracking in video and return results."""
        propagate_request = {
            "type": "propagate_in_video",
            "session_id": session_id,
            "start_frame_index": start_frame_index,
        }

        # Time video propagation
        prop_start_time = datetime.now()
        response, _ = send_and_check_request(
            self.runtime_client, self.endpoint_name, propagate_request, session_id
        )

        results, frame_count = response, len(response)
        prop_duration = (datetime.now() - prop_start_time).total_seconds()
        logger.info(f"Video propagation completed: {frame_count} frames in {prop_duration:.2f} seconds")

        # Time visualization saving
        vis_start_time = datetime.now()
        output_dir = os.path.join("outputs", "video_predictor", f"{test_type}")
        save_visualization(self.video_path, results, output_dir)
        vis_duration = (datetime.now() - vis_start_time).total_seconds()
        logger.info(f"Visualization saving completed in {vis_duration:.2f} seconds")

        return results

    def _add_points(self, session_id, prompts, test_type: str, object_id=1, frame_index=0, clear_old_points=True):
        """Add points for video segmentation."""
        # Extract points and labels
        points = []
        labels = []
        for prompt in prompts:
            if prompt["type"] == "point":
                points.append(prompt["coordinates"])
                labels.append(prompt["label"])

        if points:
            request = {
                "type": "add_points",
                "session_id": session_id,
                "frame_index": frame_index,
                "object_id": object_id,
                "points": points,
                "labels": labels,
                "clear_old_points": clear_old_points,
            }
            response, _ = send_and_check_request(self.runtime_client, self.endpoint_name, request, session_id)
            self.save_prompt_viz(frame_index, test_type, points=prompts, response=response)

    def _add_box(self, session_id, box_prompt, points, test_type: str, object_id=1, frame_index=0):
        """Add box for video segmentation, optionally with points."""
        # Extract points and labels if provided
        point_coords = []
        point_labels = []
        if points:
            for point in points:
                if point["type"] == "point":
                    point_coords.append(point["coordinates"])
                    point_labels.append(point["label"])

        request = {
            "type": "add_box",
            "session_id": session_id,
            "frame_index": frame_index,
            "object_id": object_id,
            "box": box_prompt["coordinates"],
            "points": point_coords,
            "labels": point_labels
        }
        response, _ = send_and_check_request(self.runtime_client, self.endpoint_name, request, session_id)
        self.save_prompt_viz(frame_index, test_type, points=points, box=box_prompt["coordinates"], response=response)

    def _add_refinement_points(self, session_id, points, test_type: str, object_id=1, frame_index=0):
        """Add refinement points to tweak the segmentation after initial propagation.
        These points help refine the object's segmentation without redefining what object
        is being tracked. They can be added on any frame and will attempt to maintain consistency in the object definition."""
        self._add_points(session_id, points, test_type, object_id, frame_index, clear_old_points=False)

## Video Predictor Tests

Test video segmentation functionality with various tracking scenarios.

In [None]:
s3 = boto3.client('s3')
region = boto3.Session().region_name

Download a sample image from jumpstart assets.

In [None]:
s3_bucket = f"jumpstart-cache-prod-{region}"
key_prefix = "inference-notebook-assets"

def download_from_s3(key_filenames):
    for key_filename in key_filenames:
        s3.download_file(s3_bucket, f"{key_prefix}/{key_filename}", key_filename)

basketball_layup_mp4 = "basketball-layup.mp4"

# Download images and label-mapping file.
download_from_s3(key_filenames=[basketball_layup_mp4])

Video(filename=basketball_layup_mp4, width=500)

In [None]:
os.makedirs("outputs/video_predictor", exist_ok=True)
runtime_client = boto3.client('sagemaker-runtime')
endpoint_name = predictor.endpoint_name
video_predictor = VideoPredictor(runtime_client, endpoint_name, basketball_layup_mp4)

In [None]:
"""Test tracking single object."""
logger.info("\n=== Testing Single Object Tracking ===")
test_start_time = datetime.now()
try:
    # Start session
    session_id = video_predictor._start_session()

    # Add initial points
    points = [
        {"type": "point", "coordinates": [1478, 649], "label": 1},
        {"type": "point", "coordinates": [1433, 689], "label": 0},
    ]
    video_predictor._add_points(session_id, points, "single_object")

    # Propagate tracking and get results
    results = video_predictor._propagate_video(session_id, "single_object")

    # End session
    video_predictor._end_session(session_id)
    test_duration = (datetime.now() - test_start_time).total_seconds()
    logger.info(f"Single object tracking test completed successfully in {test_duration:.2f} seconds")
except Exception as e:
    logger.error(f"Single object tracking test failed: {str(e)}")
    raise

Video(filename="outputs/video_predictor/single_object/visualization.mp4", width=500)

In [None]:
"""Test tracking multiple objects."""
logger.info("\n=== Testing Multiple Object Tracking ===")
test_start_time = datetime.now()
try:
    # Start session
    session_id = video_predictor._start_session()

    # Add first object
    points1 = [
        {"type": "point", "coordinates": [1478, 649], "label": 1},
        {"type": "point", "coordinates": [1433, 689], "label": 0},
    ]
    video_predictor._add_points(session_id, points1, "multiple_objects", object_id=1)

    # Add second object
    points2 = [{"type": "point", "coordinates": [1433, 689], "label": 1}]
    video_predictor._add_points(session_id, points2, "multiple_objects", object_id=2)

    # Propagate tracking and get results
    results = video_predictor._propagate_video(session_id, "multiple_objects")

    # End session
    video_predictor._end_session(session_id)
    test_duration = (datetime.now() - test_start_time).total_seconds()
    logger.info(f"Multiple object tracking test completed successfully in {test_duration:.2f} seconds")
except Exception as e:
    logger.error(f"Multiple object tracking test failed: {str(e)}")
    raise

Video(filename="outputs/video_predictor/multiple_objects/visualization.mp4", width=500)

In [None]:
"""Test tracking with box prompt."""
logger.info("\n=== Testing Box Tracking ===")
test_start_time = datetime.now()
try:
    # Start session
    session_id = video_predictor._start_session()

    # Add box
    box_prompt = {"type": "box", "coordinates": [1392, 562, 1531, 872]}
    points = []
    video_predictor._add_box(session_id, box_prompt, points, "box_prompt")

    # Propagate tracking and get results
    results = video_predictor._propagate_video(session_id, "box_prompt")

    # End session
    video_predictor._end_session(session_id)
    test_duration = (datetime.now() - test_start_time).total_seconds()
    logger.info(f"Box prompt tracking test completed successfully in {test_duration:.2f} seconds")
except Exception as e:
    logger.error(f"Box prompt tracking test failed: {str(e)}")
    raise

Video(filename="outputs/video_predictor/box_prompt/visualization.mp4", width=500)

In [None]:
"""Test tracking with multiple box prompt."""
logger.info("\n=== Testing Multiple Box Tracking ===")
test_start_time = datetime.now()
try:
    # Start session
    session_id = video_predictor._start_session()

    # Add box
    box_prompt = {"type": "box", "coordinates": [1392, 562, 1531, 872]}
    points = []
    video_predictor._add_box(session_id, box_prompt, points, "multiple_box_prompt", object_id=1)
    
    # Add box
    box_prompt = {"type": "box", "coordinates": [1195, 216, 1485, 440]}
    points = []
    video_predictor._add_box(session_id, box_prompt, points, "multiple_box_prompt", object_id=2)

    # Propagate tracking and get results
    results = video_predictor._propagate_video(session_id, "multiple_box_prompt")

    # End session
    video_predictor._end_session(session_id)
    test_duration = (datetime.now() - test_start_time).total_seconds()
    logger.info(f"Multiple Box prompt tracking test completed successfully in {test_duration:.2f} seconds")
except Exception as e:
    logger.error(f"Multiple Box prompt tracking test failed: {str(e)}")
    raise

Video(filename="outputs/video_predictor/multiple_box_prompt/visualization.mp4", width=500)

In [None]:
"""Test tracking with combined box and point prompts."""
logger.info("\n=== Testing Combined Box and Point Tracking ===")
test_start_time = datetime.now()
try:
    # Start session
    session_id = video_predictor._start_session()

    # Add box with points
    box_prompt = {"type": "box", "coordinates": [1392, 562, 1531, 872]}
    points = [{"type": "point", "coordinates": [1433, 689], "label": 0}]
    video_predictor._add_box(session_id, box_prompt, points, "combined_prompts")

    # Propagate tracking and get results
    results = video_predictor._propagate_video(session_id, "combined_prompts")

    # End session
    video_predictor._end_session(session_id)
    test_duration = (datetime.now() - test_start_time).total_seconds()
    logger.info(f"Combined prompts tracking test completed successfully in {test_duration:.2f} seconds")
except Exception as e:
    logger.error(f"Combined prompts tracking test failed: {str(e)}")
    raise

Video(filename="outputs/video_predictor/combined_prompts/visualization.mp4", width=500)

In [None]:
"""Test tracking with prompt refinement on different frames."""
logger.info("\n=== Testing Prompt Refinement in Video ===")
test_start_time = datetime.now()
try:
    # Start session
    session_id = video_predictor._start_session()

    # Add initial points on frame 0
    points = [
        {"type": "point", "coordinates": [1478, 649], "label": 0},
        {"type": "point", "coordinates": [1433, 689], "label": 1},
    ]
    video_predictor._add_points(session_id, points, "prompt_refinement")
    
    # Propagate tracking and get results
    results = video_predictor._propagate_video(session_id, "prompt_refinement_initial")

    # Add refinement point on frame 114
    refine_points = [
        {"type": "point", "coordinates": [940, 155], "label": 0},
    ]
    video_predictor._add_points(session_id, refine_points, "prompt_refinement", frame_index=114, clear_old_points=False)

    # Propagate tracking and get results
    results = video_predictor._propagate_video(session_id, "prompt_refinement_final")

    # End session
    video_predictor._end_session(session_id)
    test_duration = (datetime.now() - test_start_time).total_seconds()
    logger.info(f"Prompt refinement test completed successfully in {test_duration:.2f} seconds")
except Exception as e:
    logger.error(f"Prompt refinement test failed: {str(e)}")
    raise

Video(filename="outputs/video_predictor/prompt_refinement_initial/visualization.mp4", width=500)

In [None]:
predictor.delete_predictor()

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/generative_ai|sm-jumpstart_foundation_sam_2_1_video_segmentation.ipynb)

