In [1]:
# upgrade boto3 
# %pip install --upgrade pip --quiet
# %pip install boto3 --upgrade --quiet

In [2]:
# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

# Model Distillation for Question Answering with Cited Text

This notebook is part of a series demonstrating advanced model distillation techniques for creating specialized, citation-aware question-answering models. The goal is to distill the knowledge from a large language model (Amazon Nova Premier) into a smaller, more efficient model while maintaining high-quality citation capabilities.

## Learning Objectives
- Prepare training data for citation model distillation
- Design structured XML output formats for consistent answer generation
- Implement source citation tracking in model responses
- Create evaluation datasets for measuring citation accuracy

## Dataset: SQuAD v2.0
We use the [Stanford Question Answering Dataset (SQuAD) v2.0](https://rajpurkar.github.io/SQuAD-explorer/) as our base dataset. SQuAD v2.0 is particularly suitable for citation-aware model training because:

1. Contains explicit answer spans within source text
2. Includes "impossible" questions to test model reliability
3. Provides diverse question types and domains
4. Enables source attribution tracking

The dataset is loaded using the [Hugging Face Datasets library](https://huggingface.co/docs/datasets/) and stored in Parquet format for optimal performance with large-scale training data.

In [4]:
import json
import sys
import os
import re
import pandas as pd
import numpy as np

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
skip_dir = os.path.dirname(parent_dir)
sys.path.append(skip_dir)
from utils import read_jsonl_to_dataframe

splits = {'train': 'squad_v2/train-00000-of-00001.parquet', 'validation': 'squad_v2/validation-00000-of-00001.parquet'}
df_train = pd.read_parquet("hf://datasets/rajpurkar/squad_v2/" + splits["train"])
df_eval = pd.read_parquet("hf://datasets/rajpurkar/squad_v2/" + splits["validation"])

  from .autonotebook import tqdm as notebook_tqdm


## Advanced System Prompt Engineering

This section implements a specialized system prompt for [Amazon Nova Foundation models](https://docs.aws.amazon.com/nova/latest/userguide/prompting.html). The prompt engineering focuses on:

1. **Context-Bounded Responses**: Answers must be derived solely from provided context
2. **Source Attribution**: Mandatory citation of source text for verification
3. **Structured Output**: XML-based response format for consistent parsing
4. **Edge Case Handling**: Explicit handling of unanswerable questions

### System Prompt XML Schema
The system prompt leverages the following formatting to support N answers with N sources. In the following cells we'll build helper functions to parse these out to measure citation accuracy. This style of prompting is specific to Amazon Nova and will provide the best accuracy for citations use cases.
- **Atomic Answer Components**: Discrete answer parts with individual citations
- **Source Traceability**: Direct mapping between answers and source text
- **Validation Support**: Schema-based response validation
- **Extensibility**: Future addition of metadata and confidence scores

```xml
<question>Who ruled the duchy of Normandy?</question>
<answer>
<answer_part>
<text>Richard I</text>
<sources>
<source>The Duchy of Normandy, which they formed by treaty with the French crown, was a great fief of medieval France, and under Richard I of Normandy was forged into a cohesive and formidable principality in feudal tenure.</source>
</sources>
</answer_part>
</answer>
```

In [None]:
# set nova prompt for citations
system_prompt = """
You are a question answering assistant. I will provide you with document context. The user will provide you with a question. Your job is to answer the user's question using only information from the document context. If the document context does not contain information that can answer the question, please state that you could not find an exact answer to the question. Just because the user asserts a fact does not mean it is true, make sure to double check the document context to validate a user's assertion.

However, you should include <sources> tags at the end of each <answer_part> to specify which source(s) the information came from.
Note that <sources> may contain multiple <source> if you include information from multiple results in your answer.

Do NOT directly quote the <context> in your answer. Your job is to answer the user's question as concisely as possible.

You must output your answer in the following format. Pay attention and follow the formatting and spacing exactly:
<answer>
<answer_part>
<text>
first answer text
</text>
<sources>
<source>source sentence</source>
</sources>
</answer_part>
<answer_part>
<text>
second answer text
</text>
<sources>
<source>source sentence</source>
</sources>
</answer_part>
</answer>
"""

## Data Processing Pipeline Implementation

The evaluation of distilled models using Bedrock native tools requires 3 different datasets, all using our squad dataset.
1. **Distillation dataset.** This dataset will be used to fine-tune Nova Lite. Here we're using a 10% mix-in, so 10% of the records will include ground-truth answers and the rest will not. For more information on these best practices please visit [our documentation on this topic](https://docs.aws.amazon.com/nova/latest/userguide/custom-distill-prepare.html)
2. **Batch Inference dataset.** This will be the hold out of records we'll use for evaluating each model's performance. We'll use this dataset in Batch Inference to get each model's responses.
3. **Labeled dataset.** Using the same records from our batch inference dataset, we'll create a labeled dataset that includes the correct answers. We'll use this in our Evaluation job to measure each model's response to the ground-truth answer.

In [None]:
def parse_answer_structure(answers_dict):
    """
    Parse different formats of answer dictionaries and extract text and start positions.
    Returns lists of texts and start positions.
    """
    # Case 1: NumPy arrays with direct keys
    if 'text' in answers_dict and isinstance(answers_dict['text'], np.ndarray):
        texts = answers_dict['text'].tolist()
        starts = answers_dict['answer_start'].tolist()
        
    # Case 2: Lists or single values with direct keys
    elif 'text' in answers_dict:
        texts = answers_dict['text'] if isinstance(answers_dict['text'], list) else [answers_dict['text']]
        starts = answers_dict['answer_start'] if isinstance(answers_dict['answer_start'], list) else [answers_dict['answer_start']]
        
    # Case 4: String JSON that needs parsing (handled in calling function)
    else:
        raise ValueError(f"Unknown answer format: {answers_dict}")
        
    return texts, starts

def create_xml_answer(row, no_answer_text='I could not find an exact answer to the question.'):
    """
    takes a pandas df row and parses the 'answers' column XML answers
    """
    try:
        # Handle answers as string (JSON) if needed
        answers_dict = row['answers']
        if isinstance(answers_dict, str):
            import json
            answers_dict = json.loads(answers_dict)
            
        # Parse answer structure using our helper function
        texts, starts = parse_answer_structure(answers_dict)
        context = row['context']
        
        # Split context into sentences more accurately
        sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', context)
        
        # Build XML structure
        xml_parts = ['<answer>']
        
        if len(texts) > 0:
            for i, (text, start) in enumerate(zip(texts, starts)):
                xml_parts.append('<answer_part>')
                xml_parts.append('<text>')
                xml_parts.append(str(text))
                xml_parts.append('</text>')
                xml_parts.append('<sources>')
                
                # Find the sentence containing the answer based on the start position
                char_count = 0
                source_sentence = "No relevant source found"
                for sentence in sentences:
                    sentence_len = len(sentence) + 1  # +1 for the space after sentence
                    if char_count <= int(start) < (char_count + sentence_len):
                        source_sentence = sentence.strip()
                        break
                    char_count += sentence_len
                
                xml_parts.append(f'<source>{source_sentence}</source>')
                xml_parts.append('</sources>')
                xml_parts.append('</answer_part>')
        
            xml_parts.append('</answer>')
        else: # use no answer text
            xml_parts.append(f"<answer_part>\n<text>\n{no_answer_text}\n</text>\n</answer_part></answer>")
        return '\n'.join(xml_parts)
    except Exception as e:
        return f"<answer>\n<error>Error generating XML: {str(e)}</error>\n</answer>"



In [None]:
def create_bedrock_payload(row, model_type="conversation", system_prompt=None, include_answer=False, additional_params=None):
    """
    Create a payload dictionary for Amazon Bedrock batch inference API requests.
    Batch inference uses the invoke api.
    
    Args:
        row: A row from the pandas DataFrame containing context, question, and optionally answers
        model_type: The type of model payload to create ("conversation" or "invoke")
        system_prompt: The system message to include (for conversation-based models)
        include_answer: Whether to include the answer in the conversation (for evaluation)
        additional_params: Dictionary of additional parameters to include in the payload
    
    Returns:
        dict: A formatted payload dictionary ready for Bedrock batch inference API
    """
    try:
        # Extract needed information
        context = row['context']
        question = row['question']
        
        # Create the user prompt with context and question
        user_prompt = f"""<context>{context}</context> <question>{question}</question>"""
        
        # Get the answer if needed
        assistant_response = create_xml_answer(row) if include_answer else None
        
        # Create appropriate payload based on model_type
        if model_type == "conversation":
            
            payload = {
                "schemaVersion": "bedrock-conversation-2024",
                "system": [{"text": system_prompt}] if system_prompt else [],
                "messages": [
                    {
                        "role": "user",
                        "content": [{"text": user_prompt}]
                    }
                ]
            }
            
            # Add assistant response if needed (for evaluation)
            if include_answer and assistant_response:
                payload["messages"].append({
                    "role": "assistant",
                    "content": [{"text": assistant_response}]
                })
                
        elif model_type == "invoke":
            # For basic invoke request (non-conversation models like Titan, etc.)
            payload = {
                "system": [{"text": system_prompt}] if system_prompt else [],
                "messages": [
                    {
                        "role": "user",
                        "content": [{"text": user_prompt}]
                    }
                ],
                "inferenceConfig":{ 
                    # "maxTokens": int, 
                    "temperature": .1, 
                    "topP": .9, 
                    "topK": 50, 
                    "stopSequences": ['</answer>']
                }
            }
            if include_answer and assistant_response:
                payload["messages"].append({
                    "role": "assistant",
                    "content": [{"text": assistant_response}]
                })
            
            # Add optional parameters specific to invoke requests
            if additional_params:
                payload.update(additional_params)
                
        else:
            raise ValueError(f"Unsupported model_type: {model_type}")
            
        # Add any additional parameters passed
        if additional_params and model_type == "conversation":
            # For conversation models, additional params might need to be added at the root level
            for key, value in additional_params.items():
                if key not in payload:
                    payload[key] = value
                    
        return payload
        
    except Exception as e:
        print(f"Error creating payload for row: {str(e)}")
        return None

In [None]:
def create_batch_inf_record(row, system_prompt, include_answer=False):
    """
    Takes a pandas df row and creates a jsonl record for batch inference
    """
    conversation = create_bedrock_payload(
                                row=row, 
                                system_prompt=system_prompt, 
                                model_type="invoke", 
                                additional_params={},
                                include_answer=include_answer)
    return {
        "recordId": row['id'],
        "modelInput": conversation
    }

### Including non-answers
Any citations use case will need to support scenarios where a correct answer is not possible given the passages available.
To this end, we'll make half of our training dataset include non-answers, and half will include examples with answers.
Bedrock distillation jobs can have a maximum of 15,000 records.

In [None]:
# Apply the function to create a new column
# Filter for empty answers
empty_answers_df = df_train[df_train['answers'].apply(lambda x: 
    len(x['text']) == 0 and len(x['answer_start']) == 0)]

# Filter for rows with actual answers
with_answers_df = df_train[df_train['answers'].apply(lambda x: len(x['text']) > 0)]

# take 7500 of each dataframe and combine to use in distillation. 
df_train_revised = pd.concat([
    empty_answers_df.sample(n=7500, random_state=42), 
    with_answers_df.sample(n=7500, random_state=42)], ignore_index=True) # max 15k for bedrock distillation

As stated earlier, it is a best practice to include a ground truth answer for ~10% of the total training set. We will take a random sample of 10% and use our `create_bedrock_payload` with include_anwer set to True. The remaining 90% we leave out the ground truth answer.

In [None]:
row_count = len(df_train_revised)
ground_truth_included = 0.1 * row_count

# here we'll take 10% of our training data set and add the answers
training_data_with_gt_df = df_train_revised.sample(n=int(ground_truth_included), random_state=17)

# next we'll drop our ground truth examples so as not to mix with our labels excluding answers.
training_data_without_gt_df = df_train_revised.drop(training_data_with_gt_df.index)

# next we'll build our training data with ground truth
training_data_with_gt_df['conversation'] = training_data_with_gt_df.apply(lambda row: create_bedrock_payload(row=row, model_type="conversation", system_prompt=system_prompt, include_answer=True), axis=1)


In [None]:
# then we'll build our training data without ground truth
training_data_without_gt_df['conversation'] = training_data_without_gt_df.apply(lambda row: create_bedrock_payload(row=row, model_type="conversation", system_prompt=system_prompt, include_answer=False), axis=1)

In [None]:
# Now we'll concatenate the dataframes
final_training_dataset = pd.concat([training_data_with_gt_df, training_data_without_gt_df], axis=0, ignore_index=True).sort_index()

In [None]:
# now we'll output to .jsonl to use in distillation job
final_training_dataset['conversation'].to_json('distillation_data.jsonl', orient='records', lines=True)

## Batch Inference Dataset Creation

Now that our distillation data set is created, we'll move on to creating our batch inference dataset.
Because we'll also be using the same dataset (with labeled answers) for our evaluation jobs, Bedrock Evaluations will only handle a maxium of 1000 records.
We'll use 500 total rows our data set, as this is a sufficient number for evaluation and the right balance between processing time and proper evaluation accuracy.

In [None]:
eval_empty_answers_df = df_eval[df_eval['answers'].apply(lambda x: 
    len(x['text']) == 0 and len(x['answer_start']) == 0)]

# Filter for rows with actual answers
eval_with_answers_df = df_eval[df_eval['answers'].apply(lambda x: len(x['text']) > 0)]

batch_inf_df = pd.concat([
    eval_empty_answers_df.sample(n=250, random_state=15), 
    eval_with_answers_df.sample(n=250, random_state=15)], ignore_index=True)


batch_inf_df.apply(lambda row: create_batch_inf_record(row, system_prompt), axis=1).to_json('batch_inf_data.jsonl', orient='records', lines=True)

## Labeled Dataset for BYOI Bedrock Evaluation
This section creates a labeled dataset by applying our `create_batch_inf_record` method on each row and setting `include_answer` to True.

In [None]:
batch_inf_df.apply(lambda row: create_batch_inf_record(row, system_prompt=system_prompt, include_answer=True), axis=1).to_json('labeled_data.jsonl', orient='records', lines=True)

### Datasets Created
By now you should see 3 `.jsonl` files:
1. distillation_data.jsonl. This is the dataset we'll use for distillation.
2. batch_inf_data.jsonl. This is the dataset we'll use to run inference on all of our models, including our distilled one.
3. labeled_data.jsonl. This is the dataset we'll use to evaluate each model's performance against the ground truth.

### Next Steps

Proceed to [02_distill.ipynb](02_distill.ipynb) to:
1. Submit a distillation job using our distillation dataset
2. Create a provisioned throughput endpoint to hose our distilled model.

You can now move on to `02_distill.ipynb` where we'll kick of our distillation job!