## Build a Retrieval Augment Generation Solution with a Fine Tuned Model
Now that we have a fine-tuned model, it's time to revisit our retrieval augment generation solution. In this notebook, we will re-implement the initial pipeline but this time using the fine-tuned model to generate the responses. We will then evaluate the performance of the fine-tuned model and compare it with the initial model.

In [None]:
import sys
import os
module_path = "../.."
sys.path.append(os.path.abspath(module_path))
from utils.environment_validation import validate_environment, validate_model_access
validate_environment()

In [None]:
required_models = [
    "amazon.titan-embed-text-v2:0",
    "mistral.mixtral-8x7b-instruct-v0:1",
    "mistral.mistral-7b-instruct-v0:2",
    "anthropic.claude-3-haiku-20240307-v1:0"
]
validate_model_access(required_models)

In [None]:
from pathlib import Path
from itertools import chain
from rich import print as rprint
import json
import os
from langchain_core.documents import Document
from langchain_aws.llms import SagemakerEndpoint
from langchain_aws.chat_models import BedrockChat
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler
from langchain_aws.embeddings import BedrockEmbeddings
from langchain_community.vectorstores import FAISS
import boto3
from typing import Dict

import pickle
from io import BytesIO
from pathlib import Path

import asyncio
import nest_asyncio
nest_asyncio.apply()

data_path = Path("data/prepared_data")
train_data = (data_path / "prepared_data_train.jsonl").read_text().splitlines()
test_data = (data_path / "prepared_data_test.jsonl").read_text().splitlines()

doc_ids = []
documents = []

# Create a list of LangChain documents that can then be used to ingest into a vector store

for record in chain(train_data, test_data):
    json_record = json.loads(record)
    if json_record["ref_doc_id"] not in doc_ids:
        doc_ids.append(json_record["ref_doc_id"])
        doc = Document(page_content=json_record["context"], metadata=json_record["section_metadata"])
        documents.append(doc)

print(f"Loaded {len(documents)} sections")

In [None]:
boto3_session=boto3.session.Session()

bedrock_runtime = boto3_session.client("bedrock-runtime")
smr_client = boto3_session.client("sagemaker-runtime")

embedding_modelId = "amazon.titan-embed-text-v2:0"

embed_model = BedrockEmbeddings(
    model_id=embedding_modelId,
    model_kwargs={"dimensions": 1024, "normalize": True},
    client=bedrock_runtime,
)

query = "Do I really need to fine-tune the large language models?"
response = embed_model.embed_query(query)
rprint(f"Generated an embedding with {len(response)} dimensions. Sample of first 10 dimensions:\n", response[:10])

In [None]:
# we can resuse the vector db from the initial model since we're keeping embeddings the same
vector_store_file = "baseline_rag_vec_db.pkl"

if not Path(vector_store_file).exists():
    rprint(f"Vector store file {vector_store_file} does not exist. Will create a new vector store.")
    CREATE_NEW = True
else:
    rprint(f"Vector store file {vector_store_file} already exists. Delete it or change the file name above to create a new vector store.")
    CREATE_NEW = False 

if CREATE_NEW:
    vec_db = FAISS.from_documents(documents, embed_model)
    pickle.dump(vec_db.serialize_to_bytes(), open(vector_store_file, "wb"))
    
else:
    if not Path(vector_store_file).exists():
        raise FileNotFoundError(f"Vector store file {vector_store_file} not found. Set CREATE_NEW to True to create a new vector store.")
    
    vector_db_buff = BytesIO(pickle.load(open(vector_store_file, "rb")))
    vec_db = FAISS.deserialize_from_bytes(serialized=vector_db_buff.read(), embeddings=embed_model, allow_dangerous_deserialization=True)

In [None]:
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever

k = 3
faiss_retriever = vec_db.as_retriever(search_kwargs={"k": k})

bm_25 = BM25Retriever.from_documents(documents)
bm_25.k = k


ensemble_retriever = EnsembleRetriever(
    retrievers=[faiss_retriever, bm_25], weights=[0.75, 0.25]
)

### Using SageMaker endpoints with LangChain
To use a SageMaker endpoint with LangChain we need to implement a `ContentHandler` class that will handle the preprocessing of the input data and the postprocessing of the output data. The primary reason for this is unlike Bedrock, a SageMaker API may have different and inconsistent input and output payload formats.

In [None]:
sagemaker_endpoint_config_path = Path("endpoint_config.json")
if sagemaker_endpoint_config_path.exists():
    endpoint_name = json.loads(sagemaker_endpoint_config_path.read_text())["endpoint_name"]
else:
    rprint(f"[bold red]Endpoint config file {sagemaker_endpoint_config_path} not found. Please make sure that you ran the prior training and deployment notebooks[/bold red]")

In [None]:
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        """formats the request payload for the model endpoint"""
        
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        """extracts the generated answer from the endpoint response"""
        
        output = output.read().decode("utf-8")
        generated_answer = json.loads(output)["generated_text"]
        return generated_answer

content_handler = ContentHandler()

tuned_llm = SagemakerEndpoint(
        endpoint_name=endpoint_name,
        client=smr_client,
        model_kwargs={"temperature": 0, "max_new_tokens": 500},
        content_handler=content_handler,
    )

With the configured LLM we can reimplement the retrieval augment generation pipeline using the fine-tuned model.

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from operator import itemgetter

tuned_prompt_template = "[INST] You are a Banking Regulations expert.\nGiven this context\nCONTEXT\n{context}\n Answer this question\nQuestion: {question} [/INST]"


prompt = PromptTemplate.from_template(tuned_prompt_template)
output_parser = StrOutputParser()

setup_and_retrieval = RunnableParallel(
    {"context": ensemble_retriever, "question": RunnablePassthrough()}
)

generate_tuned_answer = {"answer": prompt | tuned_llm | output_parser,
                   "context": itemgetter("context")}

tuned_chain = setup_and_retrieval | generate_tuned_answer

In [None]:
# get a sample response 

sample_record = json.loads(test_data[150])
sample_question = sample_record["question"]
sample_answer = sample_record["answer"]
rprint(f"Sample question: {sample_question}")
response = tuned_chain.invoke(sample_question)
generated_answer = response["answer"]
rprint(f"\nGenerated answer: {generated_answer}")
rprint(f"\nGround truth answer: {sample_answer}")

In [None]:
# helper functions to speed up the inference and evaluation process

async def generate_answer_async(rag_chain, example):
    example = json.loads(example)
    response = await rag_chain.ainvoke(example["question"])
    contexts = [doc.page_content for doc in response["context"]]
    row = {"question": example["question"], "answer": response["answer"], "contexts": contexts, "ground_truth": example["answer"]}
    return row

async def evaluate_llm_async(metric, rows):
    evals = [metric.acall(row) for row in rows]
    # event_loop = asyncio.get_event_loop()
    evals = await asyncio.gather(*evals)
    
    return evals

In [None]:
NUM_SAMPLE_LLM_EVALUATION = 100
eval_rows = []
for example in test_data[:NUM_SAMPLE_LLM_EVALUATION]:
    eval_rows.append(generate_answer_async(tuned_chain, example))
event_loop = asyncio.get_event_loop()
eval_data= event_loop.run_until_complete(asyncio.gather(*eval_rows))

In [None]:
from ragas.metrics import faithfulness, answer_similarity, answer_relevancy, answer_correctness
from ragas.integrations.langchain import EvaluatorChain
import math

os.environ["OPENAI_API_KEY"] = "12345"
eval_llm = BedrockChat(
    model_id="anthropic.claude-3-haiku-20240307-v1:0",
    model_kwargs={
        "temperature": 0
    },
    client=bedrock_runtime,
)

In [None]:
faithfulness_metric = EvaluatorChain(metric=faithfulness, llm=eval_llm, embeddings=embed_model)
answer_relevancy_metric = EvaluatorChain(metric=answer_relevancy, llm=eval_llm, embeddings=embed_model)
answer_similarity_metric = EvaluatorChain(metric=answer_similarity, llm=eval_llm, embeddings=embed_model)
answer_correctness_metric = EvaluatorChain(metric=answer_correctness, llm=eval_llm, embeddings=embed_model)

[**Faithfulness:**](https://docs.ragas.io/en/stable/concepts/metrics/faithfulness.html) measure the extent to which the claims in the generated answer are supported by the context. It is calculated as the ratio of the number of claims in the generated answer that are supported by the context to the total number of claims in the generated answer. In other words it helps us detect hallucinations as we would expect all claims in the generated answer to be supported by the context.
It does not reflect on the accuracy or correctness of the claims, only that they are supported by the context.

**NOTE:** If you see a message `Failed to parse output. Returning None.` during the evaluation, it simply means that ragas was unable to parse the output from the model. This can happen if the model generates an output that is not in the expected format. These samples will be ignored when calculating the aggregate metric.

In [None]:
faithfulness_evals = event_loop.run_until_complete(evaluate_llm_async(faithfulness_metric, eval_data))
faithfulness_scores = [eval["faithfulness"] for eval in faithfulness_evals if not math.isnan(eval["faithfulness"])]
faithfulness_score = sum(faithfulness_scores) / len(faithfulness_scores)

print("Faithfulness Score: ", faithfulness_score)

[**Answer Relevancy:**](https://docs.ragas.io/en/stable/concepts/metrics/answer_relevance.html) attempts to measure how pertinent the generated answer is to the given prompt. It works by having the evaluator LLM generate synthetic questions based on the generated answer and then calculating the average semantic similarity between the given question and the synthetic questions. The idea is that a more complete and pertinent answer should yield synthetic questions that are more similar to the given question. 

In [None]:
relevancy_evals = event_loop.run_until_complete(evaluate_llm_async(answer_relevancy_metric, eval_data))
relevancy_scores = [eval["answer_relevancy"] for eval in relevancy_evals if not math.isnan(eval["answer_relevancy"])]
relevancy_score = sum(relevancy_scores) / len(relevancy_scores)

print("Answer Relevancy Score: ", relevancy_score)

[**Answer semantic similarity:**](https://docs.ragas.io/en/stable/concepts/metrics/semantic_similarity.html) measures the cosine similarity between the ground truth answer and the generated answer.

In [None]:
answer_similarity_evals = event_loop.run_until_complete(evaluate_llm_async(answer_similarity_metric, eval_data[:10]))
similarity_scores = [eval["answer_similarity"] for eval in answer_similarity_evals if not math.isnan(eval["answer_similarity"])]
similarity_score = sum(similarity_scores) / len(similarity_scores)

print("Answer Similarity Score: ", similarity_score)

[**Answer Correctness**](https://docs.ragas.io/en/stable/concepts/metrics/answer_correctness.html): Combines factual similarity assessed by the evaluator LLM with the semantic similarity between the generated answer and the ground truth. It is calculated as a weighted average of the factual similarity and the semantic similarity. Factual similarity is calculated similar to Faithfulness but also considers overlapping claims between the generated answer and the ground truth.

In [None]:
answer_correctness_evals = event_loop.run_until_complete(evaluate_llm_async(answer_correctness_metric, eval_data))
correctness_scores = [eval["answer_correctness"] for eval in answer_correctness_evals if not math.isnan(eval["answer_correctness"])]
correctness_score = sum(correctness_scores) / len(correctness_scores)

print("Answer Correctness Score: ", correctness_score)

In [None]:
tuned_evals = {
    "faithfulness": faithfulness_score,
    "relevancy": relevancy_score,
    "similarity": similarity_score,
    "correctness": correctness_score,
}
original_evals = json.loads(Path("base_evaluation.json").read_text())
    

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def plot_evals(tuned_evals, original_evals):
    labels = list(tuned_evals.keys())
    tuned_scores = list(tuned_evals.values())
    original_scores = list(original_evals.values())

    data = {"Labels": labels*2, 
            "Scores": tuned_scores + original_scores, 
            "Type": ["Tuned"]*len(tuned_scores) + ["Original"]*len(original_scores)}

    df = pd.DataFrame(data)

    plt.figure(figsize=(10, 6))
    barplot = sns.barplot(x="Labels", y="Scores", hue="Type", data=df)

    for p in barplot.patches:
        barplot.annotate(format(p.get_height(), '.3f'), 
                         (p.get_x() + p.get_width() / 2., p.get_height()), 
                         ha = 'center', va = 'center', 
                         xytext = (0, 10), 
                         textcoords = 'offset points')

    plt.show()

plot_evals(tuned_evals, original_evals)


### Clean Up

In [None]:
import sagemaker
from sagemaker.predictor import Predictor
sagemaker_session = sagemaker.Session()
pred = Predictor(endpoint_name=endpoint_name, sagemaker_session=sagemaker_session)
pred.delete_endpoint()