# Response-Based Knowledge Distillation with QA Specialization

## Using Amazon SageMaker JumpStart for LLM distilation (70B → 3B)

## 1. Introduction

### Overview of Knowledge Distillation 
### SageMaker JumpStart Benefits 
### Project Goals and Objectives 

In [None]:
%pip install --quiet --upgrade sagemaker jmespath datasets transformers jinja2 fmeval ipywidgets

## 2. Environment Setup

### AWS Account Configuration 

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    #change the name of the role if you are running locally
    role = iam.get_role(RoleName='AmazonSageMaker-ExecutionRole-20220929T161862')['Role']['Arn']

sess = sagemaker.Session(default_bucket=bucket)
region=sess.boto_region_name

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {bucket}")
print(f"sagemaker session region: {sess.boto_region_name}")


In [None]:
prefix = "llama-qa-distillation"

# Print AWS configuration
print(f"SageMaker Session: {sess}")
print(f"Role: {role}")
print(f"Region: {region}")
print(f"Bucket: {bucket}")

### Role configuration

[TODO] This section includes adding permisions to current role based on requirements for bedrock

## 3. Teacher Model (LLaMA 3.3 70B)
### Selecting Model in Bedrock 

In [None]:
import logging
import json
import boto3
import pandas as pd
from botocore.exceptions import ClientError

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def list_foundation_models(bedrock_client):
    """
    Gets a list of available Amazon Bedrock foundation models.

    :return: The list of available bedrock foundation models.
    """
    try:
        response = bedrock_client.list_foundation_models()
        models = response["modelSummaries"]
        logger.info("Got %s foundation models.", len(models))
        return models

    except ClientError:
        logger.error("Couldn't list foundation models.")
        raise

def create_models_dataframe(models):
    """
    Creates a pandas DataFrame with relevant model information.
    
    :param models: List of model summaries from Bedrock
    :return: pandas DataFrame with model information
    """
    model_data = []
    
    for model in models:
        model_info = {
            'Model Name': model['modelName'],
            'Provider': model['providerName'],
            'Model ID': model['modelId'],
            'Input Modalities': ', '.join(model['inputModalities']),
            'Output Modalities': ', '.join(model['outputModalities']),
            'Customizations Supported': ', '.join(model['customizationsSupported']) if 'customizationsSupported' in model else 'None',
            'Inference Types': ', '.join(model['inferenceTypesSupported'])
        }
        model_data.append(model_info)
    
    df = pd.DataFrame(model_data)
    return df

In [None]:
bedrock_client = boto3.client(service_name="bedrock",region_name="us-east-2")
fm_models = list_foundation_models(bedrock_client)

# Create DataFrame
models_df = create_models_dataframe(fm_models)

# Display the DataFrame
print("\nAmazon Bedrock Foundation Models:")
print(models_df.to_string(index=False))

# Optionally, you can also save to CSV
# models_df.to_csv('bedrock_models.csv', index=False)

logger.info("Done.")

In [None]:

models_df['Model ID'].to_list()

In [None]:
models_df[(models_df['Provider']=='Meta') & (models_df['Inference Types']=='INFERENCE_PROFILE')]['Model ID']

### Bedrock client setup

In order to use 'INFERENCE_PROFILE' models you need to create an inference profile, you dont need that for 'ON_DEMAND' models

[Note] LLama 3.3 70b only works in us-east-2 not sure if is an issue

In [None]:
import boto3
#only us-east-2 let you use llama 3.3 70b us-west-2 fails
bedrock_client = boto3.client('bedrock', region_name="us-east-2")

In [None]:
model_id='meta.llama3-3-70b-instruct-v1:0'
inference_profile_name='llama3-3-70b-inference'
inf_profile_response = bedrock_client.create_inference_profile(
    inferenceProfileName=inference_profile_name,
    description='Teacher model use for syntetic generation in a Llama distilation project',
    modelSource={
        'copyFrom': f'arn:aws:bedrock:us-east-2::foundation-model/{model_id}'
    },
    tags=[
        {
        'key': 'project',
            'value': 'Llama-model-distilation'
        },
        {
        'key': 'model-id',
            'value': 'meta.llama3-3-70b-instruct'
        },
    ]
)

In [None]:
print(f"Inference profile created successfully. ARN: {inf_profile_response['inferenceProfileArn']}")
model_arn=inf_profile_response['inferenceProfileArn']

In [None]:
inf_profile_response

### Testing Inference 

In [None]:
brt = boto3.client(service_name='bedrock-runtime',region_name='us-east-2')
def invoke_model(body, model_id, accept, content_type):
    try:
        response = brt.invoke_model(
            body=json.dumps(body), 
            modelId=model_id, 
            
            accept=accept, 
            contentType=content_type
        )

        return response

    except Exception as e:
        print(f"Couldn't invoke {model_id}")
        raise e

In [None]:
# If you'd like to try your own prompt, edit this parameter!
prompt_data = """<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Write me a blog about making strong business decisions as a leader. [/INST]"""

body = {
    "prompt": prompt_data,
    "temperature": 0.5,
    "top_p": 0.9,
    "max_gen_len": 512,
}

modelId = "us.meta.llama3-3-70b-instruct-v1:0"
accept = "application/json"
contentType = "application/json"

response = invoke_model(body, modelId, accept, contentType)
response_body = json.loads(response.get("body").read())

print(response_body["generation"])

## 4. Dataset Generation

### Corpus Preparation(Prepare QA + Context dataset)

PreaApproved dataset details:

https://github.com/pubmedqa/pubmedqa/tree/master




> **PubMedQA: A Dataset for Biomedical Research Question Answering**
>
> Jin, Q., Dhingra, B., Liu, Z., Cohen, W., & Lu, X. (2019). In *Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)*, pp. 2567-2577.

In [None]:
import requests

def get_github_json(url):
    try:
        # Convert regular GitHub URL to raw content URL
        raw_url = url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
        return requests.get(raw_url).json()
    except Exception as e:
        print(f"Error: {e}")
        return None

# Example usage:
url = "https://github.com/pubmedqa/pubmedqa/blob/master/data/ori_pqal.json"
data = get_github_json(url)

In [None]:
qa_index=list(data.keys())

In [None]:
print(len(qa_index))

In [None]:
data[qa_index[0]]

In [None]:
dataset=[]
for i in qa_index:
    keys_to_get = ['QUESTION', 'CONTEXTS','LONG_ANSWER']
    result = {k: data[i].get(k) for k in keys_to_get}
    dataset.append(result)

In [None]:
dataset[0]

### Using Teacher Model for QA Generation(Create more data based on the Questions and the context)

After some testing the best option is to use batch in Bedrock, sending real time invocations it is slower and can throtle the API and also could be expensive that sendina a batch job.

In [None]:
import json
from datetime import datetime
import uuid

def create_bedrock_batch_dataset(dataset, output_file='bedrock_batch_dataset.jsonl'):
    system_message = """You are a specialized biomedical research assistant trained to analyze and answer questions about medical and scientific literature. Your role is to:
        Extract and interpret key information from biomedical research papers, clinical studies, and medical literature
        Provide accurate, evidence-based responses based solely on the provided research context
        Focus on specific medical findings, methodologies, and clinical outcomes
        Present complex medical information in clear, understandable terms
        Maintain precision when discussing medical terminology, study results, and statistical data
        Distinguish between preliminary findings and established conclusions
        Reference specific sections of the provided research when answering questions
        Acknowledge limitations in studies when relevant
        Avoid making medical recommendations or providing diagnosis When responding, only use information explicitly stated in the provided biomedical context."""

    prompt_template = """System: {system}

Question: {question}

Provide a clear and concise answer."""
    
    with open(output_file, 'w') as outfile:
        for sample in dataset:
            # Generate a unique record ID (11 characters)
            record_id = str(uuid.uuid4())[:11]
            
            # Format the prompt
            formatted_prompt = prompt_template.format(
                system=system_message,
                question=sample["QUESTION"]
            )

            # Create the model input body for Llama 2
            body = {
                "prompt": formatted_prompt,
                "max_gen_len": 512,
                "temperature": 0.7,
                "top_p": 0.9
            }

            # Create the complete record for batch inference
            batch_record = {
                "recordId": record_id,
                "modelInput": body
            }
            
            outfile.write(json.dumps(batch_record) + '\n')

# Usage
create_bedrock_batch_dataset(dataset)

In [None]:
import sagemaker
from sagemaker.s3 import S3Uploader
# Define source and destination paths
local_path_batch_file = 'bedrock_batch_dataset.jsonl'
s3_prefix_batch = 'distillation/batch/data'  # This will be the folder in S3

# Upload the file
s3_path_batch = S3Uploader.upload(
    local_path=local_path_batch_file,
    desired_s3_uri=f's3://{bucket}/{s3_prefix_batch}',
)

print(f"File uploaded successfully to: {s3_path_batch}")

The next code is creating the Bedrock Batch job to create the updated dataset with the teacher knowledge, Llama 3.3 provide the answers to the questions.

[TODO] Create more data with other synthetic generation methods or add validation of answers generated by Bedrock using context/answer as a groundtruth

In [None]:
output_prefix="output"
inputDataConfig=({
    "s3InputDataConfig": {
        "s3Uri": s3_path_batch
    }
})

outputDataConfig=({
    "s3OutputDataConfig": {
        "s3Uri": f"s3://{bucket}/{s3_prefix_batch}/{output_prefix}/"
    }
})

In [None]:
jobName = 'batch-job-ga' + str(int(datetime.now().timestamp()))
response=bedrock_client.create_model_invocation_job(
    roleArn=role,
    #modelId='meta.llama3-3-70b-instruct-v1:0',
    modelId='us.meta.llama3-3-70b-instruct-v1:0',
    jobName=jobName,
    inputDataConfig=inputDataConfig,
    outputDataConfig=outputDataConfig
)

In [None]:
import time
jobArn = response.get('jobArn')
job_id = jobArn.split('/')[1]

print(jobArn)

status = ''
while status not in ['Completed', 'Failed']:
    job_response = bedrock_client.get_model_invocation_job(jobIdentifier=jobArn)
    status = job_response['status']
    if status == 'Failed':
        print(job_response)
    elif status == 'Completed':
        print(datetime.now(), ": ", status)
        break
    else: 
        print(datetime.now(), ": ", status)
        time.sleep(300)

### Dataset Formatting for JumpStart/Bedrock (create a dataset compatible with Jumpstart and Bedrock, LLM template format)

In [None]:
# Create an S3 client
s3 = boto3.client('s3')
prefix = f"{s3_prefix_batch}/{output_prefix}/{job_id}/"
print(f"prefix: {bucket}/{prefix}")
object_key = f"{prefix}{local_path_batch_file}.out"

In [None]:
response = s3.get_object(Bucket=bucket, Key=object_key)

In [None]:
json_data = response['Body'].read().decode('utf-8')

In [None]:
teacher_answer=[]
for line in json_data.splitlines():
        data = json.loads(line)
        print(data['modelOutput']['generation'])
        teacher_answer.append(data['modelOutput']['generation'])

In [None]:
len(teacher_answer)

In [None]:
teacher_answer[0]

In [None]:
dataset[0]['LONG_ANSWER']

In [None]:
for data_item, teacher in zip(dataset, teacher_answer):
    data_item['TEACHER_ANSWER'] = teacher

The next function creates the jsonl dataset compatible with SageMaker Jumpstart input format for chat template

In [None]:
import json

def create_qa_training_data(dataset, output_file='train.jsonl', max_samples=5000):
    """
    Transform the dataset into JSONL format with a dialog structure including a system message for Llama fine-tuning.

    Args:
        dataset: List of dictionaries containing 'QUESTION', 'CONTEXTS', 'LONG_ANSWER', 'TEACHER_ANSWER'
        output_file: Output JSONL file path
        max_samples: Maximum number of samples to include
    """
    # Define the system message
    system_message = """You are a specialized biomedical research assistant trained to analyze and answer questions about medical and scientific literature. Your role is to:
    - Extract and interpret key information from biomedical research papers, clinical studies, and medical literature
    - Provide accurate, evidence-based responses based solely on the provided research context
    - Focus on specific medical findings, methodologies, and clinical outcomes
    - Present complex medical information in clear, understandable terms
    - Maintain precision when discussing medical terminology, study results, and statistical data
    - Distinguish between preliminary findings and established conclusions
    - Reference specific sections of the provided research when answering questions
    - Acknowledge limitations in studies when relevant
    - Avoid making medical recommendations or providing diagnosis
    When responding, only use information explicitly stated in the provided biomedical context."""

    # Limit the number of samples if specified
    #dataset = dataset[:max_samples] if max_samples else dataset

    with open(output_file, 'w', encoding='utf-8') as f:
        for item in dataset:
            try:
                # Create the dialog structure with system message
                dialog = [
                    {
                        "content": f"<<SYS>>\n{system_message}\n<</SYS>>\n\n{item['QUESTION']}",
                        "role": "user"
                    },
                    {
                        "content": item['TEACHER_ANSWER'],
                        "role": "assistant"
                    }
                ]
                
                # Create the JSON object
                json_object = {
                    "dialog": dialog
                }
                
                # Write the JSON line
                f.write(json.dumps(json_object) + '\n')
            except KeyError as e:
                print(f"Skipping item due to missing key: {e}")
                continue

def verify_jsonl(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            try:
                data = json.loads(line)
                if i == 0:  # Print first example
                    print("Sample entry:")
                    print(json.dumps(data, indent=2))
                break
            except json.JSONDecodeError as e:
                print(f"Error in line {i+1}: {e}")


# Usage example
create_qa_training_data(dataset, output_file='train.jsonl', max_samples=5000)
verify_jsonl('train.jsonl')

## 5. Student Model (LLaMA 3B)

### Upload dataset to S3 bucket

In [None]:
from sagemaker.s3 import S3Uploader
import sagemaker
import random


default_bucket_prefix = sagemaker.Session().default_bucket_prefix
default_bucket_prefix_path = ""

# If a default bucket prefix is specified, append it to the s3 path
if default_bucket_prefix:
    default_bucket_prefix_path = f"/{default_bucket_prefix}"

local_data_file = "train.jsonl"
train_data_location = f"s3://{bucket}{default_bucket_prefix_path}/oasst_top1"
S3Uploader.upload(local_data_file, train_data_location)
print(f"Training data: {train_data_location}")

### Selecting Student Model in JumpStart 

In [None]:
from ipywidgets import Dropdown
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models


try:
    dropdown = Dropdown(
        options=list_jumpstart_models("search_keywords includes Text Generation"),
        value="meta-textgeneration-llama-3-2-1b",
        description="Select a JumpStart text generation model:",
        style={"description_width": "initial"},
        layout={"width": "max-content"},
    )
    display(dropdown)
except:
    dropdown = None
    pass

In [None]:
if dropdown:
    student_model_id = dropdown.value
else:
    # Provide model id as meta-textgeneration-llama-3-1-405b-instruct-fp8 for the instruct variant
    model_id = "meta-textgeneration-llama-3-2-1b"
model_version_student = "*"

### Configuring Training Job 

In [None]:
from sagemaker import hyperparameters

my_hyperparameters_student = hyperparameters.retrieve_default(
    model_id=student_model_id, model_version=model_version_student,
)

print(my_hyperparameters_student)

### Hyperparameters  

In [None]:
my_hyperparameters_student["epoch"] = "1"
my_hyperparameters_student['chat_dataset']="True"
my_hyperparameters_student['instruction_tuned']="False"
print(my_hyperparameters_student)

hyperparameters.validate(
    model_id=student_model_id, model_version=model_version_student, hyperparameters=my_hyperparameters_student
)

### Launching Training Job 

In [None]:
from sagemaker.jumpstart.estimator import JumpStartEstimator


estimator = JumpStartEstimator(
    model_id=student_model_id,
    model_version=model_version_student,
    hyperparameters=my_hyperparameters_student,
    role=role,
    disable_output_compression=True,
    environment={
        "accept_eula": "true"
    },  # please change `accept_eula` to be `true` to accept EULA.
)

estimator.fit({"training": train_data_location})

## 6. Evaluation

### Deploying Student Model Endpoint or Bedrock CMI 
### Comparative Testing (Teacher vs Student) 
### Performance Metrics Analysis

## 7. Optimization and Tuning(This section can me ommited)

### Fine-tuning Hyperparameters 
### JumpStart Model Retraining 
### Performance Improvement Strategies 

## 8. Production Deployment(Jumpstart or Bedrock)

### Endpoint Configuration 
### Scaling and Cost Management 
### Monitoring Setup 


# 9. Cleanup and Best Practices

### Resource Termination 
### Cost Optimization Tips 
### JumpStart Best Practices 

# 10. Conclusion and Next Steps

### Summary of Results 
### Lessons Learned 
### Future Improvements 