# 1. Install required packages

We install all the dependencies needed for building a
Retrieval-Augmented Generation (RAG) pipeline.
These include LangChain components, Hugging Face models,
ChromaDB for vector storage, and PyTorch for GPU acceleration.

In [None]:
%pip install somepackage -qq langchain langchain-community langchain-core langchain-text-splitters langchain-huggingface sentence-transformers chromadb transformers torch accelerate unstructured codecarbon

# 2. Imports and Configuration
Imports necessary libraries, sets constants for models, embeddings, chunking, and loads prompt templates from a JSON file for various NLP tasks.


In [None]:
from pathlib import Path
import json
import time
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_core.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
import warnings
from codecarbon import OfflineEmissionsTracker

warnings.filterwarnings("ignore")

PROMPTS_FILE = "data/test_data.json"
PERSIST_DIR = "data/chroma_db"
EMBED_MODEL = "all-MiniLM-L6-v2"
CHUNK_SIZE = 400
CHUNK_OVERLAP = 50
TOP_K_RESULTS = 5
RELEVANCE_THRESHOLD = 0.3
LLM_MODEL = "MBZUAI/LaMini-Flan-T5-248M"
MAX_NEW_TOKENS = 200
USE_GPU = torch.cuda.is_available()
COUNTRY_ISO_CODE = "EGY"
ENABLE_RECURSIVE_EDITING = True
MAX_EDIT_ITERATIONS = 2

with open(PROMPTS_FILE, "r") as f:
    config_data = json.load(f)
    MASTER_INSTRUCTION = config_data["metadata"]["master_instruction"]
    TASK_INSTRUCTIONS = config_data["metadata"]["task_instructions"]


def build_prompt_template(task_type):
    task_instruction = TASK_INSTRUCTIONS[task_type]
    return f"""{MASTER_INSTRUCTION}

{task_instruction}

Context:
{{context}}

Question: {{question}}

Answer:"""


PROMPT_TEMPLATES = {
    "summarization": build_prompt_template("summarization"),
    "reasoning": build_prompt_template("reasoning"),
    "rag": build_prompt_template("rag"),
    "paraphrasing": build_prompt_template("paraphrasing"),
    "creative_generation": build_prompt_template("creative_generation"),
}

# 3. Initialize embedding model and text splitter

The embedding model converts text into numeric vectors, while the text
splitter breaks long documents into manageable chunks for retrieval.

In [None]:
embedder = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
)

# 4. Load the local language model

Initializes the HuggingFace Seq2Seq model and tokenizer, wraps it in a pipeline for text generation, and tracks energy usage with `OfflineEmissionsTracker`.


In [None]:
tracker_loading = OfflineEmissionsTracker(
    country_iso_code=COUNTRY_ISO_CODE, project_name="model_loading", log_level="error"
)
tracker_loading.start()

tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL, low_cpu_mem_usage=True)
pipe = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=MAX_NEW_TOKENS,
    do_sample=False,
    repetition_penalty=1.3,
    device=0 if USE_GPU else -1,
    truncation=True,
    max_length=512,
)
llm = HuggingFacePipeline(pipeline=pipe)
emissions_loading = tracker_loading.stop()

# 5. Load documents from JSON

We read the context and metadata directly from a JSON file.
We also clean metadata and split text into chunks.

In [None]:
def load_documents_from_json(json_path=PROMPTS_FILE):
    data_path = Path(json_path)
    if not data_path.exists():
        print(f"JSON file not found at: {json_path}")
        return []

    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    source_text = data.get("source_text", "")
    metadata = data.get("metadata", {})

    if not source_text.strip():
        print("No source text found in JSON.")
        return []

    for k, v in metadata.items():
        if isinstance(v, (list, dict)):
            metadata[k] = str(v)

    split_docs = splitter.create_documents([source_text])

    for doc in split_docs:
        doc.metadata = metadata.copy()
        doc.metadata["topic"] = "Apollo 11"
        doc.metadata["section"] = ", ".join(metadata.get("sections", ["General"]))

    print(f"Loaded and split {len(split_docs)} chunks from JSON.")
    return split_docs

# 6. Build Chroma vector store

Here we embed the document chunks and save them into a local vector database (Chroma).
This enables fast similarity-based retrieval of relevant context later.

In [None]:
def build_chroma_store(docs, persist_dir=PERSIST_DIR):
    db = Chroma.from_documents(
        documents=docs, embedding=embedder, persist_directory=persist_dir
    )
    db.persist()
    return db

# 7. Calling the Load Document Function 

This cell loads the source document (text and metadata) from the JSON file, and
splits it into smaller chunks for embedding.

In [None]:
documents = load_documents_from_json()

# 8. Calling the Build Chroma Function

This cell builds a Chroma vector database
that stores those embeddings for efficient similarity search.

Once the database is built, itâ€™s saved to disk,
so you only need to run this cell once, unless you change or add new data.

Running it again will overwrite the existing database.

In [None]:
tracker_embeddings = OfflineEmissionsTracker(country_iso_code=COUNTRY_ISO_CODE)
tracker_embeddings.start()

db = build_chroma_store(documents)

emissions_embeddings = tracker_embeddings.stop()
print(f"Embeddings creation emissions: {emissions_embeddings:.6f} kg CO2")

# 9. RAG Response Generation and Answer Refinement

Defines functions to generate answers using Retrieval-Augmented Generation (RAG), detect issues in responses, and refine them:

- `detect_answer_issues()`: Checks if an answer is incomplete, repetitive, cut off, or disclaimer-only.  
- `retry_with_better_retrieval()`: Performs additional retrieval when issues are detected.  
- `refine_failed_answer()`: Re-generates answers based on improved context and task-specific instructions.  
- `generate_rag_response()`: Combines retrieval, LLM generation, and recursive refinement to produce a final answer.  
- `ask()`: Simple wrapper to query the system and print the result.


In [None]:
def detect_answer_issues(
    answer, question, tokens_generated=None, max_tokens=MAX_NEW_TOKENS
):
    answer_lower = answer.lower().strip()

    if (
        "no relevant information" in answer_lower
        or "does not provide information" in answer_lower
    ):
        return True, "no_info"

    sentences = [s.strip() for s in answer.split(".") if len(s.strip()) > 10]
    if len(sentences) >= 2:
        sentence_counts = {}
        for sent in sentences:
            sent_normalized = sent.lower().strip()
            if sent_normalized:
                sentence_counts[sent_normalized] = (
                    sentence_counts.get(sent_normalized, 0) + 1
                )

        if any(count > 1 for count in sentence_counts.values()):
            return True, "repetitive"

    if tokens_generated and tokens_generated >= max_tokens - 5:
        return True, "token_limit"

    if answer and len(answer) > 20:
        last_char = answer.strip()[-1]
        if last_char not in '.!?":)]}':
            return True, "incomplete"

        if len(sentences) > 1:
            last_sentence = sentences[-1].strip()
            if len(last_sentence) > 0 and len(last_sentence) < 20:
                if last_sentence and last_sentence[0].islower():
                    return True, "cutoff"

    disclaimer_phrases = [
        "i'm sorry",
        "i cannot",
        "i don't have",
        "not possible to determine",
        "context does not",
    ]

    if len(answer) < 150 and any(
        phrase in answer_lower for phrase in disclaimer_phrases
    ):
        substantial_sentences = [
            s
            for s in sentences
            if len(s.strip()) > 30
            and not any(phrase in s.lower() for phrase in disclaimer_phrases)
        ]
        if len(substantial_sentences) == 0:
            return True, "disclaimer_only"

    return False, None


def retry_with_better_retrieval(query_text, task_type, original_context, issue_type):
    if issue_type == "no_info":
        k = 8
        threshold = 0.15
    elif issue_type in ["incomplete", "cutoff", "token_limit", "repetitive"]:
        k = 3
        threshold = RELEVANCE_THRESHOLD
    elif issue_type == "disclaimer_only":
        k = 7
        threshold = 0.2
    else:
        k = TOP_K_RESULTS
        threshold = RELEVANCE_THRESHOLD

    results = db.similarity_search_with_relevance_scores(query_text, k=k)

    if len(results) == 0 or results[0][1] < threshold:
        return None, None

    if issue_type in ["incomplete", "cutoff", "token_limit", "repetitive"]:
        context_parts = [doc.page_content for doc, _ in results[:3]]
        context_text = "\n\n".join(context_parts)[:600]
    else:
        context_parts = [doc.page_content for doc, _ in results]
        context_text = "\n\n".join(context_parts)[:1200]

    return context_text, [score for _, score in results]


def refine_failed_answer(original_answer, context, query_text, issue_type):
    if issue_type == "no_info":
        refine_prompt = f"""Context: {context[:700]}

Question: {query_text}

Answer the question using only the information from the context above.

Answer:"""

    elif issue_type == "repetitive":
        refine_prompt = f"""Context: {context[:500]}

Question: {query_text}

Provide a concise answer without repeating information. 2-3 distinct sentences:

Answer:"""

    elif issue_type in ["incomplete", "cutoff", "token_limit"]:
        refine_prompt = f"""Context: {context[:500]}

Question: {query_text}

Provide a concise, complete answer in 2-3 sentences:

Answer:"""

    elif issue_type == "disclaimer_only":
        refine_prompt = f"""Context: {context[:700]}

Question: {query_text}

Answer the question directly using information from the context. Be specific and factual.

Answer:"""

    else:
        return original_answer

    tokens = tokenizer.encode(refine_prompt, truncation=False)
    if len(tokens) > 450:
        refine_prompt = tokenizer.decode(tokens[:450], skip_special_tokens=True)
        refine_prompt += f"\n\nQuestion: {query_text}\n\nAnswer:"

    refined = llm.invoke(refine_prompt)
    return refined.strip()


def generate_rag_response(
    query_text, task_type, k=TOP_K_RESULTS, threshold=RELEVANCE_THRESHOLD
):
    results = db.similarity_search_with_relevance_scores(query_text, k=k)

    if task_type == "creative_generation":
        threshold = 0.1

    if len(results) == 0 or results[0][1] < threshold:
        if task_type == "creative_generation" and len(results) > 0:
            context_text = results[0][0].page_content
        else:
            return {
                "answer": "No relevant information found.",
                "sources": [],
                "task_type": task_type,
                "scores": [],
                "iterations": 0,
                "issue_detected": "no_info",
                "fixed": False,
                "should_retry": True,
            }
    else:
        context_text = "\n\n".join([doc.page_content for doc, _score in results])

    prompt_template = PromptTemplate.from_template(PROMPT_TEMPLATES[task_type])
    prompt = prompt_template.format(context=context_text, question=query_text)

    tokens = tokenizer.encode(prompt, truncation=False)
    token_limit_exceeded = len(tokens) > 400

    if token_limit_exceeded:
        truncated_tokens = tokens[:450]
        truncated_prompt = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
        prompt = truncated_prompt + f"\n\nQuestion: {query_text}\n\nAnswer:"
        if "Context:" in truncated_prompt:
            try:
                context_text = (
                    truncated_prompt.split("Context:")[1].split("Question:")[0].strip()
                )
            except:
                pass

    answer = llm.invoke(prompt)

    answer_tokens = len(tokenizer.encode(answer, truncation=False))

    sources = [doc.metadata.get("source", "Unknown") for doc, _score in results]

    iterations = 0
    issue_detected = None
    fixed = False

    if ENABLE_RECURSIVE_EDITING:
        has_issues, issue_type = detect_answer_issues(
            answer, query_text, answer_tokens, MAX_NEW_TOKENS
        )

        if has_issues:
            issue_detected = issue_type

            for iteration in range(MAX_EDIT_ITERATIONS):
                new_context, new_scores = retry_with_better_retrieval(
                    query_text, task_type, context_text, issue_type
                )

                if new_context:
                    context_text = new_context
                    if new_scores:
                        sources = [
                            doc.metadata.get("source", "Unknown")
                            for doc, _ in results[: len(new_scores)]
                        ]

                iterations += 1
                refined_answer = refine_failed_answer(
                    answer, context_text, query_text, issue_type
                )

                refined_tokens = len(tokenizer.encode(refined_answer, truncation=False))
                still_has_issues, _ = detect_answer_issues(
                    refined_answer, query_text, refined_tokens, MAX_NEW_TOKENS
                )

                if not still_has_issues and refined_answer != answer:
                    answer = refined_answer
                    fixed = True
                    break
                elif refined_answer != answer:
                    answer = refined_answer
                else:
                    break

    return {
        "answer": answer,
        "sources": sources,
        "task_type": task_type,
        "scores": [score for _, score in results],
        "iterations": iterations,
        "issue_detected": issue_detected,
        "fixed": fixed,
    }


def ask(question, task_type="rag"):
    result = generate_rag_response(question, task_type)
    print(f"\nQ: {question}")
    print(f"A: {result['answer']}")
    return result["answer"]

# 10. Load evaluation prompts

We load a list of test questions from a JSON file.
Each question is labeled with a category (e.g., summarization, reasoning, or RAG).

In [None]:
with open(PROMPTS_FILE, "r") as f:
    prompts_data = json.load(f)

prompts = prompts_data["prompts"]
print(f"Loaded {len(prompts)} evaluation prompts")

# 11. Run Prompts and Track Metrics

Iterates over all prompts, generates RAG responses, retries if necessary, and collects metrics:

- Measures **latency** and **energy usage** using `OfflineEmissionsTracker`.  
- Retries queries with a lower threshold if no relevant information is found and recursive editing is enabled.  
- Tracks **task-level success rates** and response metadata (`iterations`, `issue_detected`, `fixed`).  
- Stores all results in a list for later analysis.


In [None]:
results = []
task_metrics = {
    "summarization": {"total": 0, "success": 0},
    "reasoning": {"total": 0, "success": 0},
    "rag": {"total": 0, "success": 0},
    "paraphrasing": {"total": 0, "success": 0},
    "creative_generation": {"total": 0, "success": 0},
}

total_latency = 0
latencies = []

tracker_inference = OfflineEmissionsTracker(
    country_iso_code=COUNTRY_ISO_CODE, project_name="inference", log_level="error"
)
tracker_inference.start()

for p in prompts:
    start_time = time.time()
    result = generate_rag_response(p["prompt"], task_type=p["category"])

    if (
        result["answer"] == "No relevant information found."
        and ENABLE_RECURSIVE_EDITING
    ):
        print("  [Retrying with lower threshold.]")
        retry_results = db.similarity_search_with_relevance_scores(p["prompt"], k=8)

        if len(retry_results) > 0 and retry_results[0][1] > 0.1:
            retry_context = "\n\n".join(
                [doc.page_content for doc, _ in retry_results[:5]]
            )[:800]

            retry_prompt = f"""Context: {retry_context}

Question: {p["prompt"]}

Answer:"""

            retry_tokens = tokenizer.encode(retry_prompt, truncation=False)
            if len(retry_tokens) > 450:
                retry_prompt = tokenizer.decode(
                    retry_tokens[:450], skip_special_tokens=True
                )
                retry_prompt += f"\n\nQuestion: {p['prompt']}\n\nAnswer:"

            retry_answer = llm.invoke(retry_prompt)

            if retry_answer and retry_answer != "No relevant information found.":
                result["answer"] = retry_answer
                result["fixed"] = True
                result["iterations"] = 1
                result["issue_detected"] = "no_info"

    end_time = time.time()

    latency = end_time - start_time
    total_latency += latency
    latencies.append(latency)

    edit_info = ""
    if result.get("iterations", 0) > 0:
        status = "Fixed" if result.get("fixed") else "Attempted"
        issue = result.get("issue_detected", "unknown")
        edit_info = f" [{status}: {issue}, {result['iterations']}x]"

    print(f"\nPrompt {p['id']}: {p['prompt']}")
    print(
        f"Answer: {result['answer'][:150]}{'...' if len(result['answer']) > 150 else ''}{edit_info}"
    )
    print(f"Latency: {latency:.3f}s")

    task_metrics[p["category"]]["total"] += 1
    if result["answer"] != "No relevant information found.":
        task_metrics[p["category"]]["success"] += 1

    results.append(
        {
            "id": p["id"],
            "category": p["category"],
            "difficulty": p["difficulty"],
            "prompt": p["prompt"],
            "answer": result["answer"],
            "expected": p.get("expected_answer"),
            "scores": result["scores"],
            "iterations": result.get("iterations", 0),
            "issue_detected": result.get("issue_detected"),
            "fixed": result.get("fixed", False),
            "latency": latency,
        }
    )

emissions_inference = tracker_inference.stop()
energy_inference = (
    tracker_inference._total_energy.kWh
    if hasattr(tracker_inference._total_energy, "kWh")
    else 0
)

# 12. Performance and Carbon Emissions

Calculates and prints performance metrics for all prompts:

- **Latency:** total, average, minimum, and maximum per query.  
- **Carbon emissions and energy consumption:** for model loading, embeddings, and inference.  
- Computes total and per-query values for both CO2 emissions and energy usage.


In [None]:
print("PERFORMANCE METRICS")

avg_latency = total_latency / len(prompts)
min_latency = min(latencies)
max_latency = max(latencies)

print(f"Total latency:       {total_latency:.3f}s")
print(f"Average per query:   {avg_latency:.3f}s")
print(f"Min latency:         {min_latency:.3f}s")
print(f"Max latency:         {max_latency:.3f}s")

print("\n CARBON EMISSIONS & ENERGY")
total_emissions = emissions_loading + emissions_embeddings + emissions_inference

try:
    energy_loading = tracker_loading._total_energy.kWh
    energy_embeddings = tracker_embeddings._total_energy.kWh
    energy_inference_val = energy_inference
    total_energy = energy_loading + energy_embeddings + energy_inference_val

    print(f"Model Loading: {emissions_loading:.6f} kg CO2  |  {energy_loading:.6f} kWh")
    print(
        f"Embeddings:    {emissions_embeddings:.6f} kg CO2  |  {energy_embeddings:.6f} kWh"
    )
    print(
        f"Inference:     {emissions_inference:.6f} kg CO2  |  {energy_inference_val:.6f} kWh"
    )
    print(f"TOTAL:         {total_emissions:.6f} kg CO2  |  {total_energy:.6f} kWh")
    print(
        f"Per query:     {emissions_inference / len(prompts):.6f} kg CO2  |  {energy_inference_val / len(prompts):.6f} kWh"
    )
except:
    print(f"Model Loading: {emissions_loading:.6f} kg CO2")
    print(f"Embeddings:    {emissions_embeddings:.6f} kg CO2")
    print(f"Inference:     {emissions_inference:.6f} kg CO2")
    print(f"TOTAL:         {total_emissions:.6f} kg CO2")
    print(f"Per query:     {emissions_inference / len(prompts):.6f} kg CO2")