In [32]:
import boto3
import sagemaker
import json
import httpx
import base64
from io import BytesIO
from typing import Dict, Any
from sagemaker import ModelPackage
from IPython.display import Markdown

import time
import concurrent.futures
import requests
from tqdm import tqdm

# Mistral OCR SageMaker Deployment

This notebook demonstrates how to deploy the Mistral OCR model to an Amazon SageMaker endpoint for real-time inference.

## Supported Instance Types

The Mistral OCR model requires GPU instances. The following instance types are supported:

### Quota increase via auto approval (Almost immediately approved)
- `ml.g6.2xlarge`
- `ml.g6.4xlarge`
- `ml.g6.8xlarge`
- `ml.g6.16xlarge`

### Quota increase via support ticket (May take a few days)
- `ml.g6e.xlarge`
- `ml.g6e.2xlarge`
- `ml.g6e.4xlarge`
- `ml.g6e.8xlarge`
- `ml.g6e.16xlarge`

## Configuration Parameters

Before running this notebook, you need to configure the following parameters:

| Parameter | Description | Example |
|-----------|-------------|---------|
| `MISTRAL_OCR_MODEL_PACKAGE_ARN` | The Amazon Resource Name (ARN) of the Mistral OCR model package from AWS Marketplace (Product ARN) | `arn:aws:sagemaker:us-west-2:123456789012:model-package/...` |
| `MISTRAL_OCR_MODEL_CONFIG_INSTANCE_TYPE` | The EC2 instance type to host the model (see supported types above) | `ml.g6.4xlarge` |
| `SAGEMAKER_EXECUTION_ROLE_ARN` | IAM role ARN with permissions to create and invoke SageMaker endpoints | `arn:aws:iam::123456789012:role/SageMakerExecutionRole` |
| `MISTRAL_OCR_ENDPOINT_NAME` | A unique name for your SageMaker endpoint | `mistral-ocr-endpoint` |

❗You can find the product ARN from the AWS Marketplace product detail page. Select the region you want to deploy the model first, then you will have the correct product ARN in that region.

<div style="background-color: #fff3cd; border-left: 6px solid #ffeb3b; padding: 10px;">
<strong>Note:</strong> Please contact the AWS 3P MP team to gain model access for your customer. Once the access is provided, you can search "Mistral OCR" on AWS Marketplace to subscribe to this model. 
<br><br>
You can find the product ARN from the AWS Marketplace product detail page. Select the region you want to deploy the model first, then you will have the correct product ARN in that region.
</div>

### Required IAM Permissions

The IAM role specified in `SAGEMAKER_EXECUTION_ROLE_ARN` should have the following policies attached:
- `AmazonSageMakerFullAccess`
- S3 read/write access to model artifacts
- CloudWatch Logs access for endpoint logging

In [None]:
MISTRAL_OCR_MODEL_PACKAGE_ARN = "<MISTRAL_OCR_MODEL_PACKAGE_ARN>" 
MISTRAL_OCR_MODEL_CONFIG_INSTANCE_TYPE = "ml.g6.4xlarge" 
SAGEMAKER_EXECUTION_ROLE_ARN = "<SAGEMAKER_EXECUTION_ROLE_ARN>" 

## Real-time Inference Endpoint Deployment

This section demonstrates how to deploy Mistral OCR as a real-time inference endpoint on SageMaker. 

### Deployment Steps:
1. Create a ModelPackage object from the Marketplace ARN
2. Deploy the model to a SageMaker endpoint 
3. Configure auto-scaling to optimize cost and performance

In [None]:
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

MISTRAL_OCR_ENDPOINT_NAME = "mistral-ocr-real-time-endpoint-1" # provide an unique endpoint name

session = sagemaker.Session()
role = SAGEMAKER_EXECUTION_ROLE_ARN
model_package = ModelPackage(
    role=role,
    model_package_arn=MISTRAL_OCR_MODEL_PACKAGE_ARN,
    sagemaker_session=session
)

# Deploy the model
model = model_package.deploy(
    initial_instance_count=1,
    instance_type=MISTRAL_OCR_MODEL_CONFIG_INSTANCE_TYPE,
    endpoint_name=MISTRAL_OCR_ENDPOINT_NAME,
    model_data_download_timeout=3600,
    container_startup_health_check_timeout=3600
    )

In [None]:
# Auto-Scaling Configuration

# Configure auto-scaling to handle variable traffic patterns
# Note: Below endpoints cannot scale to zero instances. Inference component doesn't support model package from MarketPlace
client = session.boto_session.client('application-autoscaling')

# For traditional endpoints.  
resource_id = f'endpoint/{MISTRAL_OCR_ENDPOINT_NAME}/variant/AllTraffic'
scalable_dimension = 'sagemaker:variant:DesiredInstanceCount'

# Register the endpoint as a scalable target
client.register_scalable_target(
    ServiceNamespace='sagemaker',
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
    MinCapacity=1,  # Minimum 1 instance required 
    MaxCapacity=5   # Maximum 5 instances - adjust based on your expected traffic
)

# Configure scaling policy based on invocation count
client.put_scaling_policy(
    ServiceNamespace='sagemaker',
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
    PolicyName=f'{MISTRAL_OCR_ENDPOINT_NAME}-scaling-policy',
    PolicyType='TargetTrackingScaling',
    TargetTrackingScalingPolicyConfiguration={
        'TargetValue': 5.0,  # Target 5 invocations per instance
        'PredefinedMetricSpecification': {
            'PredefinedMetricType': 'SageMakerVariantInvocationsPerInstance', 
        },
        'ScaleInCooldown': 300,  # Wait 5 minutes before scaling in (reduced from 30 to 300 seconds)
        'ScaleOutCooldown': 60,   # Wait 1 minute before scaling out (changed from 30 to 60 seconds)
        'DisableScaleIn': False    # Enable scale-in to minimum capacity
    }
)

In [None]:
def run_inference(client, endpoint_name: str, payload: dict[str,Any]) -> Dict[str, Any]:
    """
    Invoke the SageMaker endpoint for OCR inference.
    
    Args:
        client: SageMaker runtime client
        endpoint_name: Name of the deployed endpoint
        payload: JSON payload containing the image data
        
    Returns:
        Dictionary containing parsed OCR results
    """
    try:
        inference_out = client.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType="application/json",
            Body=json.dumps(payload)
        )
        inference_resp_str = inference_out["Body"].read().decode("utf-8")
        return json.loads(inference_resp_str)
    except Exception as e:
        print(f"Inference error: {e}")
        raise

def download_and_encode_image(url: str) -> str:
    """
    Download an image from a URL and encode it as base64.
    
    Args:
        url: URL of the image to download
        
    Returns:
        Base64-encoded image string
    """
    try:
        # Send a GET request to the URL
        with httpx.Client() as client:
            response = client.get(url, timeout=10)
            response.raise_for_status()  # Raise an exception for HTTP errors
        # Encode the image content to base64
        image_data = response.content
        base64_encoded_data = base64.b64encode(image_data).decode('utf-8')
        return base64_encoded_data
    except httpx.HTTPStatusError as exc:
        print(f"Error response {exc.response.status_code} while requesting {exc.request.url}")
        raise
    except httpx.RequestException as e:
        print(f"Error downloading image: {e}")
        raise

In [None]:
# Single image Inference Test

# Download sample receipt image and encode as base64
receipt_image_url = "https://cms.mistral.ai/assets/1d7df1b8-5caa-47b9-b6a1-666b05d38019"
receipt_image_b64 = download_and_encode_image(url=receipt_image_url)

# Prepare the payload for Mistral OCR model
payload = {
    "model": "mistral-ocr-2505",
    "document": {
        "type": "image_url",
        "image_url": f"data:image/jpeg;base64,{receipt_image_b64}"
    }
}

# Create a client and invoke the endpoint
sagemaker_client = boto3.client("sagemaker-runtime")
receipt_parsed = run_inference(client=sagemaker_client, endpoint_name=MISTRAL_OCR_ENDPOINT_NAME, payload=payload)

# Display the OCR results in markdown format
Markdown(receipt_parsed["pages"][0]["markdown"])

In [39]:
#Stress testing 

def send_request(args):
    """
    Send a single request to the endpoint and measure latency.
    
    Args:
        args: Tuple containing (client, endpoint_name, payload, request_id)
        
    Returns:
        Dictionary with success/failure status and latency information
    """
    client, endpoint_name, payload, request_id = args
    try:
        start_time = time.time()
        _ = run_inference(client, endpoint_name, payload)
        end_time = time.time()
        latency = end_time - start_time
        return {"success": True, "latency": latency}
    except Exception as e:
        return {"success": False, "error": str(e)}

def run_stress_test(image_url, endpoint_name, num_requests=100, max_workers=10):
    """
    Run a stress test against the endpoint with configurable concurrency.
    
    Args:
        image_url: URL of the image to use for testing
        endpoint_name: Name of the SageMaker endpoint
        num_requests: Total number of requests to send
        max_workers: Maximum number of concurrent requests
        
    Returns:
        Dictionary containing test statistics
    """
    # Setup client and prepare payload
    sagemaker_client = boto3.client("sagemaker-runtime")
    print(f"Downloading and encoding image from {image_url}...")
    receipt_image_b64 = download_and_encode_image(url=image_url)

    payload = {
        "model": "mistral-ocr-2505",
        "document": {
            "type": "image_url",
            "image_url": f"data:image/jpeg;base64,{receipt_image_b64}"
        }
    }

    # Create argument list for parallel requests
    args_list = [(sagemaker_client, endpoint_name, payload, i) for i in range(num_requests)]

    # Track metrics
    start_time = time.time()
    results = []

    print(f"Starting stress test - sending {num_requests} requests to {endpoint_name}...")

    # Run requests in parallel
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(send_request, arg) for arg in args_list]
        for future in tqdm(concurrent.futures.as_completed(futures), total=num_requests):
            results.append(future.result())

    # Calculate statistics
    end_time = time.time()
    total_time = end_time - start_time
    successful = sum(1 for r in results if r["success"])
    failed = num_requests - successful


    # Print results
    print("\nStress Test Results:")
    print(f"Total requests: {num_requests}")
    print(f"Successful: {successful} ({successful/num_requests*100:.1f}%)")
    print(f"Failed: {failed} ({failed/num_requests*100:.1f}%)")
    print(f"Total time: {total_time:.2f} seconds")
    print(f"Throughput: {successful/total_time:.2f} requests/second")


    if failed > 0:
        error_counts = {}
        for r in results:
            if not r["success"]:
                error_type = r["error"].split(':')[0]
                error_counts[error_type] = error_counts.get(error_type, 0) + 1

        print("\nError breakdown:")
        for error, count in error_counts.items():
            print(f"  {error}: {count} ({count/failed*100:.1f}%)")

    return {
        "total_requests": num_requests,
        "successful": successful,
        "failed": failed,
        "total_time": total_time,
        "throughput": successful/total_time,
    }

# Run the stress test with increasing concurrency
for concurrency in [5, 10, 20]:
    print(f"\n=== Testing with {concurrency} concurrent requests ===")
    stats = run_stress_test(
        image_url=receipt_image_url,
        endpoint_name=MISTRAL_OCR_ENDPOINT_NAME,
        num_requests=20,  # Adjust based on your needs
        max_workers=concurrency
    )


=== Testing with 5 concurrent requests ===
Downloading and encoding image from https://cms.mistral.ai/assets/1d7df1b8-5caa-47b9-b6a1-666b05d38019...


Starting stress test - sending 20 requests to mistral-ocr-notebook-endpoint-4...


100%|██████████| 20/20 [00:30<00:00,  1.55s/it]


Stress Test Results:
Total requests: 20
Successful: 20 (100.0%)
Failed: 0 (0.0%)
Total time: 31.13 seconds
Throughput: 0.64 requests/second

=== Testing with 10 concurrent requests ===
Downloading and encoding image from https://cms.mistral.ai/assets/1d7df1b8-5caa-47b9-b6a1-666b05d38019...





Starting stress test - sending 20 requests to mistral-ocr-notebook-endpoint-4...


100%|██████████| 20/20 [00:22<00:00,  1.13s/it]


Stress Test Results:
Total requests: 20
Successful: 20 (100.0%)
Failed: 0 (0.0%)
Total time: 22.88 seconds
Throughput: 0.87 requests/second

=== Testing with 20 concurrent requests ===
Downloading and encoding image from https://cms.mistral.ai/assets/1d7df1b8-5caa-47b9-b6a1-666b05d38019...





Starting stress test - sending 20 requests to mistral-ocr-notebook-endpoint-4...


 20%|██        | 4/20 [00:19<00:46,  2.90s/it]

 85%|████████▌ | 17/20 [00:19<00:01,  2.81it/s]

100%|██████████| 20/20 [00:20<00:00,  1.00s/it]



Stress Test Results:
Total requests: 20
Successful: 20 (100.0%)
Failed: 0 (0.0%)
Total time: 20.49 seconds
Throughput: 0.98 requests/second
