# Amazon Bedrock Batch Inference for Model Distillation

## Learning Objectives

By the end of this notebook, you will be able to:
1. Design and implement efficient batch inference workflows for distilled models
2. Configure and optimize batch inference jobs for maximum throughput
3. Implement robust monitoring and error handling for batch processing
4. Compare performance characteristics across model variants using batch inference

## Introduction

Batch inference represents a critical deployment pattern for machine learning models, particularly in scenarios requiring high-throughput processing of large datasets. In the context of model distillation, batch inference serves two key purposes:

1. **Performance Validation**: Enables systematic comparison of teacher, student, and distilled models across large test sets
2. **Production Readiness**: Validates the distilled model's ability to handle production-scale workloads

This notebook demonstrates advanced batch inference patterns using Amazon Bedrock, focusing on:

- Optimizing batch sizes and concurrency for maximum throughput
- Leveraging provisioned throughput endpoints for predictable performance
- Implementing robust error handling and retry mechanisms
- Gathering detailed performance metrics for model comparison

### Architecture Overview

The batch inference workflow implemented here follows a distributed processing architecture:

```
S3 Input Bucket → Bedrock Batch Processing → S3 Output Bucket
                     ↓
              Performance Metrics
                     ↓
             Evaluation Pipeline
```

This architecture enables:
- Horizontal scaling for large datasets
- Fault tolerance through automatic retries
- Detailed performance monitoring
- Cost optimization through batch processing

## Setup and Prerequisites

We'll configure our environment with the necessary dependencies and AWS client libraries. This setup assumes you have completed the previous notebooks and have a provisioned throughput endpoint available for your distilled model.

In [None]:
# upgrade boto3 
%pip install --upgrade pip --quiet
%pip install boto3 --upgrade --quiet

# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [None]:
# load PT model id from previous notebook
%store -r provisioned_model_id
%store -r custom_model_id

In [None]:
import json
import sys
import os
import time

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
skip_dir = os.path.dirname(parent_dir)
sys.path.append(skip_dir)

import boto3
from datetime import datetime
from botocore.exceptions import ClientError
from utils import create_s3_bucket

# Create Bedrock client
bedrock_client = boto3.client(service_name="bedrock", region_name='us-east-1')

# Create runtime client for inference
bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name='us-east-1')

# Region and accountID
session = boto3.session.Session(region_name='us-east-1')
region = 'us-east-1'
sts_client = session.client(service_name='sts', region_name='us-east-1')
account_id = sts_client.get_caller_identity()['Account']

# Define bucket and prefixes (using the same bucket as in distillation)
BUCKET_NAME = '<BUCKET_NAME>' # Same bucket used in distillation notebook
DATA_PREFIX = 'citations_distillation'  # Same prefix used in distillation notebook
batch_inference_prefix = f"{DATA_PREFIX}/batch_inference"  # New prefix for batch inference

## 1. Upload Batch Inference Data to S3

The first step in our batch inference pipeline is preparing and uploading the test dataset. For optimal performance, consider these best practices:

- **Data Format**: Use JSONL format for efficient streaming processing
- **File Size**: Aim for files between 1-10GB for optimal throughput
- **Compression**: Consider using GZIP compression for large datasets
- **Data Validation**: Implement schema validation before upload

The following code implements these practices while handling edge cases and errors:

In [None]:
# Define the local path to the batch inference data file
batch_inference_file = 'batch_inf_data.jsonl'

# Upload the batch inference data to S3
def upload_batch_inference_data(bucket_name, file_name, prefix):
    """
    Upload batch inference data to S3 bucket
    """
    s3_client = boto3.client('s3')
    
    # Check if bucket exists, if not create it
    try:
        s3_client.head_bucket(Bucket=bucket_name)
        print(f"Bucket {bucket_name} exists.")
    except ClientError:
        print(f"Creating bucket {bucket_name}...")
        create_s3_bucket(bucket_name=bucket_name)
    
    # Upload file to S3
    s3_key = f"{prefix}/{file_name}"
    s3_client.upload_file(file_name, bucket_name, s3_key)
    print(f"Uploaded {file_name} to s3://{bucket_name}/{s3_key}")
    
    return f"s3://{bucket_name}/{s3_key}"

# Upload batch inference data to S3
batch_inference_s3_uri = upload_batch_inference_data(BUCKET_NAME, batch_inference_file, batch_inference_prefix)
print(f"Batch inference data uploaded to: {batch_inference_s3_uri}")

# Define the output location for batch inference results
batch_inference_output_prefix = f"{batch_inference_prefix}/outputs"
batch_inference_output_uri = f"s3://{BUCKET_NAME}/{batch_inference_output_prefix}/"

## 2. Submit Batch Inference Jobs

When submitting batch inference jobs, several key configuration parameters affect performance and reliability:

1. **Concurrency Configuration**
   - MaxConcurrentInvocations: Controls parallel processing
   - BatchSize: Number of records per batch
   - TimeoutInSeconds: Maximum processing time per batch

2. **Resource Optimization**
   - Memory allocation
   - CPU/GPU utilization
   - Network bandwidth

3. **Error Handling**
   - Retry strategies
   - Dead letter queues
   - Error logging

We'll first run batch inference on our provisioned throughput endpoint using a script that simulates the results exactly as bedrock inference. `batch_inference_simulator.py` will take the same data format as input as a normal batch inference job would. It also outputs the same format. Note that this designed to work specifically for Nova models. Feel free to use this to speed up this process, or enjoy a reduce cost per inference using batch inference.


We'll then compare results with other model variants.

Once your distilled model batch inferences are complete, be sure to delete the provisioned throughput endpoint.

In [None]:
!python3 batch_inference_simulator.py --input batch_inf_data.jsonl --output batch_inference_results/distilled_results.jsonl  --model "arn:aws:bedrock:us-east-1:<account_id>:provisioned-model/pt_endpoint_id" # lite
# !python3 batch_inference_simulator.py --input batch_inf_data.jsonl --output batch_inference_results/nova_micro_results.jsonl  --model "us.amazon.micro-v1:0" # micro

### Delete PT Endpoint

Proper cleanup of resources is essential for cost management. Use the following code to remove created resources when they're no longer needed.

In [None]:
# delete provisioned throughput:
response = bedrock_client.delete_provisioned_model_throughput(provisionedModelId=provisioned_model_id)

Next, we'll submit batch inference jobs for our out-of-the-box models.
You'll need to create a batch inference service role before moving forward: https://docs.aws.amazon.com/bedrock/latest/userguide/batch-iam-sr.html

In [None]:
batch_inf_role_arn=f"arn:aws:iam::{account_id}:role/AmazonNovaBedrockBatchServiceRole"

In [None]:
# Define the list of models to use for batch inference
# We'll include the teacher model, student model, and our distilled model (provisioned throughput)
models = [
    "us.amazon.nova-premier-v1:0",  # Teacher model (Nova Premier)
    "amazon.nova-lite-v1:0",   # Student model (Nova Lite)
    "amazon.nova-micro-v1:0", 
]

# Function to submit a batch inference job
def submit_batch_inference_job(model_id, input_s3_uri, output_s3_uri):
    """
    Submit a batch inference job for the specified model
    """
    # Generate a unique job name
    timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    model_short_name = model_id.split('/')[-1].split(':')[0]
    job_name = f"distillation-bench-{model_short_name}-{timestamp}"
    
    # Create the batch inference job
    response = bedrock_client.create_model_invocation_job(
        jobName=job_name,
        modelId=model_id,
        inputDataConfig={
            "s3InputDataConfig": {
                "s3Uri": input_s3_uri,
                "s3InputFormat": "JSONL"
            }
        },
        outputDataConfig={
            "s3OutputDataConfig": {
                "s3Uri": f"{output_s3_uri}{model_short_name}/"
            }
        },
        roleArn=batch_inf_role_arn
    )
    
    job_id = response['jobArn']
    print(f"Submitted batch inference job for model {model_id}")
    print(f"Job ARN: {job_id}")
    
    return job_id

# Submit batch inference jobs for each model
job_ids = []
for model in models:
    job_id = submit_batch_inference_job(model, batch_inference_s3_uri, batch_inference_output_uri)
    job_ids.append(job_id)


## 3. Monitor Batch Inference Jobs
🕐 Its important to remember that batch inference jobs can take many hours to complete, in exchange for a reduction in inference pricing. It will likely be 12-24 hours to complete, so come back to this notebook once those batch inference jobs have completed. Alternatively, you can run the above batch simulator using Nova on-demand inferencing to speed this process up at on-demand pricing.

Let's check the status of our jobs

In [None]:
# Function to check the status of a batch inference job
def check_job_status(job_id):
    """
    Check the status of a batch inference job
    """
    response = bedrock_client.get_model_invocation_job(jobIdentifier=job_id)
    status = response['status']
    model_id = response['modelId']
    
    print(f"Model: {model_id}")
    print(f"Status: {status}")
    
    if status == 'COMPLETED':
        print(f"Output location: {response['outputDataConfig']['s3OutputDataConfig']['s3Uri']}")
    elif status == 'FAILED':
        print(f"Failure reason: {response.get('failureMessage', 'Unknown')}")
    
    return status

# Check the status of all batch inference jobs
for job_id in job_ids:
    status = check_job_status(job_id)
    print("---")

## 4. Retrieve and Prepare Results for Evaluation

### Batch Results Downloader
Function: download_batch_results

This function retrieves output files from Amazon Bedrock batch inference jobs stored in S3.

**What it does:**
* Retrieves job details from Amazon Bedrock to locate the S3 output location
* Lists all objects in the output directory of the specified batch job
* Filters and downloads only result files (ending in .out, excluding manifest files)
* Removes the .out extension from downloaded files
* Handles errors appropriately during the download process

In [None]:
# download outputs from jobs
from urllib.parse import urlparse
def download_batch_results(job_id, target_directory=None):
    """
    Download batch job results from S3 to a local directory.
    
    Args:
        job_id (str): The identifier for the Bedrock model invocation job
        target_directory (str, optional): Directory to save the downloaded file.
                                          Defaults to current directory if None.
    
    Returns:
        list: Paths to the downloaded files
    """
    try:
        # Get job details to find the S3 output location
        response = bedrock_client.get_model_invocation_job(jobIdentifier=job_id)
        output_s3_prefix = response['outputDataConfig']['s3OutputDataConfig']['s3Uri']
        job_prefix = job_id.split('/')[-1]
        model_id = response['modelId'].split('/')[-1].replace('.','-').replace(':', '-')
        output_location = f"{output_s3_prefix}{job_prefix}/"
        
        # Parse the S3 URI
        parsed_uri = urlparse(output_location)
        bucket_name = parsed_uri.netloc
        s3_prefix = parsed_uri.path.lstrip('/')
        
        # Set target directory if not provided
        if target_directory is None:
            target_directory = os.getcwd()
        
        # Create target directory if it doesn't exist
        os.makedirs(target_directory, exist_ok=True)
        
        # List objects in the output location
        s3_client = boto3.client('s3')
        downloaded_files = []
        
        paginator = s3_client.get_paginator('list_objects_v2')
        pages = paginator.paginate(Bucket=bucket_name, Prefix=s3_prefix)
        
        for page in pages:
            if 'Contents' not in page:
                continue
                
            for obj in page['Contents']:
                key = obj['Key']
                
                # Download only files ending with .out and not manifest.json.out
                if key.endswith('jsonl.out') and not key.endswith('manifest.json.out'):
                    # Create the output filename by removing .out extension
                    filename = os.path.basename(key)
                    output_filename = f"{model_id}-{filename[:-4]}" if filename.endswith('.out') else filename
                    local_file_path = os.path.join(target_directory, output_filename)
                    
                    print(f"Downloading {key} to {local_file_path}")
                    s3_client.download_file(bucket_name, key, local_file_path)
                    downloaded_files.append(local_file_path)
        
        return downloaded_files
    
    except ClientError as e:
        print(f"Error downloading batch results: {e}")
        raise
    except Exception as e:
        print(f"Unexpected error: {e}")
        raise

In [None]:
for job in job_ids:
    download_batch_results(job_id=job, target_directory="batch_inference_results/")

## Conclusion and Next Steps

In this notebook, we've walked through how to submit batch inference jobs. The results from these jobs will be what's used to evaluate our distilled model's performance.
You should see the batch inference results under the `evaluation_results` directory.


### Next Steps

Proceed to [04_evaluate.ipynb](04_evaluate.ipynb) to:
1. Analyze batch inference results across multiple dimensions
2. Compare performance metrics between model variants
3. Evaluate the success of the distillation process
4. Make data-driven decisions about production deployment