In [None]:
!pip install huggingface_hub
!pip install langchain_community
!pip install sentence_transformers
!pip install faiss-cpu
!pip install bitsandbytes
!pip install torchmetrics --upgrade
!pip install bert-score
!pip install --upgrade datasets fsspec

# Import

In [None]:
import faiss
import numpy as np
import os
import json
import re
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder, util, SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import time
from tqdm import tqdm

# Modelli

In [None]:
HF_TOKEN = "your_api_key"

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    trust_remote_code=True,
    token=HF_TOKEN
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    token=HF_TOKEN
)

# Create test set PubMedQA

In [None]:
def extract_qa_and_joined_contexts(dataset_split):
    qa_pairs = []
    joined_contexts = []

    for item in dataset_split:
        question = item["question"]
        long_answer = item["long_answer"]
        context_paragraphs = item["context"]["contexts"]

        qa_pairs.append((question, long_answer))

        # Unisci tutti i paragrafi in un singolo blocco di contesto
        full_context = "\n".join(context_paragraphs)
        joined_contexts.append(full_context)

    return qa_pairs, joined_contexts

In [None]:
from datasets import load_dataset

# Load dataset PubMedQA from HuggingFace
dataset = load_dataset("qiaojin/PubMedQA", "pqa_artificial")
train_split = dataset["train"]

In [None]:
# Get QA pairs and contexts
qa_pairs, contexts = extract_qa_and_joined_contexts(train_split)

# Embeddings PubMedQA

# Index FAISS

In [None]:
# Use MXBAI-large with normalised embeddings
bert_model = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1')

In [None]:
# 3. Calculate normalised embeddings
document_embeddings = bert_model.encode(
    contexts,
    normalize_embeddings=True,      # Thus normalise L2 ≈ cosine
    show_progress_bar=True,
    convert_to_numpy=True
)

# 4. Create FAISS index with L2 metric
dimension = document_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)

# 5. Add embeddings
index.add(document_embeddings)

# 6. Save index
faiss.write_index(index, "faiss_index_flatl2_mxbai_pubmedqa.bin")

# Load embedding FAISS

In [None]:
index = faiss.read_index("faiss_index_flatl2_mxbai_pubmedqa.bin")

# Cache construction at levels L1, L2, L3



In [None]:
class HierarchicalFaissCache:
    THRESHOLD = 0.60  # Minimum similarity threshold (cosines)

    def __init__(self, faiss_index, chunked_data, embedding_model, top_k=10):
        self.index = faiss_index
        self.chunked_data = chunked_data
        self.embedding_model = embedding_model
        self.top_k = top_k

        self.L1_capacity = 25
        self.L2_capacity = 50
        self.L3_capacity = 100

        self.L1 = []  # [(score, idx)]
        self.L2 = []  # [(score, idx)]
        self.L3 = []  # [(score, idx)]

    def _get_query_embedding(self, query_text):
      embedding = self.embedding_model.encode([query_text], normalize_embeddings=True).astype('float32')
      return embedding

    def _similarity(self, query_tensor, chunk_idx):
      chunk_embedding = self.index.reconstruct(int(chunk_idx))  # numpy array
      return float(np.dot(query_tensor, chunk_embedding))

    def _retrieve_from_level(self, level, query_vector):
        scored_results = []
        seen_texts = set()  # Together to keep track of texts already seen

        for _, idx in level:
            sim = self._similarity(query_vector, idx)
            if sim >= self.THRESHOLD:

              current_text = self.chunked_data[idx]

              if current_text not in seen_texts:
                seen_texts.add(current_text)
                scored_results.append((sim, idx))

        return sorted(scored_results, key=lambda x: x[0], reverse=True)[:self.top_k]

    def _search_faiss(self, query_tensor, query_embedding):
        D, I = self.index.search(query_embedding, self.top_k * 1)
        results = []
        seen_texts = set()  # Together to keep track of texts already seen
        for d, idx in zip(D[0], I[0]):
          if idx == -1:
              continue
          sim = self._similarity(query_tensor, idx)

          # Get the text of the current chunk
          current_text = self.chunked_data[idx]

          # Check if text has already been seen
          if current_text not in seen_texts and sim >= self.THRESHOLD:
              results.append((sim, idx))
              seen_texts.add(current_text) # Add text to the set of viewed texts

        return sorted(results, key=lambda x: x[0], reverse=True)[:self.top_k]

    def retrieve_from_cache(self, query_text):
        start_time = time.time()

        self.demote_cache_levels()
        query_embedding = self._get_query_embedding(query_text)
        query_tensor = query_embedding[0]

        # L1
        results_L1 = self._retrieve_from_level(self.L1, query_tensor)
        if len(results_L1) >= self.top_k:
            for score, idx in results_L1:
                self.L1 = [(s, i) for (s, i) in self.L1 if i != idx]
                self.L1.insert(0, (score, idx))
            retrieval_time = round(time.time() - start_time, 4)
            return [self.chunked_data[i] for _, i in results_L1], "L1", retrieval_time

        # L2
        results_L2 = self._retrieve_from_level(self.L2, query_tensor)
        if results_L2:
            self._promote_to_L1(results_L2)
            retrieval_time = round(time.time() - start_time, 4)
            return [self.chunked_data[i] for _, i in results_L2], "L2", retrieval_time

        # L3
        results_L3 = self._retrieve_from_level(self.L3, query_tensor)
        if results_L3:
            self._promote_to_L1(results_L3)
            retrieval_time = round(time.time() - start_time, 4)
            return [self.chunked_data[i] for _, i in results_L3], "L3", retrieval_time

        # FAISS fallback
        new_results = self._search_faiss(query_tensor, query_embedding)
        self._promote_to_L1(new_results)
        retrieval_time = round(time.time() - start_time, 4)
        return [self.chunked_data[i] for _, i in new_results], "FAISS", retrieval_time

    def _promote_to_L1(self, results):
        for score, idx in results:
            self.L1 = [(s, i) for (s, i) in self.L1 if i != idx]
            self.L1.insert(0, (score, idx))
        self.demote_cache_levels()

    def demote_cache_levels(self):
        # L1 → L2
        overflow = max(0, len(self.L1) - self.L1_capacity)
        if overflow:
            demoted = self.L1[-overflow:]
            self.L1 = self.L1[:-overflow]
            demoted_idxs = {idx for _, idx in demoted}
            self.L2 = [(s, i) for (s, i) in self.L2 if i not in demoted_idxs]
            self.L2.extend(demoted)
            self.L2 = sorted(self.L2, key=lambda x: x[0], reverse=True)

        # L2 → L3
        overflow_L2 = max(0, len(self.L2) - self.L2_capacity)
        if overflow_L2:
            demoted_L3 = self.L2[-overflow_L2:]
            self.L2 = self.L2[:-overflow_L2]
            demoted_idxs = {idx for _, idx in demoted_L3}
            self.L3 = [(s, i) for (s, i) in self.L3 if i not in demoted_idxs]
            self.L3.extend(demoted_L3)
            self.L3 = sorted(self.L3, key=lambda x: x[0], reverse=True)[:self.L3_capacity]

# Cache L1

In [None]:
class SingleLevelFaissCache:
    THRESHOLD = 0.50  # Minimum similarity threshold (cosines)

    def __init__(self, faiss_index, chunked_data, embedding_model, top_k=10):
        self.index = faiss_index
        self.chunked_data = chunked_data
        self.embedding_model = embedding_model
        self.top_k = top_k

        self.L1_capacity = 25
        self.L1 = []  # [(score, idx)]

    def _get_query_embedding(self, query_text):
        embedding = self.embedding_model.encode([query_text], normalize_embeddings=True).astype('float32')
        return embedding

    def _similarity(self, query_tensor, chunk_idx):
        chunk_embedding = self.index.reconstruct(int(chunk_idx))
        return float(np.dot(query_tensor, chunk_embedding))

    def _retrieve_from_L1(self, query_tensor):
        scored_results = []
        seen_texts = set()
        for _, idx in self.L1:
            sim = self._similarity(query_tensor, idx)
            if sim >= self.THRESHOLD:
                text = self.chunked_data[idx]
                if text not in seen_texts:
                    seen_texts.add(text)
                    scored_results.append((sim, idx))
        return sorted(scored_results, key=lambda x: x[0], reverse=True)[:self.top_k]

    def _search_faiss(self, query_tensor, query_embedding):
        D, I = self.index.search(query_embedding, self.top_k * 1)
        results = []
        seen_texts = set()
        for d, idx in zip(D[0], I[0]):
            if idx == -1:
                continue
            sim = self._similarity(query_tensor, idx)
            text = self.chunked_data[idx]
            if text not in seen_texts and sim >= self.THRESHOLD:
                results.append((sim, idx))
                seen_texts.add(text)
        return sorted(results, key=lambda x: x[0], reverse=True)[:self.top_k]

    def retrieve_from_cache(self, query_text):
        start_time = time.time()
        query_embedding = self._get_query_embedding(query_text)
        query_tensor = query_embedding[0]

        # Check L1 cache
        results_L1 = self._retrieve_from_L1(query_tensor)
        if len(results_L1) >= self.top_k:
            self._update_L1(results_L1)
            retrieval_time = round(time.time() - start_time, 4)
            return [self.chunked_data[i] for _, i in results_L1], "L1", retrieval_time

        # Fallback to FAISS
        new_results = self._search_faiss(query_tensor, query_embedding)
        self._update_L1(new_results)
        retrieval_time = round(time.time() - start_time, 4)
        return [self.chunked_data[i] for _, i in new_results], "FAISS", retrieval_time

    def _update_L1(self, results):
        for score, idx in results:
            self.L1 = [(s, i) for (s, i) in self.L1 if i != idx]
            self.L1.insert(0, (score, idx))
        self.L1 = self.L1[:self.L1_capacity]

# Reranker

In [None]:
# --- Initialise reranker ---
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

In [None]:
def rerank_with_batching(query_text, top_chunks, batch_size=8):
        if not top_chunks:
            return []

        texts = [chunk for chunk in top_chunks]
        pairs = [(query_text, ctx) for ctx in texts]

        all_scores = []
        # The CrossEncoder reranker already handles batching internally
        scores = reranker.predict(pairs, batch_size=batch_size)
        all_scores.extend(scores)

        reranked = sorted(zip(all_scores, top_chunks), key=lambda x: x[0], reverse=True)
        return [chunk for _, chunk in reranked[:10]]

# Experiment with HieRAG 3 levels caching

In [None]:
# Set hyperparameters top_k
cache_manager = HierarchicalFaissCache(faiss_index=index, chunked_data=contexts, embedding_model=bert_model, top_k=20)

In [None]:
output_file = "your_output_file.ndjson"

questions = [qa[0] for qa in qa_pairs]
answer = [qa[1] for qa in qa_pairs]

with open(output_file, "a", encoding="utf-8") as f:
  for index, question in tqdm(enumerate(questions[:1000]), total=len(questions[:1000]), desc="Processing questions"):
    t1 = time.time()
    query = question
    ground_truth = answer[index]

    t1_retrieval = time.time()
    top_chunks, hit_level, retrieval_time = cache_manager.retrieve_from_cache(query)

    # Uncomment this set of lines if you want HieRAG ++
    #top_chunks_rerank = rerank_with_batching(query, top_chunks, batch_size=2)

    #search_results = "\n".join(
    #    [chunk for chunk in top_chunks_rerank]
    #)

    search_results = "\n".join(
        [chunk for chunk in top_chunks]
    )

    t2_retrieval = time.time()

    prompt = f"""You are a biomedical assistant.
Analyze the context from medical literature and provide a well-reasoned, informative answer to the following research question.
Your answer should reflect the current scientific understanding, using the context information.
Context information is below.

{search_results}

Now answer this question:

Question: {query}

Answer:"""

    t1_answer = time.time()

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    output = model.generate(
      input_ids=inputs["input_ids"],
      attention_mask=inputs["attention_mask"],
      max_new_tokens=50,
      eos_token_id=tokenizer.eos_token_id,
      pad_token_id=tokenizer.eos_token_id
    )

    response = tokenizer.decode(output[0], skip_special_tokens=True)
    t2_answer = time.time()

    predicted_answer = (
          response.split("Answer:")[-1].strip().split("\n")[0].strip(" .")
        )

    t2 = time.time()

    record = {
      "prompts": query,
      "ground_truth": ground_truth,
      "responses": predicted_answer,
      "rag_time": round(t2 - t1, 3),
      "retrieval_time": round(t2_retrieval - t1_retrieval, 3),
      "answer_time": round(t2_answer - t1_answer, 3),
      "cache_hit_level": hit_level,
      "cache_L1_size": len(cache_manager.L1)
      "cache_L2_size": len(cache_manager.L2),
      "cache_L3_size": len(cache_manager.L3)
    }
    f.write(json.dumps(record, ensure_ascii=False) + "\n")

# Experiment HieRAG 1 level caching

In [None]:
# Set hyperparameters top_k
cache_manager = SingleLevelFaissCache(faiss_index=index, chunked_data=contexts, embedding_model=bert_model, top_k=1)

In [None]:
# Save in your output file
output_file = "you_file_path.json"

dataset = list(dataset)
questions = [qa[0] for qa in dataset]
answer = [qa[1] for qa in dataset]

with open(output_file, "a", encoding="utf-8") as f:
  for index, question in tqdm(enumerate(questions[:500]), total=len(questions[:500]), desc="Processing questions"):
    t1 = time.time()
    query = question
    ground_truth = answer[index]

    top_chunks, hit_level, retrieval_time = cache_manager.retrieve_from_cache(query)

    t1_retrieval = time.time()

    # Uncomment this set of lines if you want HieRAG ++
    #top_chunks_rerank = rerank_with_batching(query, top_chunks, batch_size=2)

    #search_results = "\n".join(
    #    [chunk for chunk in top_chunks_rerank]
    #)

    search_results = "\n".join(
        [chunk for chunk in top_chunks]
    )

    t2_retrieval = time.time()

    prompt = f"""You are a biomedical assistant.
Analyze the context from medical literature and provide a well-reasoned, informative answer to the following research question.
Your answer should reflect the current scientific understanding, using the context information.
Context information is below.

{search_results}

Now answer this question:

Question: {query}

Answer:"""

    t1_answer = time.time()

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    output = model.generate(
      input_ids=inputs["input_ids"],
      attention_mask=inputs["attention_mask"],
      max_new_tokens=7,
      eos_token_id=tokenizer.eos_token_id,
      pad_token_id=tokenizer.eos_token_id
    )

    response = tokenizer.decode(output[0], skip_special_tokens=True)
    t2_answer = time.time()

    predicted_answer = (
    response.split("Answer:")[-1]
    .split("(")[0]
    .split("[Note:")[0]
    .split("\n")[0]
    .strip(" .")
    )

    t2 = time.time()

    record = {
      "prompts": query,
      "ground_truth": ground_truth,
      "responses": predicted_answer,
      "rag_time": round(t2 - t1, 3),
      "retrieval_time": round(t2_retrieval - t1_retrieval, 3),
      "answer_time": round(t2_answer - t1_answer, 3),
      "cache_hit_level": hit_level,
      "cache_L1_size": len(cache_manager.L1)
    }
    f.write(json.dumps(record, ensure_ascii=False) + "\n")

# BERTScore evaluation

In [None]:
from bert_score import score

# Load responses
path = "your_output_file.ndjson"
with open(path, "r", encoding="utf-8") as f:
    data = [json.loads(line) for line in f]

references = [entry["ground_truth"] for entry in data]
candidates = [entry["responses"] for entry in data]

# Calculate BERTScore
P, R, F1 = score(candidates, references, lang="en", model_type="bert-large-uncased")

# Show mean
print(f"BERTScore Precision: {P.mean().item():.4f}")
print(f"BERTScore Recall:    {R.mean().item():.4f}")
print(f"BERTScore F1:        {F1.mean().item():.4f}")

# Evaluation of responses time

In [None]:
def calculate_average_times(file_path):
    """Calculates the average value of RAG time, Retrieval Time and Answer Time.

    Args:
        file_path: Path to the JSON file containing the data.

    Returns:
        A tuple containing the average values of RAG time, Retrieval Time and Answer Time.
    """

    total_rag_time = 0
    total_retrieval_time = 0
    total_answer_time = 0
    num_records = 0

    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            total_rag_time += record["rag_time"]
            total_retrieval_time += record["retrieval_time"]
            total_answer_time += record["answer_time"]
            num_records += 1

    avg_rag_time = total_rag_time / num_records
    avg_retrieval_time = total_retrieval_time / num_records
    avg_answer_time = total_answer_time / num_records

    return avg_rag_time, avg_retrieval_time, avg_answer_time

# Percorso del file JSON
file_path = "your_output_file.json"

# Calcola e stampa i valori medi
avg_rag_time, avg_retrieval_time, avg_answer_time = calculate_average_times(file_path)

print(f"Average RAG Time: {avg_rag_time:.3f}")
print(f"Average Retrieval Time: {avg_retrieval_time:.3f}")
print(f"Average Answer Time: {avg_answer_time:.3f}")

# Rouge, Bleu e Comet

In [None]:
!pip install unbabel-comet rouge-score comet

In [None]:
import json
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from comet import download_model, load_from_checkpoint
import torch

# Path file NDJSON
file_path = 'your_output_file.ndjson'

# Loading data
with open(file_path, 'r') as f:
    data = [json.loads(line) for line in f]

# Initialise Metrics
smoothie = SmoothingFunction().method4
rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

bleu_scores = []
rouge1_scores = []
rouge2_scores = []
rougeL_scores = []

# Extract text
sources = []
predictions = []
references = []

for entry in data:
    pred = entry['responses'].strip()
    ref = entry['ground_truth'].strip()
    predictions.append(pred)
    references.append(ref)
    sources.append(entry.get('prompts', ''))

    # BLEU
    bleu = sentence_bleu([ref.split()], pred.split(), smoothing_function=smoothie)
    bleu_scores.append(bleu)

    # ROUGE
    scores = rouge.score(ref, pred)
    rouge1_scores.append(scores['rouge1'].fmeasure)
    rouge2_scores.append(scores['rouge2'].fmeasure)
    rougeL_scores.append(scores['rougeL'].fmeasure)

# Mean BLEU and ROUGE
print(f"Average BLEU: {sum(bleu_scores)/len(bleu_scores):.4f}")
print(f"Average ROUGE-1: {sum(rouge1_scores)/len(rouge1_scores):.4f}")
print(f"Average ROUGE-2: {sum(rouge2_scores)/len(rouge2_scores):.4f}")
print(f"Average ROUGE-L: {sum(rougeL_scores)/len(rougeL_scores):.4f}")

# COMET
print("Loading COMET model...")
model_path = download_model("Unbabel/wmt22-comet-da")
comet_model = load_from_checkpoint(model_path)
comet_data = [{"src": s, "mt": p, "ref": r} for s, p, r in zip(sources, predictions, references)]
comet_scores = comet_model.predict(comet_data, batch_size=8, gpus=1 if torch.cuda.is_available() else 0)

comet_scores_list = comet_scores.scores

print(f"Average COMET: {sum(comet_scores_list)/len(comet_scores_list):.4f}")