## Install dependencies

In [None]:
!pip install langchain-community langchain-core pinecone fastembed langchain-pinecone

In [None]:
from langchain_community.embeddings import FastEmbedEmbeddings
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from pinecone import Pinecone
from langchain_pinecone import PineconeVectorStore
from langchain_core.prompts import PromptTemplate
from google.colab import files

import pandas as pd
import json
from google.colab import userdata

## Initialize NLLB translate client

In [None]:
class NLLBTranslator:
    def __init__(self, model_name="facebook/nllb-200-distilled-600M"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    def translate(self, src_text, src_lang, tgt_lang):
        self.tokenizer.src_lang = src_lang

        encoded = self.tokenizer(
            src_text,
            return_tensors="pt",
            truncation=True,
            max_length=512
        )

        forced_bos_token_id = self.tokenizer.convert_tokens_to_ids(tgt_lang)

        generated_tokens = self.model.generate(
            **encoded,
            forced_bos_token_id=forced_bos_token_id,
            max_length=512,
            do_sample=False
        )

        translated = self.tokenizer.batch_decode(
            generated_tokens,
            skip_special_tokens=True
        )

        return translated[0]


nllb_translater = NLLBTranslator()
language_map = {
    "en": "eng_Latn",
    "ta": "tam_Taml",
    "si": "sin_Sinh"
}

def translate_text(text, source_language, target_language = "en"):
    translate_response = nllb_translater.translate(text, language_map[source_language], language_map[target_language])
    return translate_response

In [None]:
translate_text(
    "இந்த சேவைக்கு தேவையான ஆவணங்கள் என்ன?",
    source_language="ta",
    target_language="en"
)

## Implement the Retrieval Chain

In [None]:
class RetrievalChain:
    def __init__(self, embedding_model, vector_store, top_k=15):
        self.embedding_model = embedding_model
        self.vector_store = vector_store
        self.top_k = top_k

    def get_retrieved_context_ids(self, query):
        """
        Returns ranked list of retrieved context_ids (top_k)
        """
        results = self.vector_store.similarity_search_with_score(
            query,
            k=self.top_k
        )

        return [doc.metadata["context_id"] for doc, _ in results]

    def compute_recall_flags(self, gt_context_id, retrieved_ids):
        """
        Computes Recall@k flags
        """
        return {
            "r@1": int(gt_context_id in retrieved_ids[:1]),
            "r@3": int(gt_context_id in retrieved_ids[:3]),
            "r@5": int(gt_context_id in retrieved_ids[:5]),
            "r@10": int(gt_context_id in retrieved_ids[:10]),
            "r@15": int(gt_context_id in retrieved_ids[:15]),
        }

    def run(self, question_id, query, query_language, gt_context_id):
        en_query = translate_text(query, query_language)
        retrieved_ids = self.get_retrieved_context_ids(en_query)

        recall_flags = self.compute_recall_flags(
            gt_context_id,
            retrieved_ids
        )

        return {
            "question_id": question_id,
            "ground_truth": gt_context_id,
            "translated_question": en_query,
            **recall_flags,
            "contexts": "|".join(map(str, retrieved_ids))
        }


## Initiaize the RetrievalChain

In [None]:
experiment_id = "qt_nllb"
embedding_model = FastEmbedEmbeddings()

pc_index_name = "gic-fastembed"
pc = Pinecone(api_key=userdata.get('PINECONE_API_KEY'))
index = pc.Index(pc_index_name)
vector_store = PineconeVectorStore(index=index, embedding=embedding_model)

retrieval_chain = RetrievalChain(
    embedding_model=embedding_model,
    vector_store=vector_store,
)

## Load test data

In [None]:
import pandas as pd
qa_data = pd.read_csv("gic_qa_with_ids.csv")
qa_data.head(1)

## tamil questions

In [None]:
results_ta = []

for _, row in qa_data.iterrows():
    # if row['question_id'] == 2:
    #   break
    print(f"Processing question: {row['question_id']}")
    # print(f"Question: {row['question_ta']}")
    result = retrieval_chain.run(
        question_id=row["question_id"],
        query=row["question_ta"],
        query_language="ta",
        gt_context_id=row["context_id"]
    )
    print(result)
    print("\n-------------------------------------- \n")
    results_ta.append(result)

In [None]:
df = pd.DataFrame(results_ta)
df.tail(5)

Unnamed: 0,question_id,ground_truth,translated_question,r@1,r@3,r@5,r@10,r@15,contexts
495,496,4-91-643-1,Who provides Government Band Services and what...,1,1,1,1,1,4-91-643-1|5-32-645-1|12-71-767-1|1-20-2117-1|...
496,497,4-97-1481-2,What is the specific title of the list of Nati...,0,1,1,1,1,4-97-1484-1|11-67-1484-1|4-97-1481-2|4-97-1481...
497,498,4-91-638-1,Who is eligible to submit works for the Royal ...,1,1,1,1,1,4-91-638-1|4-91-637-1|4-91-639-1|4-91-637-2|4-...
498,499,4-91-955-1,Who will be allowed to visit the Haldummulla H...,1,1,1,1,1,4-91-955-1|4-91-1213-1|10-62-950-1|4-91-674-1|...
499,500,4-27-653-2,Under what specific conditions is compensation...,1,1,1,1,1,4-27-653-2|5-32-473-2|5-32-473-4|10-65-420-7|1...


In [None]:
df.to_csv(f"result_ta_{experiment_id}.csv", index=False)

In [None]:
summary_baseline_ta = {
    "Recall@1": df["r@1"].mean(),
    "Recall@3": df["r@3"].mean(),
    "Recall@5": df["r@5"].mean(),
    "Recall@10": df["r@10"].mean(),
    "Recall@15": df["r@15"].mean(),
}

summary_baseline_ta

In [None]:
summary_pct_baseline_ta = {
    k: v * 100
    for k, v in summary_baseline_ta.items()
}

summary_pct_baseline_ta

In [None]:
summary_pct_baseline_ta['language'] = "ta"
summary_pct_baseline_ta

## Sinhala questions

In [None]:
results_si = []

for _, row in qa_data.iterrows():
    # if row['question_id'] == 2:
    #   break
    print(f"Processing question: {row['question_id']}")
    # print(f"Question: {row['question_si']}")
    result = retrieval_chain.run(
        question_id=row["question_id"],
        query=row["question_si"],
        query_language="si",
        gt_context_id=row["context_id"]
    )
    print(result)
    # print("translated_question: ", row['translated_question'])
    # print("Ground truth: ", row["context_id"])
    # print("Retrieved: ", result["contexts"])
    results_si.append(result)

In [None]:
df = pd.DataFrame(results_si)
print(len(df))
df.tail(5)

In [None]:
df.to_csv(f"result_si_{experiment_id}.csv", index=False)

In [None]:
summary_baseline_si = {
    "Recall@1": df["r@1"].mean(),
    "Recall@3": df["r@3"].mean(),
    "Recall@5": df["r@5"].mean(),
    "Recall@10": df["r@10"].mean(),
    "Recall@15": df["r@15"].mean(),
}

summary_baseline_si

In [None]:
summary_pct_baseline_si = {
    k: v * 100
    for k, v in summary_baseline_si.items()
}

summary_pct_baseline_si

In [None]:
summary_pct_baseline_si['language'] = "si"
summary_pct_baseline_si

## Overall summary

In [None]:
summary_baseline = pd.DataFrame([summary_pct_baseline_ta, summary_pct_baseline_si])
summary_baseline

In [None]:
cols = ["language"] + [c for c in summary_baseline.columns if c != "language"]
summary_baseline = summary_baseline[cols]
summary_baseline

In [None]:
summary_baseline.to_csv(f"summary_{experiment_id}.csv", index=False)