# Building a Retrieval Augmented Fine-Tuning (RAFT) dataset

This notebook uses the PubMedQA dataset:
- Jin, Q., Dhingra, B., Liu, Z., Cohen, W., & Lu, X. (2019). PubMedQA: A Dataset for Biomedical Research Question Answering. Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), 2567–2577.


The RAFT approach is a product of this research paper:
- Zhang, T., Patil, S. G., Jain, N., Shen, S., Zaharia, M., Stoica, I., & Gonzalez, J. E. (2024). RAFT: Adapting Language Model to Domain Specific RAG. https://arxiv.org/abs/2403.10131

> Note that this notebook set doesn't yet incorporate CoT Reasoning, but it will further enhance the results.

## Introduction to RAFT

Retrieval Augmented Fine Tuning (RAFT) aims to improve the performance of large language models on domain-specific open-book question answering tasks. 

![](images/x1.png)

RAFT trains the language model to ignore "distractor" documents that do not contain relevant information to answer the given question, and instead focus on extracting the answer from the "golden" relevant documents. It also encourages the model to generate chain-of-thought style responses that cite relevant quotes from the documents, which improves the model's reasoning abilities. Experiments show that RAFT consistently outperforms standard fine-tuning and retrieval-augmented generation approaches across several specialized domains like medical literature, coding APIs, and multi-hop reasoning.

As outlined in the following diagrams, training a model on distractor data has a material impact on the final accuracy of the response.

Natural Questions|HotspotQA
--- | --
![](images/x5.png) | ![](images/x6.png)


## Outcome

In this series of notebooks you will learn how to build a RAFT dataset based on the PubMedQA dataset, then train a new generation model in a SageMaker Managed Training Job, then host the fine-tuned model on SageMaker hosting or Bedrock via Custom Model Import. After hosting the model, you will run some quick evaluations to quantify the improvement on hold out data.

## Dependencies

In [1]:
!pip uninstall -q -y autogluon-multimodal autogluon-timeseries autogluon-features autogluon-common autogluon-core
!pip install -Uq pathos==0.3.2
!pip install -Uq datasets==2.19.2
!pip install -Uq transformers==4.40.2
!pip install -Uq transformers[torch]==4.40.2
!pip install -Uq sentence_transformers==3.1.1
!pip install -Uq accelerate==1.0.0
!pip install -Uq sagemaker==2.224.1

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jupyter-ai 2.29.1 requires faiss-cpu!=1.8.0.post0,<2.0.0,>=1.8.0, which is not installed.
s3fs 2024.10.0 requires fsspec==2024.10.0.*, but you have fsspec 2024.3.1 which is incompatible.[0m[31m
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jupyter-ai 2.29.1 requires faiss-cpu!=1.8.0.post0,<2.0.0,>=1.8.0, which is not installed.
dask 2025.2.0 requires cloudpickle>=3.0.0, but you have cloudpickle 2.2.1 which is incompatible.
distributed 2025.2.0 requires cloudpickle>=3.0.0, but you have cloudpickle 2.2.1 which is incompatible.[0m[31m
[0m

In [11]:
import boto3
import json

import json
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets

from botocore.exceptions import ClientError


import sagemaker

sess = sagemaker.Session()
region = sess.boto_region_name

Process the raw PubMedQA dataset for RAFT. This involves building a set of Question/Answer/Context elements, with oracle context (all the correct context to answer the question) and distractor context (irrelevant data). The dataset will have a set of "distracted" documents where the oracle context isn't present at all along with standard documents where the oracle contexts are present but shuffled to prevent the model from learning to bias early contexts.

The combination of these elements allows the model to better discern the correct way to answer a given user query when presented with a mixed corpus of content to work with.

First, you'll pull down the PubMedQA from HuggingFace Datasets, then build a base dataset that you will use for a variety of tasks.

In [12]:
source_dataset = load_dataset("qiaojin/PubMedQA", "pqa_artificial")
source_dataset["train"][0]

{'pubid': 25429730,
 'question': 'Are group 2 innate lymphoid cells ( ILC2s ) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?',
 'context': {'contexts': ['Chronic rhinosinusitis (CRS) is a heterogeneous disease with an uncertain pathogenesis. Group 2 innate lymphoid cells (ILC2s) represent a recently discovered cell population which has been implicated in driving Th2 inflammation in CRS; however, their relationship with clinical disease characteristics has yet to be investigated.',
   'The aim of this study was to identify ILC2s in sinus mucosa in patients with CRS and controls and compare ILC2s across characteristics of disease.',
   'A cross-sectional study of patients with CRS undergoing endoscopic sinus surgery was conducted. Sinus mucosal biopsies were obtained during surgery and control tissue from patients undergoing pituitary tumour resection through transphenoidal approach. ILC2s were identified as CD45(+) Lin(-) CD127(+) CD4(-) CD8(-) CRTH2(CD294)(+) CD

In [13]:
import random

def process_dataset(input_dataset, output_filename, p=0.7, distract=3, max_items=-1):

    output_data = []

    if max_items > -1:
        print(f"max_items set, reducing input to {max_items} items.")
    else:
        max_items = len(input_dataset)
    
    for idx, item in enumerate(input_dataset.select(range(max_items))):
        
        distractor_docs = []
        
        for i in range(distract):
            distractor_element = input_dataset[random.randint(0,len(input_dataset)-1)]
            distractor_contexts = distractor_element["context"]["contexts"]
            distractor_docs.append(random.sample(sorted(distractor_contexts),1)[0])

        contexts = []
        
        #randomly select distractors
        full_distractor = random.uniform(0, 1) > p
        
        if full_distractor:
            contexts = distractor_docs
        else:
            contexts = item["context"]["contexts"] + distractor_docs
            
        random.shuffle(contexts)
        
        data_item = {
            "question": item["question"],
            "context": "\n\n".join(contexts),
            "oracle": "\n\n".join(item["context"]["contexts"]),
            "distracted": full_distractor,
            "original_answer": item["long_answer"]
        }
        output_data.append(data_item)
        
        print(f"item: {idx+1}", end="\r")
        
    #write training data to an output file
    with open(output_filename, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=4)

In [14]:
process_dataset(source_dataset["train"].shuffle(),"./data/base_data/base_data.json", max_items=40000)

max_items set, reducing input to 40000 items.
item: 40000

Then load the processed file into a dataset object and inspect one of the elements. 

You can see 5 properties in the dataset:
- `question` - The user query related to this entry.
- `oracle` - The oracle context for the given question, this is all of the correct context to generate the answer. This will be used to generate synthetic data in a following step as well as can be used to further measure the factual accuracy of generated responses.
- `context` - The combined context elements. Either consistes entirely of distractor documents, or a mix of oracle/distractor documents.
- `distracted` - A boolean flag identifying whether the context completely consists of distractor documents.
- `original_answer` - The source PubMedQA answer. Here, you will generate longer versions based on oracle context since these are typically short.

In [15]:
dataset = load_dataset("json", data_files="./data/base_data/base_data.json", split="train")

dataset[0]

Generating train split: 0 examples [00:00, ? examples/s]

{'oracle': 'The aim of the study was to search for serologic, immunopathologic, and morphologic evidence of antibody-mediated rejection (AMR) among patients with acute renal allograft dysfunction. The study included 19 patients with episodes of acute rejection (ARE) within the first year after transplantation. All patients had negative crossmatch tests before transplantation. Patients underwent biopsy for histologic and C4d examinations. All patients were monitored for donor-specific HLA alloantibodies during the first posttransplant year. Complement-dependent cytotoxic crossmatches were performed with donor lymphocytes. In eight patients, the crossmatch test results changed to positive during ARE. In all biopsies except one with cortical infarction, we observed C4d staining (group 1). The biopsies of four patients showed histologic changes of AMR, and all of their grafts were lost. In one patient, cellular and vascular rejection (Banff II) were present; in two, Banff I; and in one, bo

The `create_oracle_rag_prompts` function takes in an element from the base dataset and generates a prompt consisting of only oracle context so you can generate training data with longer answers.

In [16]:
# Convert dataset to summarization messages    
def create_oracle_rag_prompts(data_point):
    full_prompt = f"""
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        You are an assistant for question-answering tasks. Answer the following question in 5 sentences using the provided context. If you don't know the answer, just say "I don't know.".
        <|start_header_id|>user<|end_header_id|>
        Context: {data_point["oracle"]}
        
        Question: {data_point["question"]}
        <|start_header_id|>assistant<|end_header_id|>
        Answer:"""
    return {"prompt": full_prompt}

Use `dataset.map` to run the `create_oracle_rag_prompts` function on all the rows of the dataset, creating a `prompt` feature for each element that you'll use for generation. After the mapping, dump to a `generation_data.json` file in the data directory.

In [17]:
dataset = dataset.map(
    create_oracle_rag_prompts,
    batched=False
)

dataset.to_json("./data/generation_data/generation_data.json", orient="records")

Map:   0%|          | 0/40000 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/40 [00:00<?, ?ba/s]

248235961

Load the data directly from the filesystem in case you've already generated the prompt data.

In [18]:
generation_dataset = load_dataset("json", data_files="./data/generation_data/generation_data.json", split="train")
generation_dataset[0]

Generating train split: 0 examples [00:00, ? examples/s]

{'oracle': 'The aim of the study was to search for serologic, immunopathologic, and morphologic evidence of antibody-mediated rejection (AMR) among patients with acute renal allograft dysfunction. The study included 19 patients with episodes of acute rejection (ARE) within the first year after transplantation. All patients had negative crossmatch tests before transplantation. Patients underwent biopsy for histologic and C4d examinations. All patients were monitored for donor-specific HLA alloantibodies during the first posttransplant year. Complement-dependent cytotoxic crossmatches were performed with donor lymphocytes. In eight patients, the crossmatch test results changed to positive during ARE. In all biopsies except one with cortical infarction, we observed C4d staining (group 1). The biopsies of four patients showed histologic changes of AMR, and all of their grafts were lost. In one patient, cellular and vascular rejection (Banff II) were present; in two, Banff I; and in one, bo

## Generate Oracle Summary Data

In [23]:
from sagemaker.deserializers import JSONDeserializer
from sagemaker.serializers import JSONSerializer

generation_base_predictor = sagemaker.Predictor(
    endpoint_name="<<YOUR ENDPOINT HERE>>",
    sagemaker_session=sess,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

Preview the prompt being used for generation.

In [24]:
test_item = generation_dataset[2]
test_item["prompt"]

'\n        <|begin_of_text|>\n        <|start_header_id|>system<|end_header_id|>\n        You are an assistant for question-answering tasks. Answer the following question in 5 sentences using the provided context. If you don\'t know the answer, just say "I don\'t know.".\n        <|start_header_id|>user<|end_header_id|>\n        Context: To determine the long-term coronary heart disease (CHD) mortality in women and men with symptoms, according to the Rose Angina Questionnaire at a relatively young age.\n\nCohort study with the baseline survey conducted during 1974-8. Information on symptoms was collected by a short, three-item version of the Rose Angina Questionnaire. Participants were re-invited to a similar survey five years later and followed for mortality throughout 2000.\n\nThree counties in Norway (the Norwegian Counties Study).\n\n16 616 men and 16 265 women aged 40-49 years and denying CHD in 1974-8.\n\nCHD mortality during 23 years.\n\nBy the end of follow-up 1316 men (7.9%) a

Here, you will set the parameters being used for generation and insert your prompt. Then use the Bedrock `invoke.model` API to test what a generated response looks like before running against the whole dataset.

In [25]:
def send_prompt(predictor, prompt, parameters):
    # convert u/a format 
    payload = {
        "inputs": prompt,
        "parameters": parameters
    }
    response = predictor.predict(payload)
    return response

In [26]:
%%time
prompt = test_item["prompt"]

base_response = send_prompt(
    p4d_base_predictor,
    prompt,
    parameters={
        "temperature": 0.9, 
        "max_new_tokens": 512,
        "top_p": 0.9
    }
)

base_response

CPU times: user 13.1 ms, sys: 0 ns, total: 13.1 ms
Wall time: 716 ms


{'generated_text': ' Yes, the study found that Rose angina predicted 23-year coronary heart disease mortality in both women and men aged 40-49 years. The risk of CHD mortality was significantly higher in individuals with Rose angina compared to those without. In men, the adjusted hazard ratio for CHD mortality was 1.50, while in women, it was 1.98. These findings suggest that Rose angina is a significant predictor of long-term CHD mortality in this age group. The risk associated with Rose angina was comparable to the risk associated with elevated cholesterol levels and blood pressure.'}

## Generate summaries based on oracle contexts for supervised fine tuning

This step will go through the generation dataset and build longer summaries than the standard PubMedQA summaries, only using oracle contexts. This is then joined with the existing data fields to build the training dataset.

In [None]:
from IPython.display import clear_output

with open(f"../data/synthetic_data/synthetic_training_data.json", "w") as output_file:
    output_json = []
    for idx, data in enumerate(generation_dataset):
    
        model_response = send_prompt(
            generation_base_predictor,
            data["prompt"],
            parameters={
                "temperature": 0.9, 
                "max_new_tokens": 512,
                "top_p": 0.9
            }
        )
        
        # Extract and print the response text.
        response_text = model_response["generated_text"]
        #print(response_text)
    
        output_item = data

        #print(output_item)
        
        del(output_item["prompt"])
        output_item["synthetic_answer"] = response_text

        output_json.append(output_item)

        clear_output()
        print(f"{idx+1} of {len(generation_dataset)}\n\n{output_item}")
        
    output_file.write(json.dumps(output_json))