# Building a Retrieval Augmented Fine-Tuning (RAFT) dataset

This notebook uses the PubMedQA dataset:
- Jin, Q., Dhingra, B., Liu, Z., Cohen, W., & Lu, X. (2019). PubMedQA: A Dataset for Biomedical Research Question Answering. Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), 2567â€“2577.


The RAFT approach is a product of this research paper:
- Zhang, T., Patil, S. G., Jain, N., Shen, S., Zaharia, M., Stoica, I., & Gonzalez, J. E. (2024). RAFT: Adapting Language Model to Domain Specific RAG. https://arxiv.org/abs/2403.10131

> Note that this notebook set doesn't yet incorporate CoT Reasoning, but it will further enhance the results.

## Introduction to RAFT

Retrieval Augmented Fine Tuning (RAFT) aims to improve the performance of large language models on domain-specific open-book question answering tasks. 

![](images/x1.png)

RAFT trains the language model to ignore "distractor" documents that do not contain relevant information to answer the given question, and instead focus on extracting the answer from the "golden" relevant documents. It also encourages the model to generate chain-of-thought style responses that cite relevant quotes from the documents, which improves the model's reasoning abilities. Experiments show that RAFT consistently outperforms standard fine-tuning and retrieval-augmented generation approaches across several specialized domains like medical literature, coding APIs, and multi-hop reasoning.

As outlined in the following diagrams, training a model on distractor data has a material impact on the final accuracy of the response.

Natural Questions|HotspotQA
--- | --
![](images/x5.png) | ![](images/x6.png)


## Outcome

In this series of notebooks you will learn how to build a RAFT dataset based on the PubMedQA dataset, then train a new generation model in a SageMaker Managed Training Job, then host the fine-tuned model on SageMaker hosting or Bedrock via Custom Model Import. After hosting the model, you will run some quick evaluations to quantify the improvement on hold out data.

## Dependencies

In [None]:
%pip uninstall -q -y autogluon-multimodal autogluon-timeseries autogluon-features autogluon-common autogluon-core
%pip install -Uq pathos==0.3.2
%pip install -Uq datasets==2.19.2
%pip install -Uq transformers==4.40.2
%pip install -Uq transformers[torch]==4.40.2
%pip install -Uq sentence_transformers==3.1.1
%pip install -Uq accelerate==1.0.0
%pip install -Uq sagemaker==2.224.1

In [None]:
import boto3
import json

import json
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets

from botocore.exceptions import ClientError


import sagemaker

sess = sagemaker.Session()
region = sess.boto_region_name

Process the raw PubMedQA dataset for RAFT. This involves building a set of Question/Answer/Context elements, with oracle context (all the correct context to answer the question) and distractor context (irrelevant data). The dataset will have a set of "distracted" documents where the oracle context isn't present at all along with standard documents where the oracle contexts are present but shuffled to prevent the model from learning to bias early contexts.

The combination of these elements allows the model to better discern the correct way to answer a given user query when presented with a mixed corpus of content to work with.

First, you'll pull down the PubMedQA from HuggingFace Datasets, then build a base dataset that you will use for a variety of tasks.

In [None]:
source_dataset = load_dataset("qiaojin/PubMedQA", "pqa_artificial")
source_dataset["train"][0]

In [None]:
import random

def process_dataset(input_dataset, output_filename, p=0.7, distract=3, max_items=-1):

    output_data = []

    if max_items > -1:
        print(f"max_items set, reducing input to {max_items} items.")
    else:
        max_items = len(input_dataset)
    
    for idx, item in enumerate(input_dataset.select(range(max_items))):
        
        distractor_docs = []
        
        for i in range(distract):
            distractor_element = input_dataset[random.randint(0,len(input_dataset)-1)]
            distractor_contexts = distractor_element["context"]["contexts"]
            distractor_docs.append(random.sample(sorted(distractor_contexts),1)[0])

        contexts = []
        
        #randomly select distractors
        full_distractor = random.uniform(0, 1) > p
        
        if full_distractor:
            contexts = distractor_docs
        else:
            contexts = item["context"]["contexts"] + distractor_docs
            
        random.shuffle(contexts)
        
        data_item = {
            "question": item["question"],
            "context": "\n\n".join(contexts),
            "oracle": "\n\n".join(item["context"]["contexts"]),
            "distracted": full_distractor,
            "original_answer": item["long_answer"]
        }
        output_data.append(data_item)
        
        print(f"item: {idx+1}", end="\r")
        
    #write training data to an output file
    with open(output_filename, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=4)

In [None]:
process_dataset(source_dataset["train"].shuffle(),"./data/base_data/base_data.json", max_items=40000)

Then load the processed file into a dataset object and inspect one of the elements. 

You can see 5 properties in the dataset:
- `question` - The user query related to this entry.
- `oracle` - The oracle context for the given question, this is all of the correct context to generate the answer. This will be used to generate synthetic data in a following step as well as can be used to further measure the factual accuracy of generated responses.
- `context` - The combined context elements. Either consistes entirely of distractor documents, or a mix of oracle/distractor documents.
- `distracted` - A boolean flag identifying whether the context completely consists of distractor documents.
- `original_answer` - The source PubMedQA answer. Here, you will generate longer versions based on oracle context since these are typically short.

In [None]:
dataset = load_dataset("json", data_files="./data/base_data/base_data.json", split="train")

dataset[0]

The `create_oracle_rag_prompts` function takes in an element from the base dataset and generates a prompt consisting of only oracle context so you can generate training data with longer answers.

In [None]:
# Convert dataset to summarization messages    
def create_oracle_rag_prompts(data_point):
    full_prompt = f"""
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        You are an assistant for question-answering tasks. Answer the following question in 5 sentences using the provided context. If you don't know the answer, just say "I don't know.".
        <|start_header_id|>user<|end_header_id|>
        Context: {data_point["oracle"]}
        
        Question: {data_point["question"]}
        <|start_header_id|>assistant<|end_header_id|>
        Answer:"""
    return {"prompt": full_prompt}

Use `dataset.map` to run the `create_oracle_rag_prompts` function on all the rows of the dataset, creating a `prompt` feature for each element that you'll use for generation. After the mapping, dump to a `generation_data.json` file in the data directory.

In [None]:
dataset = dataset.map(
    create_oracle_rag_prompts,
    batched=False
)

dataset.to_json("./data/generation_data/generation_data.json", orient="records")

Load the data directly from the filesystem in case you've already generated the prompt data.

In [None]:
generation_dataset = load_dataset("json", data_files="./data/generation_data/generation_data.json", split="train")
generation_dataset[0]

## Generate Oracle Summary Data

In [None]:
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import JSONSerializer

generation_base_predictor = sagemaker.Predictor(
    endpoint_name="<<YOUR ENDPOINT HERE>>",
    sagemaker_session=sess,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

Preview the prompt being used for generation.

In [None]:
test_item = generation_dataset[2]
test_item["prompt"]

Here, you will set the parameters being used for generation and insert your prompt. Then use the Bedrock `invoke.model` API to test what a generated response looks like before running against the whole dataset.

In [None]:
def send_prompt(predictor, prompt, parameters):
    # convert u/a format 
    payload = {
        "inputs": prompt,
        "parameters": parameters
    }
    response = predictor.predict(payload)
    return response

In [None]:
%%time
prompt = test_item["prompt"]

base_response = send_prompt(
    p4d_base_predictor,
    prompt,
    parameters={
        "temperature": 0.9, 
        "max_new_tokens": 512,
        "top_p": 0.9
    }
)

base_response

## Generate summaries based on oracle contexts for supervised fine tuning

This step will go through the generation dataset and build longer summaries than the standard PubMedQA summaries, only using oracle contexts. This is then joined with the existing data fields to build the training dataset.

In [None]:
from IPython.display import clear_output

with open(f"../data/synthetic_data/synthetic_training_data.json", "w") as output_file:
    output_json = []
    for idx, data in enumerate(generation_dataset):
    
        model_response = send_prompt(
            generation_base_predictor,
            data["prompt"],
            parameters={
                "temperature": 0.9, 
                "max_new_tokens": 512,
                "top_p": 0.9
            }
        )
        
        # Extract and print the response text.
        response_text = model_response["generated_text"]
        #print(response_text)
    
        output_item = data

        #print(output_item)
        
        del(output_item["prompt"])
        output_item["synthetic_answer"] = response_text

        output_json.append(output_item)

        clear_output()
        print(f"{idx+1} of {len(generation_dataset)}\n\n{output_item}")
        
    output_file.write(json.dumps(output_json))