# Amazon Bedrock Model Distillation for Citations

> ⚠️ **CRITICAL COST WARNING**
>
> This notebook uses Amazon Bedrock Provisioned Throughput (PT) which has significant cost implications. As of this writing, a PT endpoints are required in order to make inferences to a distilled model.
>
> - Provisioned units are billed for the entire month regardless of actual usage
> - You cannot cancel provisioned capacity until the month ends
> - 
> 
> **Before running this notebook:**
> 1. Please read and understand the cost implications of [creating a Provisioned Throughput endpoint](https://docs.aws.amazon.com/bedrock/latest/userguide/prov-throughput.html?utm_source=chatgpt.com)
> 2. Check current pricing at [AWS Bedrock Pricing](https://aws.amazon.com/bedrock/pricing/)
>
> 💡 **Best Practice**: Delete or release provisioned capacity when not actively needed to avoid unnecessary charges.

## Learning Objectives

After completing this notebook, you will be able to:
1. Implement advanced model distillation techniques using Amazon Bedrock's Distilation APIs
2. Configure and optimize teacher-student model architectures for citation tasks
3. Monitor and evaluate distillation performance metrics
4. Deploy and manage production-grade distilled models

## Introduction

Model distillation is an advanced knowledge transfer technique that enables the creation of efficient, production-ready models by distilling knowledge from larger foundation models into smaller, specialized ones. This notebook demonstrates enterprise-grade implementation of model distillation in Amazon Bedrock, focusing on citation generation use cases.

### Setup and Prerequisites
**By now you should have already ran notebook `01_prepare_data.ipynb` to prepare all of the data for subsequent steps.**

Before proceeding, ensure you have:

- An active AWS account with appropriate permissions
- Amazon Bedrock access enabled in your preferred region
- An S3 bucket for storing training data and output
- Training data in JSONL format
- Sufficient service quota to use Provisioned Throughput in Bedrock
- An IAM role with the following permissions:

IAM Policy:
```json
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": [
                "s3:GetObject",
                "s3:PutObject",
                "s3:ListBucket"
            ],
            "Resource": [
                "arn:aws:s3:::YOUR_DISTILLATION_OUTPUT_BUCKET",
                "arn:aws:s3:::YOUR_DISTILLATION_OUTPUT_BUCKET/*",
            ]
        },
        {
            "Effect": "Allow",
            "Action": [
                "bedrock:CreateModelCustomizationJob",
                "bedrock:GetModelCustomizationJob",
                "bedrock:ListModelCustomizationJobs",
                "bedrock:StopModelCustomizationJob"
            ],
            "Resource": "arn:aws:bedrock:YOUR_REGION:YOUR_ACCOUNT_ID:model-customization-job/*"
        }
    ]
}
```

Trust Relationship:
```json
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Principal": {
                "Service": [
                    "bedrock.amazonaws.com"
                ]
            },
            "Action": "sts:AssumeRole",
            "Condition": {
                "StringEquals": {
                    "aws:SourceAccount": "YOUR_ACCOUNT_ID"
                },
                "ArnLike": {
                    "aws:SourceArn": "arn:aws:bedrock:YOUR_REGION:YOUR_ACCOUNT_ID:model-customization-job/*"
                }
            }
        }
    ]
}
```

First, let's set up our environment and import required libraries.

In [None]:
# upgrade boto3 
%pip install --upgrade pip --quiet
%pip install boto3 --upgrade --quiet

In [None]:
# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [None]:
import json
import sys
import os

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
skip_dir = os.path.dirname(parent_dir)
sys.path.append(skip_dir)

import boto3
from datetime import datetime
from botocore.exceptions import ClientError
from utils import create_s3_bucket, upload_training_data_to_s3, delete_s3_bucket_and_contents, \
create_model_distillation_role_and_permissions, delete_role_and_attached_policies, delete_distillation_buckets

# Create Bedrock client
bedrock_client = boto3.client(service_name="bedrock",region_name='us-east-1')

# Create runtime client for inference
bedrock_runtime = boto3.client(service_name='bedrock-runtime',region_name='us-east-1')

# Region and accountID
session = boto3.session.Session(region_name='us-east-1')
region =  'us-east-1' # session.region_name
sts_client = session.client(service_name='sts',region_name='us-east-1')
account_id = sts_client.get_caller_identity()['Account']

# define bucket you want to create and upload the dataset to:
BUCKET_NAME= '<BUCKET_NAME>' # Replace by your bucket name
DATA_PREFIX = 'citations_distillation' # Replace by your defined prefix

# configure teacher nd student model
teacher_model = "us.amazon.nova-premier-v1:0"
student_model = "amazon.nova-lite-v1:0:300k"

In [None]:
# Generate unique names for the job and model
distillation_dataset = 'distillation_data.jsonl'
current_dt = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
job_name = f"distill-citations-{current_dt}"
model_name = f"distilled-citations-{current_dt}"


In [None]:

# Configure models and IAM role
role_name, role_arn = create_model_distillation_role_and_permissions(bucket_name=BUCKET_NAME, account_id=account_id)

# creating training data bucket
create_s3_bucket(bucket_name=BUCKET_NAME)

# Specify S3 locations
training_data_s3_uri = upload_training_data_to_s3(BUCKET_NAME, distillation_dataset, prefix=DATA_PREFIX)
output_path = f"s3://{BUCKET_NAME}/{DATA_PREFIX}/outputs/"

# Set maximum response length
max_response_length = 1000

# Starting the Distillation Job

With our environment configured and data prepared, we'll initiate the distillation process. This section covers:

1. Job Configuration
   - Model selection and parameters
   - Resource allocation
   - Output settings

2. Performance Optimization
   - Response length tuning
   - Batch size configuration
   - Resource utilization

3. Monitoring Setup
   - Metrics configuration
   - Logging settings
   - Alert thresholds

We'll use the `create_model_customization_job` API with production-optimized settings.

In [None]:
# need to add wait to let the role be persisted here
response = bedrock_client.create_model_customization_job(
    jobName=job_name,
    customModelName=model_name,
    roleArn=role_arn,
    baseModelIdentifier=student_model,
    customizationType="DISTILLATION",
    trainingDataConfig={
        "s3Uri": training_data_s3_uri
    },
    outputDataConfig={
        "s3Uri": output_path
    },
    customizationConfig={
        "distillationConfig": {
            "teacherModelConfig": {
                "teacherModelIdentifier": teacher_model,
                "maxResponseLengthForInference": max_response_length 
            }
        }
    }
)

In [None]:
# Record the distillation job arn
job_arn = response['jobArn']
print("job arn", job_arn)

# print job status
job_status = bedrock_client.get_model_customization_job(jobIdentifier=job_arn)["status"]
print(job_status)

A distillation job of this size (15000 records) can take a number of hours to complete. Return to this notebook once that job has completed by running the above cell.

Once distillation is complete, we're ready to deploy to our PT endpoint.

⚠️ **Please understand the cost considerations for Provisioned Throughput (PT) endpoints before proceeding**


# Deploying the Distilled Model

Now we'll deploy our distilled model to a PT endpoint by grabbing the model name from the distillation job and creating a no-commit, 1-MU PT endpoint.

In [None]:
# Deploy the distilled model
custom_model_id = bedrock_client.get_model_customization_job(jobIdentifier=job_arn)['outputModelArn']
distilled_model_name = model_name

provisioned_model_id = bedrock_client.create_provisioned_model_throughput(
    modelUnits=1,
    provisionedModelName=distilled_model_name,
    # commitmentDuration # ommitted for no-commit
    modelId=custom_model_id 
)['provisionedModelArn']

In [None]:
provisioned_model_id

Store the provisioned model endpoint ARN for subsequent inference operations and resource cleanup.

In [None]:
%store provisioned_model_id
%store custom_model_id

# Conclusion

This notebook has demonstrated advanced implementation techniques for model distillation in Amazon Bedrock, with a focus on citation generation use cases. 

## Next Steps

In the next notebook ([03_batch_inference.ipynb](03_batch_inference.ipynb)), we'll explore:
1. Implementing batch inference with the distilled model
2. Evaluating citation accuracy and performance metrics
3. Optimizing throughput and latency
4. Monitoring production workloads

For additional resources:
- [Amazon Bedrock Model Distillation Documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-distillation.html)
- [Best Practices for Production Deployments](https://docs.aws.amazon.com/bedrock/latest/userguide/best-practices.html)
- [Advanced Monitoring and Optimization](https://docs.aws.amazon.com/bedrock/latest/userguide/monitoring-customization.html)