<a href="https://colab.research.google.com/github/DharshiBalasubramaniyam/super-duper-rotary-phone/blob/main/mbart-and-gemini/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install libraries

In [None]:
!pip install langchain langchain-community langchain-core pinecone

Collecting langchain-community
  Downloading langchain_community-0.3.24-py3-none-any.whl.metadata (2.5 kB)
Collecting pinecone
  Downloading pinecone-7.0.1-py3-none-any.whl.metadata (9.5 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.9.1-py3-none-any.whl.metadata (3.8 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting pinecone-plugin-interface<0.0.8,>=0.0.7 (from pinecone)
  Downloading pinecone_plugin_interface-0.0.7-py3-none-any.whl.metadata (1.2 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain-communit

In [None]:
!pip install fastembed langchain-google-genai langchain-pinecone langchain-text-splitters python-dotenv



In [None]:
import os
import time
import csv
import json
import re
import time

from pathlib import Path

from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings import FastEmbedEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import GoogleGenerativeAI
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone, ServerlessSpec
from google.cloud import translate_v2 as translate

In [None]:
def get_folder_path(folder_name):
    folder_path = os.path.join(os.getcwd(), folder_name)
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    return folder_path


def detect_language(google_translate_client, text):
    language_response = google_translate_client.detect_language(text)
    return language_response["language"]


def translate_text(client, target_language, source_language, text):
    translate_response = client.translate(
        text,
        source_language=source_language,
        target_language=target_language,
    )
    return translate_response["translatedText"]

## Rag Evaluator

In [None]:
class RagEvaluator:
    def __init__(self, llm):
        self.llm = llm

    def run(self, query, context, ground_truth, generated_response):
        prompt = f"""
        You are an expert evaluator tasked with assessing the quality of a generated response from a
        Cross-lingual Retrieval-Augmented Generation (RAG) system.

        The retrieved context is in English, while the query, generated response, and ground truth answer are
        all in the same language — either Tamil (ta), Sinhala (si), or English (en).

        The RAG system is explicitly instructed to respond with “I'm sorry, I can’t help with that.” if the retrieved context does not
        contain sufficient information to answer the query. In such cases, this response should be considered
        faithful and potentially correct depending on whether the answer was indeed absent in the context.

        For each of the following three dimensions in evaluation criteria, provide a score from 1 (very poor) to 5 (excellent),
        and explain your reasoning concisely in **English** only.

        ### Evaluation Criteria:

        - **Faithfulness**: Does the generated response accurately reflect the information in the retrieved context without introducing unsupported facts?
          → A response like “I'm sorry, I can’t help with that.” is highly faithful if the context truly lacks the necessary information.

        - **Answer Correctness**: How well does the generated response match the ground truth in terms of factual correctness and completeness?
          → If the ground truth cannot be derived from the context, and the generated response correctly says “I'm sorry, I can’t help with that.”, consider it correct.

        - **Context Relevance**: Is the retrieved context relevant and appropriate for answering the query?
          → If the context does not help in answering the query, score this low even if the response is appropriate.

        ### Output format (example)
        {{
          "faithfulness_score": "SCORE_BETWEEN_1_AND_5",
          "faithfulness_reason": "Explain whether the generated response stays true to the retrieved context or introduces hallucinations.",
          "answer_correctness_score": "SCORE_BETWEEN_1_AND_5",
          "answer_correctness_reason": "Explain how factually correct and complete the response is when compared with the ground truth.",
          "context_relevance_score": "SCORE_BETWEEN_1_AND_5",
          "context_relevance_reason": "Explain whether the retrieved context was appropriate and useful for answering the query."
        }}
        Make sure to return a JSON string like above without any formatting.

        ### Inputs:

        Below is the data you need to evaluate:

        Query: {query}

        Retrieved Context: {context}

        Generated Response: {generated_response}

        Ground Truth Answer: {ground_truth}
        """

        response = self.llm.invoke(prompt)

        return response


# llm = GoogleGenerativeAI(model="gemini-2.0-flash", google_api_key="")
# eval = RagEvaluator(llm)
# print(eval.run(
#     query="What are the requirements to get a driving license in Sri Lanka?",
#     generated_response="To obtain a driving license card in Sri Lanka, applicants must be 16 years or older and provide documents such as the birth certificate, Form K, and Grama Niladhari letter. The application is submitted to the Department for Registration of Persons.",
#     ground_truth="To get a driving license in Sri Lanka, you need to be 18 or older and submit Form K and your birth certificate to the Department for Registration of Persons.",
#     context="To obtain a driving license in Sri Lanka, you must be at least 18 years old, pass a written test and a practical driving test, and provide your NIC and medical certificate. The application must be submitted to the Department of Motor Traffic."
# ))


## Mbart translator

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

class MbartTranslator:
    def __init__(self, model_name = "facebook/mbart-large-50-many-to-many-mmt"):
        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
        self.model = MBartForConditionalGeneration.from_pretrained(model_name)

    def translate(self, src_text, src_lang, to_lang):
      self.tokenizer.src_lang = src_lang
      encoded = self.tokenizer(src_text, return_tensors="pt")

      forced_bos_token_id = self.tokenizer.lang_code_to_id[to_lang]
      generated_tokens = self.model.generate(**encoded, forced_bos_token_id=forced_bos_token_id)
      translated = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

      return translated[0]

## Class RagChain

In [None]:
class RagChain:
    def __init__(self, data_folder_name, pc_index_name, pinecone_api_key, google_api_key, translate_client_config_file_path):
        self.data_folder_path = get_folder_path(data_folder_name)
        self.embedding_model = FastEmbedEmbeddings()
        self.vector_store = self.process_vector_store(pinecone_api_key, pc_index_name)
        self.llm = GoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=google_api_key)
        self.prompt_template = PromptTemplate(
            template="""You are a helpful and knowledgeable AI assistant that provides accurate and clear information about the registration of births, marriages, and deaths in Sri Lanka. You support responses in English (en), Sinhala (si), or Tamil (ta) based on the user's query language.

Use only the information from the retrieved documents to answer. Follow these instructions:

- Only answer questions related to registration information.
If the user's question is unrelated or the answer cannot be found in the retrieved documents, respond politely with:
"I'm sorry, I can’t help with that."

- Present the response in a clear, organized format using bullet points or numbered lists where appropriate.

- If relevant, include links from the 'Important Links' section (If exists) to guide users to official resources.

- At the end of the response, include a "For more information" link if available and relevant to the query.

Maintain a professional and helpful tone. Do not invent or assume information. Answer in the language used by the user in their query.
---
Context:
{context}

user: {question}
Assistant:
""",
            input_variables=["context", "question"],
        )
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
            chain_type="stuff",
            return_source_documents=True,
            chain_type_kwargs={"prompt": self.prompt_template},
        )
        self.mbart_translater = MbartTranslator()
        self.google_translate_client = translate.Client.from_service_account_json(translate_client_config_file_path)

    def process_vector_store(self, pinecone_api_key, pc_index_name):
        pc = Pinecone(api_key=pinecone_api_key)

        existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]

        if pc_index_name not in existing_indexes:
            print(f"Vector store index '{pc_index_name}' not found. Create new...")

            pc.create_index(
                name=pc_index_name,
                dimension=384,
                metric="cosine",
                spec=ServerlessSpec(
                    cloud="aws",
                    region="us-east-1"
                )
            )

            while not pc.describe_index(pc_index_name).status['ready']:  # Wait for the index to be ready
                time.sleep(1)
            print(f"Successfully created vector store index '{pc_index_name}'. Adding pdf files to new index...")

            index = pc.Index(pc_index_name)
            vector_store = PineconeVectorStore(index=index, embedding=self.embedding_model)

            print(f"Adding documents to index '{pc_index_name}'.")

            if os.path.exists(self.data_folder_path):
                md_files = list(Path(self.data_folder_path).rglob('*.md'))
                print(f"Found '{len(md_files)}' files to process.")
                for file_path in md_files:
                    print(f"Processing file: '{file_path}'.")
                    loader = TextLoader(file_path, encoding='utf-8')
                    documents = loader.load()

                    # splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
                    # chunks = splitter.split_documents(documents)
                    vector_store.add_documents(documents)
            else:
                print(f"Error: No found folder '{self.data_folder_path}'.")
            return vector_store

        print(f"Existing vector store found. Returning it: '{pc_index_name}'.")
        index = pc.Index(pc_index_name)
        return PineconeVectorStore(index=index, embedding=self.embedding_model)

    def process_query(self, query, user_input_language):
        language_map = {
            "en": "en_XX", "ta": "ta_IN", "si": "si_LK"
        }
        english_query = query if user_input_language == "en" \
            else self.mbart_translater.translate(query, language_map[user_input_language], language_map["en"])
        return english_query

    def process_response(self, response, user_input_language):
        english_query = response if user_input_language == "en" \
            else translate_text(self.google_translate_client, user_input_language, "en", response)
        return english_query

    def run(self, user_input_query, ground_truth):
        user_input_language = detect_language(self.google_translate_client, user_input_query)
        query = self.process_query(user_input_query, user_input_language)
        llm_response = self.qa_chain.invoke({"query": f"{query}. Please provide your answer in {user_input_language}"})
        context = ""
        for doc in llm_response["source_documents"]:
            context += doc.page_content + "\n --- \n"
        eval = RagEvaluator(self.llm)
        eval_response = eval.run(
            query, context, ground_truth, llm_response['result']
        )

        return llm_response['result'], eval_response, context, query

## Initialize RAG chain

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
folder_name = "rgd_data"
pc_index_name = "rgd-data"
google_api_key = input("Enter google API kEY: ")
pinecone_api_key = input("Enter pinecone API key: ")
translate_client_config_file_path = input("Enter translate client config file path:")

In [None]:
rag_chain = RagChain(
    data_folder_name=folder_name,
    pc_index_name=pc_index_name,
    pinecone_api_key=pinecone_api_key,
    google_api_key=google_api_key,
    translate_client_config_file_path=translate_client_config_file_path
)


Existing vector store found. Returning it: 'rgd-data'.


## Evaluating the tamil QA pairs

In [None]:
evaluation_data = [
    ['id', 'original_query', 'en_query', 'faithfulness_score', 'answer_correctness_score', 'context_relevance_score']
]

count = 1
qa_file_path = "/content/drive/My Drive/Colab Notebooks/fyrp/qa_data/tamil2.csv"
output_file_path = "/content/drive/My Drive/Colab Notebooks/fyrp/outputs/mbart_ta.csv"
with open(qa_file_path, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        if count % 5 == 0:
            time.sleep(120)
        test_id = row['id']
        question = row['question']
        ground_truth = row['answer']

        print("\n----------------------\nprocessing question: ", question)

        generated_res, evaluation, context, en_query = rag_chain.run(question, ground_truth)
        print(f"""

English query:
{en_query}

Generate response:
{generated_res}

Evaluation:
{evaluation}

Context:
{context}
""")
        evaluation_match = re.search(r'\{.*\}', evaluation.strip(), re.DOTALL)
        if evaluation_match:
            try:
                result = json.loads(evaluation_match.group())
                faithfulness_score = result['faithfulness_score']
                answer_correctness_score = result['answer_correctness_score']
                context_relevance_score = result['context_relevance_score']
                test_res = [test_id, question, en_query, faithfulness_score, answer_correctness_score, context_relevance_score]
                evaluation_data.append(test_res)
            except json.JSONDecodeError as e:
                print("Failed to parse JSON:", e)
        else:
            print("Failed to match")
        count += 1
    with open(output_file_path, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerows(evaluation_data)


----------------------
processing question:  ஒரு குழந்தையை தத்தெடுக்க தேவைப்பட்டால் தத்தெடுக்கக்கூடிய குழந்தையின் அதிகபட்ச வயது என்ன?


English query:
What is the maximum age of a child to be adopted if a child is required to be adopted?
              
Generate response:
பிள்ளையைத் தத்தெடுக்க வேண்டுமென்றால், தத்தெடுக்கப்படும் பிள்ளையின் அதிகபட்ச வயது 14 வருடங்கள்.

For more information: https://www.rgd.gov.lk/web/index.php?option=com_content&view=article&id=20&Itemid=151&lang=en#adoption-of-a-child

Evaluation:
```json
{
  "faithfulness_score": "5",
  "faithfulness_reason": "The generated response accurately reflects the information present in the retrieved context regarding the maximum age of a child to be adopted.",
  "answer_correctness_score": "5",
  "answer_correctness_reason": "The generated response matches the ground truth answer almost exactly. The answer is factually correct and complete based on the context.",
  "context_relevance_score": "5",
  "context_relevance_reason": 

## Evaluating the sinhala QA pairs

In [None]:
evaluation_data = [
    ['id', 'original_query', 'en_query', 'faithfulness_score', 'answer_correctness_score', 'context_relevance_score']
]

count = 1
qa_file_path = "/content/drive/My Drive/Colab Notebooks/fyrp/qa_data/sinhala2.csv"
output_file_path = "/content/drive/My Drive/Colab Notebooks/fyrp/outputs/mbart_si.csv"
with open(qa_file_path, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        if count % 5 == 0:
            time.sleep(120)
        test_id = row['id']
        question = row['question']
        ground_truth = row['answer']

        print("\n----------------------\nprocessing question: ", question)

        generated_res, evaluation, context, en_query = rag_chain.run(question, ground_truth)
        print(f"""
Generate response:
{generated_res}

Evaluation:
{evaluation}

Context:
{context}
""")
        evaluation_match = re.search(r'\{.*\}', evaluation.strip(), re.DOTALL)
        if evaluation_match:
            try:
                result = json.loads(evaluation_match.group())
                faithfulness_score = result['faithfulness_score']
                answer_correctness_score = result['answer_correctness_score']
                context_relevance_score = result['context_relevance_score']
                test_res = [test_id, question, en_query, faithfulness_score, answer_correctness_score, context_relevance_score]
                evaluation_data.append(test_res)
            except json.JSONDecodeError as e:
                print("Failed to parse JSON:", e)
        else:
            print("Failed to match")
        count += 1
    with open(output_file_path, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerows(evaluation_data)


----------------------
processing question:  දරුවෙකු තාත්තා ගැනීම සඳහා උපරිම වයස කීයද?

Generate response:
දරුවෙකු හදා වඩා ගැනීමේදී දරුවාගේ උපරිම වයස අවුරුදු 14 නොඉක්මවිය යුතුය.

*   දරුවෙකු හදා වඩා ගැනීම සඳහා වැඩි විස්තර සඳහා: [https://www.rgd.gov.lk/web/index.php?option=com\_content&view=article&id=20&Itemid=151&lang=en#adoption-of-a-child](https://www.rgd.gov.lk/web/index.php?option=com_content&view=article&id=20&Itemid=151&lang=en#adoption-of-a-child)

Evaluation:
```json
{
  "faithfulness_score": "5",
  "faithfulness_reason": "The generated response accurately reflects the information in the retrieved context regarding the maximum age of a child to be adopted (14 years).",
  "answer_correctness_score": "5",
  "answer_correctness_reason": "The generated response is factually correct and complete, matching the ground truth. It correctly states that the maximum age for a child to be adopted is 14 years.",
  "context_relevance_score": "5",
  "context_relevance_reason": "The retrieved