# Amazon Bedrock Batch Inference for Model Distillation

## Introduction

This notebook demonstrates how to use Amazon Bedrock's batch inference capabilities to process multiple inputs at once using your distilled models. Batch inference is useful when you need to process a large number of inputs efficiently without the overhead of making individual API calls.

In this notebook, we will:
1. Upload batch inference data to an S3 bucket
2. Submit batch inference jobs for multiple models, including our provisioned throughput endpoint
3. Monitor the status of these batch inference jobs
4. Prepare for evaluation of accuracy improvements in our distilled model

The batch inference results will be used to evaluate the accuracy improvements achieved through model distillation. By running batch inference on both our distilled model (via the provisioned throughput endpoint created in the previous notebook) and other models for comparison, we can quantitatively assess the performance of our distilled model.

## Setup and Prerequisites

First, let's set up our environment and import required libraries.

In [None]:
# upgrade boto3 
%pip install --upgrade pip --quiet
%pip install boto3 --upgrade --quiet

# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [None]:
# load PT model id from previous notebook
%store -r provisioned_model_id
%store -r custom_model_id

In [1]:
import json
import sys
import os
import time

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

import boto3
from datetime import datetime
from botocore.exceptions import ClientError
from utils import create_s3_bucket

# Create Bedrock client
bedrock_client = boto3.client(service_name="bedrock", region_name='us-east-1')

# Create runtime client for inference
bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name='us-east-1')

# Region and accountID
session = boto3.session.Session(region_name='us-east-1')
region = 'us-east-1'
sts_client = session.client(service_name='sts', region_name='us-east-1')
account_id = sts_client.get_caller_identity()['Account']

# Define bucket and prefixes (using the same bucket as in distillation)
bucket_name = 'sample-data-us-east-1-228707323172-1' # '905418197933-distillation'  # Same bucket used in distillation notebook
data_prefix = 'citations_distillation'  # Same prefix used in distillation notebook
batch_inference_prefix = f"{data_prefix}/batch_inference"  # New prefix for batch inference

## 1. Upload Batch Inference Data to S3

First, we'll upload our batch inference data file to the S3 bucket with the batch inference prefix.

In [2]:
# Define the local path to the batch inference data file
batch_inference_file = 'batch_inf_data.jsonl'

# Upload the batch inference data to S3
def upload_batch_inference_data(bucket_name, file_name, prefix):
    """
    Upload batch inference data to S3 bucket
    """
    s3_client = boto3.client('s3')
    
    # Check if bucket exists, if not create it
    try:
        s3_client.head_bucket(Bucket=bucket_name)
        print(f"Bucket {bucket_name} exists.")
    except ClientError:
        print(f"Creating bucket {bucket_name}...")
        create_s3_bucket(bucket_name=bucket_name)
    
    # Upload file to S3
    s3_key = f"{prefix}/{file_name}"
    s3_client.upload_file(file_name, bucket_name, s3_key)
    print(f"Uploaded {file_name} to s3://{bucket_name}/{s3_key}")
    
    return f"s3://{bucket_name}/{s3_key}"

# Upload batch inference data to S3
batch_inference_s3_uri = upload_batch_inference_data(bucket_name, batch_inference_file, batch_inference_prefix)
print(f"Batch inference data uploaded to: {batch_inference_s3_uri}")

# Define the output location for batch inference results
batch_inference_output_prefix = f"{batch_inference_prefix}/outputs"
batch_inference_output_uri = f"s3://{bucket_name}/{batch_inference_output_prefix}/"

Bucket sample-data-us-east-1-228707323172-1 exists.
Uploaded batch_inf_data.jsonl to s3://sample-data-us-east-1-228707323172-1/citations_distillation/batch_inference/batch_inf_data.jsonl
Batch inference data uploaded to: s3://sample-data-us-east-1-228707323172-1/citations_distillation/batch_inference/batch_inf_data.jsonl


## 2. Submit Batch Inference Jobs
First we'll kick off our batch inference simulator to run batch inference on our Provisioned Throughput endpoint holding our distilled model.

Then, we'll define a list of models and submit a batch inference job for each model.

In [None]:
#  !python3 batch_inference_simulator.py --input distilling_for_citations/batch_inf_data.jsonl --output distilling_for_citations/batch_inference_results/distilled_results.jsonl  --model "<provisioned throughput endpoint ARN>"

In [None]:
# Define the list of models to use for batch inference
# We'll include the teacher model, student model, and our distilled model (provisioned throughput)
models = [
    # "us.amazon.nova-premier-v1:0",  # Teacher model (Nova Premier)
    "amazon.nova-lite-v1:0:300k",   # Student model (Nova Lite)
]

# Function to submit a batch inference job
def submit_batch_inference_job(model_id, input_s3_uri, output_s3_uri):
    """
    Submit a batch inference job for the specified model
    """
    # Generate a unique job name
    timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    model_short_name = model_id.split('/')[-1].split(':')[0]
    job_name = f"batch-inference-{model_short_name}-{timestamp}"
    
    # Create the batch inference job
    response = bedrock_client.create_model_invocation_job(
        jobName=job_name,
        modelId=model_id,
        inputDataConfig={
            "s3InputDataConfig": {
                "s3Uri": input_s3_uri,
                "s3InputFormat": "JSONL"
            }
        },
        outputDataConfig={
            "s3OutputDataConfig": {
                "s3Uri": f"{output_s3_uri}{model_short_name}/"
            }
        },
        roleArn=f"arn:aws:iam::{account_id}:role/service-role/batch-inf-service-role" # TODO build iam role for this
    )
    
    job_id = response['jobArn']
    print(f"Submitted batch inference job for model {model_id}")
    print(f"Job ARN: {job_id}")
    
    return job_id

# Submit batch inference jobs for each model
job_ids = []
for model in models:
    job_id = submit_batch_inference_job(model, batch_inference_s3_uri, batch_inference_output_uri)
    job_ids.append(job_id)

Submitted batch inference job for model amazon.nova-lite-v1:0:300k
Job ARN: arn:aws:bedrock:us-east-1:228707323172:model-invocation-job/99sm7iehijui


## 3. Monitor Batch Inference Jobs

Finally, we'll monitor the status of the batch inference jobs.

In [7]:
# Function to check the status of a batch inference job
def check_job_status(job_id):
    """
    Check the status of a batch inference job
    """
    response = bedrock_client.get_model_invocation_job(jobIdentifier=job_id)
    status = response['status']
    model_id = response['modelId']
    
    print(f"Model: {model_id}")
    print(f"Status: {status}")
    
    if status == 'COMPLETED':
        print(f"Output location: {response['outputDataConfig']['s3OutputDataConfig']['s3Uri']}")
    elif status == 'FAILED':
        print(f"Failure reason: {response.get('failureMessage', 'Unknown')}")
    
    return status

# Check the status of all batch inference jobs
for job_id in job_ids:
    status = check_job_status(job_id)
    print("---")

Model: arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-lite-v1:0
Status: Submitted
---


## 4. Wait for Jobs to Complete

If you want to wait for all jobs to complete before proceeding, you can use the following cell.

In [None]:
# Wait for all jobs to complete
def wait_for_jobs_completion(job_ids, check_interval=60, max_wait_time=3600):
    """
    Wait for all batch inference jobs to complete
    """
    start_time = time.time()
    all_completed = False
    
    while not all_completed and (time.time() - start_time) < max_wait_time:
        print(f"Checking job status at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        
        statuses = []
        for job_id in job_ids:
            response = bedrock_runtime.get_model_invocation_job(jobIdentifier=job_id)
            status = response['status']
            model_id = response['modelId']
            
            print(f"Model: {model_id} - Status: {status}")
            statuses.append(status)
        
        # Check if all jobs are completed or failed
        all_completed = all(status in ['COMPLETED', 'FAILED', 'STOPPED'] for status in statuses)
        
        if not all_completed:
            print(f"Waiting {check_interval} seconds for next check...")
            time.sleep(check_interval)
    
    if all_completed:
        print("All batch inference jobs have completed.")
    else:
        print(f"Maximum wait time of {max_wait_time} seconds exceeded.")

# Uncomment the following line to wait for all jobs to complete
# wait_for_jobs_completion(job_ids)

## 5. Retrieve and Prepare Results for Evaluation

Once the batch inference jobs are complete, you can retrieve the results and prepare them for evaluation in the next notebook (04_evaluate.ipynb).

The batch inference results will be crucial for evaluating the performance of our distilled model compared to the teacher and student models. We'll analyze metrics such as:

1. **Accuracy**: How well does the distilled model match the expected outputs?
2. **Consistency**: Does the distilled model produce consistent results across similar inputs?
3. **Efficiency**: How does the performance compare to the computational resources required?
4. **Specific Task Performance**: For citation generation, we'll evaluate the quality and accuracy of citations produced.

In [None]:
# Function to list batch inference output files
def list_batch_inference_outputs(bucket_name, prefix):
    """
    List the batch inference output files in the S3 bucket
    """
    s3_client = boto3.client('s3')
    response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
    
    if 'Contents' in response:
        for obj in response['Contents']:
            print(f"s3://{bucket_name}/{obj['Key']}")
    else:
        print(f"No objects found in s3://{bucket_name}/{prefix}")

# List the batch inference output files
# Uncomment the following line to list the output files
# list_batch_inference_outputs(bucket_name, batch_inference_output_prefix)

## Conclusion

In this notebook, we've demonstrated how to use Amazon Bedrock's batch inference capabilities to process multiple inputs at once using different models, including our provisioned throughput endpoint for the distilled model. This approach allows us to efficiently compare the performance of different models on the same dataset.

Key steps covered:
1. Uploading batch inference data to an S3 bucket
2. Submitting batch inference jobs for multiple models, including our distilled model
3. Monitoring the status of batch inference jobs
4. Preparing results for evaluation

### Benefits of Using Provisioned Throughput for Batch Inference

Using a provisioned throughput endpoint for batch inference offers several advantages:

1. **Consistent Performance**: Dedicated capacity ensures consistent performance without competing for resources
2. **Cost Efficiency**: For large batch jobs, provisioned throughput can be more cost-effective than on-demand pricing
3. **Higher Throughput**: Ability to process more requests in parallel, reducing overall processing time
4. **Predictable Latency**: More stable and predictable response times

The results from this batch inference process will be used in the next notebook (04_evaluate.ipynb) to quantitatively assess the improvements achieved through model distillation. This evaluation will help determine if the distilled model meets the performance requirements while providing the efficiency benefits of a smaller model.