# Bedrock Batch Inference Introduction
In this notebook we'll take a quick look at the [Bedrock Batch API](https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference.html) and take a look at how we can utilize it to solve asynchronous tasks such as summarization. In coming notebooks we'll dive into some more real-world use-cases where you can utilize this feature.


### Setup
We are working in a conda_python3 kernel in a ml.c5.xlarge SageMaker Classic Notebook Instance. You can also execute this in your own environment, note that you need permissions to read and write from S3 with your execution role as well as Bedrock access to work with the Batch API.

### Additional Resources/Credits
- [Official Bedrock Batch Blog](https://aws.amazon.com/blogs/machine-learning/automate-amazon-bedrock-batch-inference-building-a-scalable-and-efficient-pipeline/)
- [Bedrock Batch Samples](https://github.com/aws-samples/amazon-bedrock-samples/tree/main/introduction-to-bedrock/batch_api)
- [Batch API Docs](https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference.html)

In [None]:
!pip install boto3==1.35.99

## Data Creation
Here we can artifically generate a dataset with a dummy transcript that we'll use Claude 3 Haiku to summarize, funnily enough we use Haiku to generate the sample prompt as well. We then structure the payload in the format Claude expects: https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference-data.html.

In [None]:
import boto3
import json

# Initialize Bedrock Runtime Client
bedrock = boto3.client("bedrock-runtime")

# Dummy prompt
prompt = "Generate a realistic customer support call transcript about resolving a billing issue."

# Function to generate dummy data
def generate_data(input_prompt: str = prompt, model_id: str = "anthropic.claude-3-haiku-20240307-v1:0", 
                  anthropic_version: str = "bedrock-2023-05-31", max_tokens: int = 500,
                  mime_type: str = "application/json") -> str:

    response = bedrock.invoke_model(
        modelId=model_id,
        body=json.dumps({
            "anthropic_version": anthropic_version,
            "max_tokens":  max_tokens,
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prompt
                        }
                    ]
                }
            ]
        }),
        contentType=mime_type,
        accept=mime_type
    )

    # Extract the transcript text
    transcript_text = json.loads(response["body"].read())["content"][0]["text"]
    return transcript_text

mock_transcript_text = generate_data(prompt)
print(mock_transcript_text)

### 

In [None]:
# Generation config for claude
generation_config = {
    "anthropic_version": "bedrock-2023-05-31",
    "max_tokens": 512,
    "system": "You are a helpful assistant. Please summarize the transcript provided.",
    "temperature": 0.0,
    "top_p": 0.99,
    "top_k": 250
}

# Output file
output_file = "claude_haiku_batch_requests_summary.jsonl"

# Dummy file with 150 records
with open(output_file, "w") as f:
    for i in range(150):
        record = {
            "recordId": f"REC{i:08d}",
            "modelInput": {
                **generation_config,
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": f"Summarize the following call transcript:\n\n{mock_transcript_text}"
                            }
                        ]
                    }
                ]
            }
        }
        f.write(json.dumps(record) + "\n")

print(f"Input data file created: {output_file}")

### Upload Artifacts to S3

In [None]:
import boto3
import uuid
import os

region = "us-east-1"
file_key = f"input_data/{output_file}"

# Create bucket and upload input data jsonlines file
s3 = boto3.client("s3", region_name=region)
unique_suffix = str(uuid.uuid4())[:8]
bucket_name = f"bedrock-batch-{unique_suffix}"

s3.create_bucket(
    Bucket=bucket_name)
print(f"Created bucket: {bucket_name}")

# Upload file
s3.upload_file(Filename=output_file, Bucket=bucket_name, Key=file_key)
input_s3_uri = f"s3://{bucket_name}/{file_key}"
print(f"Uploaded file to: {input_s3_uri}")

# Output folder for results
output_prefix = "output_results/"
s3.put_object(Bucket=bucket_name, Key=output_prefix)
output_s3_uri = f"s3://{bucket_name}/{output_prefix}"
print(f"Created output folder: {output_s3_uri}")

print(f"Input Data URI : {input_s3_uri}")
print(f"Output Results URI: {output_s3_uri}")

## Create IAM Role
Ensure we give read and write access to our S3 Bucket with input and output data locations: https://docs.aws.amazon.com/bedrock/latest/userguide/batch-iam-sr.html

In [None]:
import uuid

iam = boto3.client('iam')

account_id = boto3.client('sts').get_caller_identity()['Account']
region = 'us-east-1'  # Replace with your AWS region

trust_policy = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Principal": {"Service": "bedrock.amazonaws.com"},
            "Action": "sts:AssumeRole",
            "Condition": {
                "StringEquals": {"aws:SourceAccount": account_id},
                "ArnEquals": {
                    "aws:SourceArn": f"arn:aws:bedrock:{region}:{account_id}:model-invocation-job/*"
                }
            }
        }
    ]
}

role_name = f"BedrockBatchInferenceRole-{uuid.uuid4().hex[:8]}"

response = iam.create_role(
    RoleName=role_name,
    AssumeRolePolicyDocument=json.dumps(trust_policy),
    Description="Service role for Amazon Bedrock batch inference"
)
role_arn = response['Role']['Arn']
print(f"Created role: {role_arn}")

#Attach policy
s3_policy = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Sid": "S3Access",
            "Effect": "Allow",
            "Action": [
                "s3:GetObject",
                "s3:PutObject",
                "s3:ListBucket"
            ],
            "Resource": [
                f"arn:aws:s3:::{bucket_name}",
                f"arn:aws:s3:::{bucket_name}/*",
                f"arn:aws:s3:::{bucket_name}",
                f"arn:aws:s3:::{bucket_name}/*"
            ],
            "Condition": {
                "StringEquals": {
                    "aws:ResourceAccount": account_id
                }
            }
        }
    ]
}

iam.put_role_policy(
    RoleName=role_name,
    PolicyName="BedrockS3AccessPolicy",
    PolicyDocument=json.dumps(s3_policy)
)
print("Attached S3 access policy to the role.")

## Create Batch Job
Boto3 API Call: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_CreateModelInvocationJob.html#bedrock-CreateModelInvocationJob-request-inputDataConfig

In [None]:
inputDataConfig = {
    "s3InputDataConfig": {
        "s3Uri": input_s3_uri
    }
}

outputDataConfig = {
    "s3OutputDataConfig": {
        "s3Uri": output_s3_uri
    }
}

# client to create invocation job, different from runtime
bedrock = boto3.client('bedrock', region_name = 'us-east-1')

#model ID for the Batch Job
model_id = "anthropic.claude-3-haiku-20240307-v1:0"

# start job
response=bedrock.create_model_invocation_job(
        roleArn=role_arn,
        modelId=model_id,
        jobName=f"batch-claude-summarization-{str(__import__('uuid').uuid4())[:8]}",
        inputDataConfig=inputDataConfig,
        outputDataConfig=outputDataConfig
    )
jobArn = response.get('jobArn')

In [None]:
# Code Snippet borrowed from: https://research-it.wharton.upenn.edu/programming/using-aws-bedrocks-batch-api/
import time
print("Monitoring batch job...")
while True:
    job_status_response = bedrock.get_model_invocation_job(jobIdentifier=jobArn)
    # Possible status: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/get_model_invocation_job.html
    status = job_status_response['status']
    if status in ['InProgress', 'Initializing', 'Submitted', 'Validating', 'Scheduled']:
        print(f"Job {jobArn} is {status}. Waiting for completion...")
        time.sleep(100)
    elif status == 'Completed':
        print(f"Job {jobArn} completed successfully.")
        break
    elif status == 'Failed':
        print(f"Job {jobArn} failed.")
        raise RuntimeError("Job failed.")
    else:
        print(f"Job {jobArn} has unexpected status: {status}")
        time.sleep(100)

## Parse Output Results
You should see two files created: one manifest file with the cumulative job results and another out file with each individual output for the corresponding input. Ensure to adjust file path to reflect your output files.

In [None]:
jobs = bedrock.list_model_invocation_jobs()['invocationJobSummaries']
matching_job = next((job for job in jobs if job['jobArn'] == jobArn), None)
output_results = matching_job['outputDataConfig']['s3OutputDataConfig']['s3Uri']
output_results

In [None]:
!aws s3 cp {output_results} ./results --recursive #creates a results folder

### Cumulative Job Metrics

In [None]:
import json

# replace with your file path
file_path = 'results/ykfptibinz1x/manifest.json.out'

with open(file_path, 'r') as f:
    data = json.load(f)
print(f"Overall metrics: {data}")

### Individual Outputs

In [None]:
# replace with your file path
file_path = 'results/ykfptibinz1x/claude_haiku_batch_requests_summary.jsonl.out'

parsed_data = []

with open(file_path, 'r') as file:
    for line in file:
        data = json.loads(line)
    
        # Input text
        input_text = ""
        try:
            input_text = data.get('modelInput', {}).get('messages', [])[0].get('content', [])[0].get('text', '')
        except (IndexError, AttributeError):
            pass
    
        # Output text
        output_text = ""
        try:
            output_text = data.get('modelOutput', {}).get('content', [])[0].get('text', '')
        except (IndexError, AttributeError):
            pass
    
        # Record ID
        record_id = data.get('recordId', '')
    
        # Store the extracted information
        parsed_data.append({
            'recordId': record_id,
            'input_text': input_text,
            'output_text': output_text
        })

for item in parsed_data:
    print(f"Record ID: {item['recordId']}")
    print("Input:", item['input_text'])
    print("Output:", item['output_text'])
    print("-" * 50)