# Amazon Bedrock Model Distillation Guide using Meta Llama models

## Introduction
Model distillation is the process of transferring knowledge from a larger more intelligent teacher model to a smaller, faster, and more cost-efficient student model. To use Amazon Bedrock Model Distillation, first select a teacher model whose accuracy you want to achieve. Then, provide use case specific prompts as the task-specific fine-tuning dataset. Note that you can either provide your own prompts in JSON Line (JSONL) format, or you can use Bedrock invocation logs. With invocation logs, Bedrock can either use prompt-response pairs or prompts alone to fine-tune the student model. Amazon Bedrock doesn’t use your data to train any other teacher or student model for public use, and only you can access the final distilled model.

Depending on the use case, Amazon Bedrock model distillation may use proprietary data synthesis techniques to generate high-quality responses from the teacher model. Data synthesis techniques might increase the size of the fine-tuning dataset to a maximum of 15K prompt-response pairs. Charges here are billed at the on-demand inference rate of the teacher model. [Click here](https://docs.aws.amazon.com/bedrock/latest/userguide/model-distillation.html) to read more about model distillation from the Amazon Bedrock User Guide. [Click here](https://aws.amazon.com/blogs/machine-learning/a-guide-to-amazon-bedrock-model-distillation-preview/) to read an AWS Blog that that covers the distillation workflow. This notebook goes over Bedrock Model Distillation using **invocation logs**. Llama 3.1 405B is the teacher model and Llama 3.1 8B is the student model.

The guide covers essential API operations including:
- Creating and configuring distillation jobs
- Invoke model to generate invocation logs using ConverseAPI
- Working with historical invocation logs in your account to create distillation job
- Managing model provisioning and deployment
- Running inference with distilled models

## Best Practices and Considerations

When using model distillation:
1. Ensure your training data is diverse and representative of your use case
2. Monitor distillation metrics in the S3 output location
3. Evaluate the distilled model's performance against your requirements
4. Consider cost-performance tradeoffs when selecting model units for deployment

The distilled model should provide faster responses and lower costs while maintaining acceptable performance for your specific use case.

## Setup and Prerequisites

Before starting with model distillation, ensure you have the following:

#### Required AWS Resources:
- An AWS account with appropriate permissions
- Amazon Bedrock access enabled in your preferred region
- An S3 bucket for storing invocation logs 
- An S3 bucket to store output metrics
- Sufficient service quota to use Provisioned Throughput in Bedrock
- An IAM role with the following permissions:

IAM Policy:
```json
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": [
                "s3:GetObject",
                "s3:PutObject",
                "s3:ListBucket"
            ],
            "Resource": [
                "arn:aws:s3:::YOUR_DISTILLATION_OUTPUT_BUCKET",
                "arn:aws:s3:::YOUR_DISTILLATION_OUTPUT_BUCKET/*",
                "arn:aws:s3:::YOUR_INVOCATION_LOG_BUCKET",
                "arn:aws:s3:::YOUR_INVOCATION_LOG_BUCKET/*"
            ]
        },
        {
            "Effect": "Allow",
            "Action": [
                "bedrock:CreateModelCustomizationJob",
                "bedrock:GetModelCustomizationJob",
                "bedrock:ListModelCustomizationJobs",
                "bedrock:StopModelCustomizationJob"
            ],
            "Resource": "arn:aws:bedrock:YOUR_REGION:YOUR_ACCOUNT_ID:model-customization-job/*"
        }
    ]
}
```

Trust Relationship:
```json
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Principal": {
                "Service": [
                    "bedrock.amazonaws.com"
                ]
            },
            "Action": "sts:AssumeRole",
            "Condition": {
                "StringEquals": {
                    "aws:SourceAccount": "YOUR_ACCOUNT_ID"
                },
                "ArnLike": {
                    "aws:SourceArn": "arn:aws:bedrock:YOUR_REGION:YOUR_ACCOUNT_ID:model-customization-job/*"
                }
            }
        }
    ]
}
```




#### Dataset:
In this notebook we will be using a modified version of the [Paul Graham Essay dataset](https://github.com/run-llama/llama_index/tree/main/llama-datasets/paul_graham_essay) from [LlamaIndex](https://docs.llamaindex.ai/en/latest/). LlamaIndex is a framework for building LLM-powered agents, and Retrieval Augmented Generation (RAG) is a core technique for building data-backed LLM applications. LlamaIndex has public datasets consisting of queries, reference contexts, and responses. Open the file named paul_graham.json to inspect the queries, context, and responses. Open paul_graham_source.txt to get a feel for the information source written about [Paul Graham](https://www.paulgraham.com/articles.html). For your own distillation jobs, you can use datasets that are specific to your use case.

We also cover the creation of a RAG dataset using a Llama model. Readers can use this technique to generate their own datasets for distillation using Amazon Bedrock.

First, let's set up our environment and import required libraries.

In [None]:
# upgrade boto3 
%pip install --upgrade pip --quiet
%pip install boto3 --upgrade --quiet

In [None]:
# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [None]:
import json
import boto3
from datetime import datetime

# Create Bedrock client
bedrock_client = boto3.client(service_name="bedrock")

# Create runtime client for inference
bedrock_runtime = boto3.client(service_name='bedrock-runtime')

# Region and accountID
session = boto3.session.Session()
region = session.region_name
sts_client = session.client('sts')
account_id = sts_client.get_caller_identity()['Account']

Let's first test the response from the model before it is distilled. We take a question from one of the prompts of the paul_graham.json file.

In [None]:
from botocore.exceptions import ClientError

# Specify the model ID as Meta Llama 3.1 8B
model_id = "meta.llama3-8b-instruct-v1:0"

# Define the message with the prompt
message = {
    "role": "user",
    "content": [
        {"text": "In the essay, the author discusses his initial interest in AI and his eventual disillusionment with it. According to the author, what were the two main influences that initially drew him to AI and what realization led him to believe that the approach to AI during his time was a hoax?"}
    ]
}

response = bedrock_runtime.converse(
    modelId=model_id,
    messages=[message],  
    inferenceConfig={
        "maxTokens": 150,
        "temperature": 0,
        "topP": .5
    }
)

try:
    # Extract the 'text' content from the response
    text_content = response['output']['message']['content'][0]['text']
    
    # Replace newline escape sequences ('\\n') with actual newline characters
    formatted_text = text_content.replace('\\n', '\n')
    
    # Print the formatted text
    print(formatted_text)

except (ClientError, Exception) as e:
    print(f"ERROR: Unable to invoke '{model_id}'. Reason: {e}")


As expected, the response does not reflect the information in the Paul Graham essay. Compare the response above with the response from the paul_graham.json file.

In [None]:
with open('paul_graham_rag_dataset.json', 'r') as file:
    data = json.load(file)

print(data["examples"][1]["reference_answer"])

####  Model selection
When selecting models for distillation, consider:
1. Performance targets
2. Latency requirements
3. Total Cost of Ownership

In [None]:
# Setup teacher and student model pairs
teacher_model_id = "meta.llama3-1-405b-instruct-v1:0"
student_model = "meta.llama3-1-8b-instruct-v1:0:128k"

### Step 1. Configure Model Invocation Logging using the API

In this example, we only store loggings to S3 bucket, but you can optionally enable logging in Cloudwatch as well. 

In [None]:
# S3 bucket and prefix to store invocation logs
s3_bucket_for_log = ""
prefix_for_log = "" # Optional

In [None]:
def setup_s3_bucket_policy(bucket_name, prefix, account_id, region):
    s3_client = boto3.client('s3')
    
    bucket_policy = {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Sid": "AmazonBedrockLogsWrite",
                "Effect": "Allow",
                "Principal": {
                    "Service": "bedrock.amazonaws.com"
                },
                "Action": [
                    "s3:PutObject"
                ],
                "Resource": [
                     f"arn:aws:s3:::{bucket_name}/{prefix}/AWSLogs/{account_id}/BedrockModelInvocationLogs/*"
                ],
                "Condition": {
                    "StringEquals": {
                        "aws:SourceAccount": account_id
                    },
                    "ArnLike": {
                        "aws:SourceArn": f"arn:aws:bedrock:{region}:{account_id}:*"
                    }
                }
            }
        ]
    }
    
    bucket_policy_string = json.dumps(bucket_policy)
    
    try:
        response = s3_client.put_bucket_policy(
            Bucket=bucket_name,
            Policy=bucket_policy_string
        )
        print("Successfully set bucket policy")
        return True
    except Exception as e:
        print(f"Error setting bucket policy: {str(e)}")
        return False

In [None]:
# Setup bucket policy
setup_s3_bucket_policy(s3_bucket_for_log, prefix_for_log, account_id, region)

# Setup logging configuration
bedrock_client.put_model_invocation_logging_configuration(
    loggingConfig={
        's3Config': {
            'bucketName': s3_bucket_for_log,
            'keyPrefix': prefix_for_log
        },
        'textDataDeliveryEnabled': True,
        'imageDataDeliveryEnabled': True,
        'embeddingDataDeliveryEnabled': True
    }
)

### Step 2: Prepare input dataset in jsonl format

In [None]:
# Specify that the JSONL fine-tuning dataset will be named paul_graham.jsonl
output_jsonl = "paul_graham.jsonl"

# Load the original JSON file
with open('paul_graham_rag_dataset.json', 'r') as f:
    data = json.load(f)
    data=data['examples']
    with open(output_jsonl, mode="w", encoding="utf-8") as jsonl_file:
        for row in data:
            # Construct the JSON object for each line
            json_line = {
            "prompt": (
                "You are a historian charged with answering questions about Paul Graham.\n"
                "Given the context below, answer the following question.\n\n"
                f"<context>\nDocument 1: {row.get('reference_contexts', '')}\n</context>\n\n"
                f"<question>{row.get('query', '')}</question>"
            ),
            "completion": row.get('reference_answer', '')  
            }
            # Write the JSON object as a line in the JSONL file
            jsonl_file.write(json.dumps(json_line) + "\n")


Note that you need [at least 100 prompt response](https://docs.aws.amazon.com/bedrock/latest/userguide/distillation-data-prep-option-2.html) pairs. paul_graham.jsonl currently only has 44 prompt response pairs. Continue below to generate more synthetic prompt response pairs.

In [None]:
def read_text_file(file_path, encoding='utf-8'):
    """Reads a text file and returns its content."""
    with open(file_path, 'r', encoding=encoding) as file:
        return file.read()

def split_text(text, chunk_size=1500, chunk_overlap=100, separators=None):
    """Splits text into chunks with optional overlap and custom separators."""
    if separators is None:
        separators = ["\n\n", "\n", ".", " ", ""]
    
    chunks = []
    current_chunk = ""
    
    # Split text into sentences or paragraphs first
    for sep in separators:
        if sep in text:
            parts = text.split(sep)
            for part in parts:
                if len(current_chunk) + len(part) <= chunk_size:
                    current_chunk += part + sep
                else:
                    if current_chunk:
                        chunks.append(current_chunk.strip())
                        # Handle overlap
                        overlap_start = max(0, len(current_chunk) - chunk_overlap)
                        current_chunk = current_chunk[overlap_start:] + part + sep
                    else:
                        current_chunk = part + sep
            break  # Use the first separator that matches
    
    # Add the last chunk if it exists
    if current_chunk:
        chunks.append(current_chunk.strip())
    
    return chunks

# Load the text file
text_content = read_text_file("paul_graham_source.txt")

# Split the text into chunks
essay_chunks = split_text(
    text_content,
    chunk_size=1200,
    chunk_overlap=100,
    separators=["\n\n", "\n", ".", " ", ""]
)

# Print the first few chunks for verification
for i in range(0,3):
    print(f"Chunk number {i+1}:\n" + essay_chunks[i] + '\n*****\n')



Next, we need to generate sample prompts and answers. For simplicity, we generate 60 prompt and answer pairs. This process can be refined using the techniques in this [AWS Blog](https://aws.amazon.com/blogs/machine-learning/generate-synthetic-data-for-evaluating-rag-systems-using-amazon-bedrock/). 

In [None]:
import time

# Specify the model ID
# model_id = "meta.llama3-1-405b-instruct-v1:0"
model_id = "meta.llama3-70b-instruct-v1:0"

# We append the results to our JSONL dataset
for i in range(0, 67):
    # Modify the sleep time to avoid throttling
    time.sleep(15)
    output_jsonl = "paul_graham.jsonl"
    context = essay_chunks[i]
    prompt_message = {
        "role": "user",
        "content": [
            {"text": 
            f"""
            <Instructions>
            Here is some context:
            <context>
            {context}
            </context>
        
            Your task is to generate 1 question that can be answered using the provided context, following these rules:
        
            <rules>
            1. The question should make sense to humans even when read without the given context.
            2. The question should be fully answered from the given context.
            3. The question should be framed from a part of context that contains important information. It can also be from tables, code, etc.
            4. The answer to the question should not contain any links.
            5. The question should be of moderate difficulty.
            6. The question must be reasonable and must be understood and responded by humans.
            7. Do not use phrases like 'provided context', etc. in the question.
            8. Avoid framing questions using the word "and" that can be decomposed into more than one question.
            9. The question should not contain more than 10 words, make use of abbreviations wherever possible.
            </rules>
        
            To generate the question, first identify the most important or relevant part of the context. Then frame a question around that part that satisfies all the rules above.
        
            Output only the generated question with a "?" at the end, no other text or characters.
            </Instructions>
            """
            }
        ]
    }
    
    response = bedrock_runtime.converse(
        modelId=model_id,
        messages=[prompt_message],  
        inferenceConfig={
            "maxTokens": 1000,
            "temperature": 0.2,
            "topP": .5
        }
    )
    generated_prompt = response['output']['message']['content'][0]['text']
    # Replace newline escape sequences ('\\n') with actual newline characters
    question = generated_prompt.replace('\\n', '\n')
    # Print the formatted text
    print(question)

    # Generate answer
    answer_prompt_message = {
        "role": "user",
        "content": [
            {"text": 
            f"""
            <Instructions>
            <Task>
            <role>You are an experienced QA Engineer for building large language model applications.</role>
            <task>It is your task to generate an answer to the following question <question>{question}</question> only based on the <context>{context}</context></task>
            The output should be only the answer generated from the context.
        
            <rules>
            1. Only use the given context as a source for generating the answer.
            2. Be as precise as possible with answering the question.
            3. Be concise in answering the question and only answer the question at hand rather than adding extra information.
            </rules>
        
            Only output the generated answer as a sentence. No extra characters.
            </Task>
            </Instructions>
            """
            }
        ]
    }
    
    response = bedrock_runtime.converse(
        modelId=model_id,
        messages=[answer_prompt_message],  
        inferenceConfig={
            "maxTokens": 1000,
            "temperature": 0.2,
            "topP": .5
        }
    )
    
    generated_answer = response['output']['message']['content'][0]['text']
        
    # Replace newline escape sequences ('\\n') with actual newline characters
    answer = generated_answer.replace('\\n', '\n')
        
    # Print the formatted text
    print(answer)

    with open(output_jsonl, mode="a", encoding="utf-8") as jsonl_file:
        # Construct the JSON object for each line
        json_line = {
        "prompt": (
            "You are a historian charged with answering questions about Paul Graham.\n"
            "Given the context below, answer the following question.\n\n"
            f"<context>\nDocument 1: {context}\n</context>\n\n"
            f"<question>{question}</question>"
        ),
        "completion": answer
        }
        # Write the JSON object as a line in the JSONL file
        jsonl_file.write(json.dumps(json_line) + "\n")




### Step 4. Invoke teacher model to generate logs

We're using [ConverseAPI](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html) in this example, but you can also use [InvokeModel API](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html) in Bedrock.

We will invoke `Llama3.1 405b` to generate response on the `Paul Graham Essay` dataset for each input prompt

In [None]:
# Setup inference params
inference_config = {"maxTokens": 2048, "temperature": 0.1, "topP": 0.9}
request_metadata = {"job_type": "paulgraham",
                    "use_case": "RAG",
                    # Options to avoid throttling
                    "invoke_model": "llama31-405b"}
                    # "invoke_model": "llama31-70b"}

<div class="alert alert-block alert-warning">
The following code sample takes about 30mins to complete, which invokes teacher model to generate invocation logs
</div>

In [None]:
with open('paul_graham.jsonl', 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line)

        prompt = data['prompt']

        conversation = [
            {
                "role": "user",
                "content": [{"text": prompt}]
            }
        ]

        response = bedrock_runtime.converse(
            modelId=teacher_model_id,
            messages=conversation,
            inferenceConfig=inference_config,
            requestMetadata=request_metadata
        )

        response_text = response["output"]["message"]["content"][0]["text"]

### Step 3. Configure and submit distillation job using historical invocation logs

Now we have enough logs in our S3 bucket, let's configure and submit our distillation job using historical invocation logs

<div class="alert alert-block alert-warning">
Please make sure to update <b>role_arn</b> and <b>output_path</b> in the following code sample
</div>

In [None]:
# Generate unique names for the job and model
job_name = f"distillation-job-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
model_name = f"distilled-model-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"

# Set maximum response length
max_response_length = 1000

# Setup IAM role
role_arn = "" # Replace by your IAM role configured for distillation job (Update everything starting with < and ending with >)

# Invocation_logs_data
invocation_logs_data = f"s3://{s3_bucket_for_log}/{prefix_for_log}/AWSLogs"
output_path = ""

In [None]:
# Configure training data using invocation logs
training_data_config = {
    'invocationLogsConfig': {
        'usePromptResponse': True, # By default it is set as "False"
        'invocationLogSource': {
            's3Uri': invocation_logs_data
        },
        'requestMetadataFilters': { # Replace by our filter
            'equals': {"job_type": "Example"},
            'equals': {"use_case": "RAG"},
            'equals': {"invoke_model": "llama31-405b"},
        }
    }
}

In [None]:
# Create distillation job with invocation logs
response = bedrock_client.create_model_customization_job(
    # Addition
    jobName=job_name+'0',
    customModelName=model_name,
    roleArn=role_arn,
    baseModelIdentifier=student_model,
    customizationType="DISTILLATION",
    trainingDataConfig=training_data_config,
    outputDataConfig={
        "s3Uri": output_path
    },
    customizationConfig={
        "distillationConfig": {
            "teacherModelConfig": {
                "teacherModelIdentifier": teacher_model_id,
                "maxResponseLengthForInference": max_response_length
            }
        }
    }
)

### Step 4. Monitoring distillation job status

After submitted your distillation job, you can run the following code to monitor the job status

<div class="alert alert-block alert-warning">
Please be aware that distillation job could run for up to 7 days
</div>

In [None]:
# Record the distillation job arn
job_arn = response['jobArn']

# print job status
job_status = bedrock_client.get_model_customization_job(jobIdentifier=job_arn)["status"]
print(job_status)

# Hugo's addition:
print(job_arn)

<div class="alert alert-block alert-warning">
Proceed to following sections only when the status shows <b>Complete</b>
</div>

### Step 5. Deploying the Distilled Model

After distillation is complete, you'll need to set up Provisioned Throughput to use the model.

In [None]:
# Deploy the distilled model
custom_model_id = bedrock_client.get_model_customization_job(jobIdentifier=job_arn)['outputModelArn']
distilled_model_name = f"distilled-model-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"

provisioned_model_id = bedrock_client.create_provisioned_model_throughput(
    modelUnits=1,
    provisionedModelName=distilled_model_name,
    modelId=custom_model_id 
)['provisionedModelArn']

Check the provisioned throughput status, proceed until it shows **InService**

In [None]:
# print pt status
pt_status = bedrock_client.get_provisioned_model_throughput(provisionedModelId=provisioned_model_id)['status']
print(pt_status)

### Step 6. Run inference with provisioned throughput units

In this example, we use ConverseAPI to invoke the distilled model, you can use both InvokeModel or ConverseAPI to generate response.

In [None]:
# Example inference with the distilled model
input_prompt = "<Your input prompt here>"  # Replace by your input prompt

In [None]:
conversation = [ 
    {
        "role": "user", 
        "content": [{"text": input_prompt}], 
    } 
]
inferenceConfig = {
    "maxTokens": 2048, 
    "temperature": 0.1, 
    "topP": 0.9
    }

# test the deloyed model
response = bedrock_runtime.converse(
    modelId=provisioned_model_id,
    messages=conversation,
    inferenceConfig=inferenceConfig,
)
response_text = response["output"]["message"]["content"][0]["text"]
print(response_text)

### (Optional) Model Copy and Share

If you want to deploy the model to a `different AWS Region` or a `different AWS account`, you can use `Model Share` and `Model Copy` feature of Amazon Bedrock. Please check the following notebook for more information.

[Sample notebook](https://github.com/aws-samples/amazon-bedrock-samples/blob/main_archieve_10_06_2024/custom_models/model_copy/cross-region-copy.ipynb)

### Step 7. Cleanup

After you're done with the experiment, please ensure to **delete** the provisioned throughput model unit to avoid unnecessary cost.

In [None]:
response = bedrock_client.delete_provisioned_model_throughput(provisionedModelId=provisioned_model_id)

# Conclusion

In this guide, we've walked through the entire process of model distillation using Amazon Bedrock with historical model invocation logs. We covered:

1. Setting up the environment and configuring necessary AWS resources
2. Configuring model invocation logging using the API
3. Invoking the teacher model to generate logs
4. Configuring and submitting a distillation job using historical invocation logs
5. Monitoring the distillation job's progress
6. Deploying the distilled model using Provisioned Throughput
7. Running inference with the distilled model
8. Optional model copy and share procedures
9. Cleaning up resources

Remember to always consider your specific use case requirements when selecting models, configuring the distillation process, and filtering invocation logs. The ability to use actual production data from your model invocations can lead to distilled models that are highly optimized for your particular applications.

With these tools and techniques at your disposal, you're well-equipped to leverage the power of model distillation to optimize your AI/ML workflows in Amazon Bedrock.

**Happy distilling!**