# Exercício 10_11 RAGAS

matheusrdgsf@gmail.com / mrsf@cin.ufpe.br

In [1]:
# !pip install groq

In [2]:
import getpass
import os
import string

import matplotlib.pyplot as plt
import pandas as pd
import torch
from datasets import load_dataset
from groq import Groq
from sentence_transformers import SentenceTransformer

from tqdm.notebook import tqdm

### Parameters

In [4]:
N_QUESTIONS = 150

### LLM Inferecene

In [5]:
GROQ_KEY = os.getenv("GROQ_KEY", getpass.getpass("Enter your Groq API key: "))
client = Groq(
    api_key=GROQ_KEY,
)
MODELS = ["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it"]

In [6]:
def predict_groq(text, retry=10, temperature=0):

    for _ in range(retry):
        try:
            chat_completion = client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": f"{text}",
                    }
                ],
                model=MODELS[0],
                seed=42,
                temperature=temperature,
            )

            return chat_completion.choices[0].message.content
        except Exception as e:
            print(e)
            pass

    return "Fail in GROQ API."

### Load data

In [7]:
data = load_dataset("json", data_files="data/test_questions.json")["train"]

In [8]:
data[1]

{'answer': {'answer_spans': None,
  'answer_unit': 'years',
  'answer_value': '5',
  'type': 'value'},
 'question_links': ['World War I'],
 'title': 'Giovanni Messe',
 'context': [{'indices': [528, 615],
   'passage': 'main',
   'text': 'he became aide-de-camp to King Victor Emmanuel III, holding this post from 1923 to 1927'},
  {'indices': [178, 262],
   'passage': 'World War I',
   'text': 'a global war originating in Europe that lasted from 28 July 1914 to 11 November 1918'}],
 'question': 'How long had the First World War been over when Messe was named aide-de-camp?'}

### Process data

In [9]:
data_processed = []

for sample in data:
    question = sample["question"].strip()
    answer = sample["answer"]
    context = sample["context"]

    if answer["type"] == "binary":
        answer_text = answer["answer_value"]
    elif answer["type"] == "value":
        answer_text = answer["answer_value"] + " " + answer["answer_unit"]
    elif answer["type"] == "span":
        answer_text = answer["answer_spans"][0]["text"]
    elif answer["type"] == "none":
        answer_text = "none"
    else:
        print("Unknown answer type")

    context = [i["text"] for i in context]

    data_processed.append(
        {"question": question, "answer": answer_text.strip(), "context": context}
    )

In [10]:
data_processed[1]

{'question': 'How long had the First World War been over when Messe was named aide-de-camp?',
 'answer': '5 years',
 'context': ['he became aide-de-camp to King Victor Emmanuel III, holding this post from 1923 to 1927',
  'a global war originating in Europe that lasted from 28 July 1914 to 11 November 1918']}

### RAGAS

### Faithfulness

We can advance directly to second step because of the question lenghts.

In [11]:
if os.path.exists("faithfulness_checkpoint.csv"):
    faithfulness_checkpoint = pd.read_csv("faithfulness_checkpoint.csv")
else:
    faithfulness_checkpoint = pd.DataFrame()

In [12]:
prompt_faithfulness = """
Consider the given context and following statements, then determine whether they are supported by the information present in the context. Provide a brief explanation for each statement before arriving at the verdict (Yes/No). Provide a final verdict for each statement in order at the end in the given format. Do not deviate from the specified format.

statement: {statement}
context: {context}
"""

In [13]:
indices_to_process = [
    i for i in range(len(data_processed)) if i not in faithfulness_checkpoint.index
]

for i in tqdm(indices_to_process, total=len(indices_to_process)):

    sample = data_processed[i]

    question = sample["question"]
    answer = sample["answer"]
    context = sample["context"]

    prompt = prompt_faithfulness.format(statement=question, context=" ".join(context))

    prediction = predict_groq(prompt).lower().strip()

    faithfulness_checkpoint = pd.concat(
        [
            faithfulness_checkpoint,
            pd.DataFrame(
                {
                    "question": [question],
                    "answer": [answer],
                    "context": [context],
                    "prediction": [prediction],
                }
            ),
        ],
        ignore_index=True,
    )

    faithfulness_checkpoint.to_csv("faithfulness_checkpoint.csv", index=False)

0it [00:00, ?it/s]

In [14]:
# regex to find verdict
import re

verdict_regex = re.compile(r"verdict: (yes|no)")

faithfulness_checkpoint["verdict"] = faithfulness_checkpoint["prediction"].apply(
    lambda x: verdict_regex.search(x).group(1) if verdict_regex.search(x) else None
)

In [15]:
print(
    f"Faithfulness: {faithfulness_checkpoint['verdict'].value_counts().get('yes', 0) / faithfulness_checkpoint.shape[0] * 100:.2f}%"
)

Faithfulness: 70.00%


### Answer Relevance

In [118]:
if os.path.exists("answer_relevance_checkpoint.csv"):
    answer_relevance_checkpoint = pd.read_csv("answer_relevance_checkpoint.csv")
else:
    answer_relevance_checkpoint = pd.DataFrame()

In [119]:
prompt_answer_relevance = """
Generate {n_questions} different questions for the given answer.

Your answer need to be only the questions split by \n. Don't add any other information or special characters.

Do not deviate from the specified format.
answer: {answer}
"""

retry_prompt = """
Given the following questions, split them in {n_questions} and return splitted by \n.
Don't add any other information or special characters.

questions: {questions}
"""

#### Generate Questions

In [120]:
N_QUESTIONS = 5

total_process = len(data_processed) - len(answer_relevance_checkpoint)

start_index = len(answer_relevance_checkpoint)

for i in tqdm(range(start_index, len(data_processed)), total=total_process):
    row = data_processed[i]

    prompt = prompt_answer_relevance.format(
        answer=row["answer"], n_questions=N_QUESTIONS
    )

    questions = predict_groq(prompt, temperature=1).split("\n")

    if len(questions) != N_QUESTIONS:
        questions = predict_groq(
            retry_prompt.format(
                n_questions=N_QUESTIONS,
                questions=questions,
                temperature=1,
            )
        ).split("\n")

    answer_relevance_checkpoint = pd.concat(
        [
            answer_relevance_checkpoint,
            pd.DataFrame(
                {
                    "answer": row["answer"],
                    "questions": [questions],
                    "original_question": row["question"],
                }
            ),
        ],
        ignore_index=True,
    )

    answer_relevance_checkpoint.to_csv("answer_relevance_checkpoint.csv", index=False)

  0%|          | 0/50 [00:00<?, ?it/s]

#### Extract Answers

In [92]:
cosine_similarity = torch.nn.CosineSimilarity()

In [93]:
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [129]:
similarities = []
for i, row in answer_relevance_checkpoint.iterrows():

    original_question = row["original_question"]
    questions = row["questions"]

    original_question_embedding = model.encode(original_question)
    questions_embedding = model.encode(questions)

    similarity = cosine_similarity(
        torch.tensor(original_question_embedding),
        torch.tensor(questions_embedding),
    ).mean()

    similarities.append(similarity.item())

print(f"Answer Relevance: {sum(similarities) / len(answer_relevance_checkpoint):.2f}")

Answer Relevance: 0.20


### Context Relevance

In [151]:
if os.path.exists("context_relevance_checkpoint.csv"):
    context_relevance_checkpoint = pd.read_csv("context_relevance_checkpoint.csv")
else:
    context_relevance_checkpoint = pd.DataFrame()

In [152]:
prompt_context_relevance = """
Return relevante sentences from the provided context that can potentially
help answer the following question. If no relevant sentences are found, or if you
believe the question cannot be answered from the given context, return the phrase
'Insufficient Information'. While extracting candidate sentences you’re not allowed
to make any changes to sentences from given context.

Your answer need to be only strings of the sentences split by \n. Don't add any other information or special characters.

Question:
{question}

Context:
{context}
"""

In [153]:
total_process = len(data_processed) - len(context_relevance_checkpoint)

start_index = len(context_relevance_checkpoint)

for i in tqdm(range(start_index, len(data_processed)), total=total_process):

    row = data_processed[i]

    question = row["question"]
    context = row["context"]

    prompt = prompt_context_relevance.format(
        question=question,
        context="\n".join([f"{i+1} - {context}" for i, context in enumerate(context)]),
    )

    context_relevance = predict_groq(prompt, temperature=1).split("\n")

    context_relevance_checkpoint = pd.concat(
        [
            context_relevance_checkpoint,
            pd.DataFrame(
                {
                    "question": [question],
                    "context": [context],
                    "context_relevance": [context_relevance],
                }
            ),
        ],
        ignore_index=True,
    )

    context_relevance_checkpoint.to_csv("context_relevance_checkpoint.csv", index=False)

  0%|          | 0/50 [00:00<?, ?it/s]

In [164]:
def remove_punctuation(text: str):
    return text.translate(str.maketrans("", "", string.punctuation))


context_relevance_score = []

for i, row in context_relevance_checkpoint.iterrows():

    question = row["question"]
    context = row["context"]

    context_relevance = list(
        map(lambda i: remove_punctuation(i.strip().lower()), row["context_relevance"])
    )

    context_relevance = list(
        filter(lambda i: i != "insufficient information", context_relevance)
    )

    context_relevance_score.append(len(context_relevance) / len(context))

print(
    f"Context Relevance: {sum(context_relevance_score) / len(context_relevance_score):.2f}"
)

Context Relevance: 0.50
