## Build a Retrieval Augment Generation Solution with a Fine Tuned Model
Now that we have a fine-tuned model, it's time to revisit our retrieval augment generation solution. In this notebook, we will re-implement the initial pipeline but this time using the fine-tuned model to generate the responses. We will then evaluate the performance of the fine-tuned model and compare it with the initial model.

In [None]:
import sys
import os
module_path = "../.."
sys.path.append(os.path.abspath(module_path))
from utils.environment_validation import validate_environment, validate_model_access
validate_environment()

In [None]:
required_models = [
    "amazon.titan-embed-text-v2:0",
    "mistral.mixtral-8x7b-instruct-v0:1",
    "mistral.mistral-7b-instruct-v0:2",
]
validate_model_access(required_models)

In [None]:
from pathlib import Path
from itertools import chain
from rich import print as rprint
import json
import os
from langchain_core.documents import Document
from langchain_aws.llms import SagemakerEndpoint
from langchain_aws.chat_models import ChatBedrockConverse
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler
from langchain_aws.embeddings import BedrockEmbeddings
from langchain_community.vectorstores import FAISS
import boto3
from concurrent.futures import ThreadPoolExecutor
from typing import Dict

import pickle
from io import BytesIO
from pathlib import Path

import asyncio
import nest_asyncio
nest_asyncio.apply()

data_path = Path("data/prepared_data")
train_data = (data_path / "prepared_data_train.jsonl").read_text().splitlines()
test_data = (data_path / "prepared_data_test.jsonl").read_text().splitlines()

doc_ids = []
documents = []

# Create a list of LangChain documents that can then be used to ingest into a vector store

for record in chain(train_data, test_data):
    json_record = json.loads(record)
    if json_record["ref_doc_id"] not in doc_ids:
        doc_ids.append(json_record["ref_doc_id"])
        doc = Document(page_content=json_record["context"], metadata=json_record["section_metadata"])
        documents.append(doc)

print(f"Loaded {len(documents)} sections")

In [None]:
boto3_session=boto3.session.Session()

bedrock_runtime = boto3_session.client("bedrock-runtime")
bedrock_client = boto3_session.client("bedrock")
smr_client = boto3_session.client("sagemaker-runtime")

embedding_modelId = "amazon.titan-embed-text-v2:0"

embed_model = BedrockEmbeddings(
    model_id=embedding_modelId,
    model_kwargs={"dimensions": 1024, "normalize": True},
    client=bedrock_runtime,
)

query = "Do I really need to fine-tune the large language models?"
response = embed_model.embed_query(query)
rprint(f"Generated an embedding with {len(response)} dimensions. Sample of first 10 dimensions:\n", response[:10])

In [None]:
# we can resuse the vector db from the initial model since we're keeping embeddings the same
vector_store_file = "baseline_rag_vec_db.pkl"

if not Path(vector_store_file).exists():
    rprint(f"Vector store file {vector_store_file} does not exist. Will create a new vector store.")
    CREATE_NEW = True
else:
    rprint(f"Vector store file {vector_store_file} already exists. Delete it or change the file name above to create a new vector store.")
    CREATE_NEW = False 

if CREATE_NEW:
    vec_db = FAISS.from_documents(documents, embed_model)
    pickle.dump(vec_db.serialize_to_bytes(), open(vector_store_file, "wb"))
    
else:
    if not Path(vector_store_file).exists():
        raise FileNotFoundError(f"Vector store file {vector_store_file} not found. Set CREATE_NEW to True to create a new vector store.")
    
    vector_db_buff = BytesIO(pickle.load(open(vector_store_file, "rb")))
    vec_db = FAISS.deserialize_from_bytes(serialized=vector_db_buff.read(), embeddings=embed_model, allow_dangerous_deserialization=True)

In [None]:
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever

k = 3
faiss_retriever = vec_db.as_retriever(search_kwargs={"k": k})

bm_25 = BM25Retriever.from_documents(documents)
bm_25.k = k


ensemble_retriever = EnsembleRetriever(
    retrievers=[faiss_retriever, bm_25], weights=[0.75, 0.25]
)

### Using SageMaker endpoints with LangChain
To use a SageMaker endpoint with LangChain we need to implement a `ContentHandler` class that will handle the preprocessing of the input data and the postprocessing of the output data. The primary reason for this is unlike Bedrock, a SageMaker API may have different and inconsistent input and output payload formats.

In [None]:
sagemaker_endpoint_config_path = Path("endpoint_config.json")
if sagemaker_endpoint_config_path.exists():
    endpoint_name = json.loads(sagemaker_endpoint_config_path.read_text())["endpoint_name"]
else:
    rprint(f"[bold red]Endpoint config file {sagemaker_endpoint_config_path} not found. Please make sure that you ran the prior training and deployment notebooks[/bold red]")

In [None]:
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        """formats the request payload for the model endpoint"""
        
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        """extracts the generated answer from the endpoint response"""
        
        output = output.read().decode("utf-8")
        generated_answer = json.loads(output)["generated_text"]
        return generated_answer

content_handler = ContentHandler()

tuned_llm = SagemakerEndpoint(
        endpoint_name=endpoint_name,
        client=smr_client,
        model_kwargs={"temperature": 0, "max_new_tokens": 500},
        content_handler=content_handler,
    )

With the configured LLM we can reimplement the retrieval augment generation pipeline using the fine-tuned model.

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from operator import itemgetter

tuned_prompt_template = "[INST] You are a Banking Regulations expert.\nGiven this context\nCONTEXT\n{context}\n Answer this question\nQuestion: {question} [/INST]"


prompt = PromptTemplate.from_template(tuned_prompt_template)
output_parser = StrOutputParser()

setup_and_retrieval = RunnableParallel(
    {"context": ensemble_retriever, "question": RunnablePassthrough()}
)

generate_tuned_answer = {"answer": prompt | tuned_llm | output_parser,
                   "context": itemgetter("context")}

tuned_chain = setup_and_retrieval | generate_tuned_answer

In [None]:
# get a sample response 

sample_record = json.loads(test_data[150])
sample_question = sample_record["question"]
sample_answer = sample_record["answer"]
rprint(f"Sample question: {sample_question}")
response = tuned_chain.invoke(sample_question)
generated_answer = response["answer"]
rprint(f"\nGenerated answer: {generated_answer}")
rprint(f"\nGround truth answer: {sample_answer}")

In [None]:
# helper functions to speed up the inference and evaluation process

async def generate_answer_async(rag_chain, example):
    example = json.loads(example)
    response = await rag_chain.ainvoke(example["question"])
    contexts = [doc.page_content for doc in response["context"]]
    row = {"question": example["question"], "answer": response["answer"], "contexts": contexts, "ground_truth": example["answer"]}
    return row

async def evaluate_llm_async(metric, rows):
    sem = asyncio.Semaphore(10)

    async def limited_invoke(row):
        async with sem:
            return await metric.ainvoke(row)

    tasks = [asyncio.create_task(limited_invoke(row)) for row in rows]
    return await asyncio.gather(*tasks)

In [None]:
NUM_SAMPLE_LLM_EVALUATION = 100
eval_rows = []
for example in test_data[:NUM_SAMPLE_LLM_EVALUATION]:
    eval_rows.append(generate_answer_async(tuned_chain, example))
event_loop = asyncio.get_event_loop()
eval_data= event_loop.run_until_complete(asyncio.gather(*eval_rows))

In [None]:
# the guardrail should be created as part of the workshop
# if not you can create "rag_eval" guardrail in the console with only the contextual grounding check enabled
eval_guardrail = [gr for gr in bedrock_client.list_guardrails()["guardrails"] if gr["name"]=="rag_eval"]
if len(eval_guardrail) == 0:
    rprint("No RAG evaluation guardrail found. Please create one in the Bedrock console.")
else:
    eval_guardrail = eval_guardrail[0]
eval_guardrail_id = eval_guardrail["id"]

In [None]:
def invoke_rag_eval_guardrail(guardrail_id, question, context, response):

    guardrail_payload = [
        {
            "text": {
                "text": context,
                "qualifiers": ["grounding_source"],
            }
        },
        {"text": {"text": question, "qualifiers": ["query"]}},
        {"text": {"text": response}},
    ]

    response = bedrock_runtime.apply_guardrail(
        guardrailIdentifier=guardrail_id,
        guardrailVersion="1",
        source="OUTPUT",
        content=guardrail_payload,
    )
    assessments = response["assessments"][0]["contextualGroundingPolicy"]["filters"]
    grounding_score = [
        metric for metric in assessments if metric["type"] == "GROUNDING"
    ][0]["score"]
    relevance_score = [
        metric for metric in assessments if metric["type"] == "RELEVANCE"
    ][0]["score"]
    return {"grounding_score": grounding_score, "relevance_score": relevance_score}

In [None]:
def evaluate_rag_guardrail(guardrail_id, questions, contexts, answers):
    results = []
    with ThreadPoolExecutor(max_workers=4) as executor:
        for question, context, answer in zip(questions, contexts, answers):
            results.append(
                executor.submit(
                    invoke_rag_eval_guardrail, guardrail_id, question, "\n".join(context), answer
                )
            )
        
    eval_rows = [result.result() for result in results]
        
    grounding_scores = [row["grounding_score"] for row in eval_rows]
    relevance_scores = [row["relevance_score"] for row in eval_rows]
    grounding_score = sum(grounding_scores) / len(grounding_scores)
    relevance_score = sum(relevance_scores) / len(relevance_scores)
    return grounding_score, relevance_score

In [None]:
rprint("Evaluating the fine-tuned RAG responses using the RAG evaluation guardrail")
ft_grounding_score, ft_relevance_score = evaluate_rag_guardrail(
    eval_guardrail_id,
    [row["question"] for row in eval_data],
    [row["contexts"] for row in eval_data],
    [row["answer"] for row in eval_data],
)

rprint(f"Fine-tuned grounding score: {ft_grounding_score}\n")
rprint(f"Fine-tuned relevance score: {ft_relevance_score}")

In [None]:
tuned_evals = {
    "grounding_score": ft_grounding_score,
    "relevancy": ft_relevance_score,

}
original_evals = json.loads(Path("base_evaluation.json").read_text())

ground_truth_grounding_score = original_evals["ground_truth"]["grounding_score"]
ground_truth_relevance_score = original_evals["ground_truth"]["relevancy"]
baseline_grounding_score = original_evals["baseline"]["grounding_score"]
baseline_relevance_score = original_evals["baseline"]["relevancy"]

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def plot_evals(grounding_scores, relevance_scores, names=["Fine-tuned", "Baseline", "Ground Truth"]):
    evals = pd.DataFrame(
        {
            "Grounding Score": grounding_scores,
            "Relevance Score": relevance_scores,
            "Model": names,
        }
    )
    evals = evals.melt(id_vars="Model", var_name="Metric", value_name="Score")
    ax = sns.barplot(x="Metric", y="Score", hue="Model", data=evals)
    plt.title("Grounding and Relevance Scores")
    
    # Add value labels
    for p in ax.patches:
        ax.annotate(format(p.get_height(), '.2f'),
                   (p.get_x() + p.get_width() / 2., p.get_height()),
                   ha = 'center', va = 'center',
                   xytext = (0, 9),
                   textcoords = 'offset points')
    
    # Move legend to the bottom
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=3)
    
    plt.show()

plot_evals(
    [ft_grounding_score, baseline_grounding_score, ground_truth_grounding_score],
    [ft_relevance_score, baseline_relevance_score, ground_truth_relevance_score],
)


### Clean Up

In [None]:
import sagemaker
from sagemaker.predictor import Predictor
sagemaker_session = sagemaker.Session()
pred = Predictor(endpoint_name=endpoint_name, sagemaker_session=sagemaker_session)
pred.delete_endpoint()