## Imports/Installations

In [None]:
%pip install sentence-transformers
%pip install beir
%pip install groq
%pip install ujson
%pip install textstat
%pip install openai



In [None]:
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import io
import os
import time
import pickle
import torch
from tqdm import tqdm
import numpy as np
from google.colab import drive
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import KFold
import random
import requests
import json
from groq import Groq
from beir.datasets.data_loader import GenericDataLoader
import torch.nn.functional as F
import textstat
import matplotlib.pyplot as plt
from openai import OpenAI



import sys
import ujson as json
import re
import string
from collections import Counter
from collections import defaultdict


drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Fill in this cell with your data paths

In [None]:
# Enter your data paths here

data_path = "/content/drive/My Drive/Imperial/Dissertation/datasets/hotpotqa"
corpus_embeddings_path = "/content/drive/My Drive/Imperial/Dissertation/hotpot_corpus_embeddings.npy"
corpus_ids_path = "/content/drive/My Drive/Imperial/Dissertation/hotpot_corpus_ids.npy"
query_embedding_path = "/content/drive/My Drive/Imperial/Dissertation/embeddings/query_embeddings.npy"
gold_truth_path = "/content/drive/My Drive/Imperial/Dissertation/datasets/hotpotqa/queries.jsonl"
classifier_model_path = "/content/drive/My Drive/Imperial/Dissertation/classifier_model.pkl"
full_wiki_path = "/content/drive/MyDrive/Imperial/Dissertation/hotpot_dev_fullwiki_v1.json"
hotpot_prompt_path = "/content/drive/MyDrive/Imperial/Dissertation/hotpotqa_prompt.txt"

## Evaluation functions

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def precision_at_k(ranking_scores, relevant_docs, k):
    retrieved_docs = [doc_id for doc_id, _ in ranking_scores[:k]]
    relevant_retrieved_docs = len(set(retrieved_docs) & relevant_docs)
    return relevant_retrieved_docs / k

def recall_at_k(ranking_scores, relevant_docs, k):
    retrieved_docs = [doc_id for doc_id, _ in ranking_scores[:k]]
    relevant_retrieved_docs = len(set(retrieved_docs) & relevant_docs)
    return relevant_retrieved_docs / len(relevant_docs) if relevant_docs else 0

def ndcg_at_k(ranking_scores, relevant_docs, k):
    dcg = 0.0
    idcg = sum([1.0 / np.log2(i + 2) for i in range(min(len(relevant_docs), k))])

    if idcg == 0:  # Prevent division by zero
      return 0.0

    for i, (doc_id, score) in enumerate(ranking_scores[:k]):
        if doc_id in relevant_docs:
            dcg += 1.0 / np.log2(i + 2)
    return dcg / idcg

def load_classifier_model():
    global classifier_model_path

    # Detect the device
    print(f'Using device: {device}')

    # Load the classifier model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    with open(classifier_model_path, 'rb') as f:
        classifier_model = pickle.load(f)

    classifier_model.to(device)

    return classifier_model, tokenizer


def create_corpus_embeddings(corpus, model):
    corpus_ids, corpus_embeddings = encode_corpus(corpus, model)
    corpus_embeddings_np = corpus_embeddings.cpu().numpy()

    # save_embeddings(corpus_ids, corpus_embeddings, "corpus_embeddings")
    np.save("/content/drive/My Drive/Imperial/Dissertation/hotpot_corpus_embeddings.npy", np.array(corpus_embeddings_np))
    np.save("/content/drive/My Drive/Imperial/Dissertation/hotpot_corpus_ids.npy", np.array(corpus_ids))


def create_query_embeddings(queries, model):
    query_ids, query_embeddings = encode_queries(queries, model)
    save_embeddings(query_ids, query_embeddings, "query_embeddings")

def create_embeddings(corpus, queries, model):
    corpus_ids, corpus_embeddings = encode_corpus(corpus, model)
    query_ids, query_embeddings = encode_queries(queries, model)
    save_embeddings(corpus_ids, corpus_embeddings, "corpus_embeddings")
    save_embeddings(query_ids, query_embeddings, "query_embeddings")

def load_embedding_model(model_name='all-MiniLM-L6-v2'):
    return SentenceTransformer(model_name)

def load_data(data_path, split):
    return GenericDataLoader(data_folder=data_path).load(split=split)

def calculate_linguistic_complexity(query):
    # Calculate Flesch-Kincaid Grade Level and Dale-Chall Readability Score
    fk_score = textstat.flesch_kincaid_grade(query)
    dc_score = textstat.dale_chall_readability_score(query)

    # Normalize the scores to a common scale or use them directly
    # Here, I'm averaging them, but you can apply your logic
    complexity_score = (fk_score + dc_score) / 2

    return complexity_score

def encode_corpus(corpus, model, batch_size=16):
    corpus_ids = list(corpus.keys())
    corpus_texts = [doc['text'] for doc in corpus.values()]

    corpus_embeddings = []
    for i in tqdm(range(0, len(corpus_texts), batch_size), desc="Encoding Corpus"):
        batch_texts = corpus_texts[i:i + batch_size]
        batch_embeddings = model.encode(batch_texts, convert_to_tensor=True)
        corpus_embeddings.append(batch_embeddings.cpu())  # Move to CPU immediately
        torch.cuda.empty_cache()  # Clear GPU memory

    corpus_embeddings = torch.cat(corpus_embeddings, dim=0)
    return corpus_ids, corpus_embeddings

def encode_queries(queries, model, batch_size=32):
    query_ids = list(queries.keys())
    query_texts = list(queries.values())
    query_embeddings = []
    for i in tqdm(range(0, len(query_texts), batch_size), desc="Encoding Queries"):
        batch_texts = query_texts[i:i + batch_size]
        batch_embeddings = model.encode(batch_texts, convert_to_tensor=True)
        query_embeddings.append(batch_embeddings)
    query_embeddings = torch.cat(query_embeddings, dim=0)
    return query_ids, query_embeddings

def save_embeddings(ids, embeddings, file_prefix):
    np.save(f"./{file_prefix}_ids.npy", np.array(ids))
    np.save(f"./{file_prefix}_embeddings.npy", embeddings.cpu().numpy())

def load_embeddings(file_prefix):
    ids = np.load(f"./{file_prefix}_ids.npy", allow_pickle=True).tolist()
    embeddings = torch.tensor(np.load(f"./{file_prefix}_embeddings.npy"))
    return ids, embeddings

def load_corpus_embeddings_for_evaluation():
    print("Loading embeddings...")
    corpus_ids, corpus_embeddings = load_embeddings("corpus_embeddings")
    print("Embeddings loaded.")
    return corpus_ids, corpus_embeddings

def load_query_embeddings_for_evaluation():
    print("Loading embeddings...")
    query_ids, query_embeddings = load_embeddings("query_embeddings")
    print("Embeddings loaded.")
    return query_ids, query_embeddings

def save_results(file_path, data):
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)

def load_results(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def retrieve(corpus_ids, corpus_embeddings, query_ids, query_embeddings, top_k):
    results = {}
    for i, query_embedding in tqdm(enumerate(query_embeddings), desc="Retrieving Documents", total=len(query_embeddings)):
        scores = torch.matmul(query_embedding, corpus_embeddings.T)
        sorted_indices = torch.argsort(scores, descending=True)[:top_k]
        results[query_ids[i]] = [(corpus_ids[idx], scores[idx].item()) for idx in sorted_indices]
    return results

def retrieve_single_query_complex(corpus_ids, corpus_embeddings, query_id, query_embedding, top_k, similarity_threshold=0.5):

    scores = torch.matmul(query_embedding, corpus_embeddings.T)
    sorted_indices = torch.argsort(scores, descending=True)

    top_results = []
    for idx in sorted_indices:
        if idx >= len(corpus_ids):
            print(f"Index {idx} is out of bounds for corpus_ids with length {len(corpus_ids)}")
            continue

        score = scores[idx].item()
        if score >= similarity_threshold:
            top_results.append((corpus_ids[idx], score))
            if len(top_results) == top_k:
                break
        else:
            break

       # Ensure at least one document is retrieved
    if len(top_results) == 0:
        # Retrieve the top-ranked document even if it doesn't meet the similarity threshold
        valid_index = sorted_indices[sorted_indices < len(corpus_ids)][0]
        top_results.append((corpus_ids[valid_index], scores[valid_index].item()))

    return top_results

def retrieve_single_query_simple(corpus_ids, corpus_embeddings, query_id, query_embedding, top_k, similarity_threshold=0.5):

    scores = torch.matmul(query_embedding, corpus_embeddings.T)

    sorted_indices = torch.argsort(scores, descending=True)[:top_k]

    top_results = [(corpus_ids[idx], scores[idx].item()) for idx in sorted_indices]

    return top_results


# os.environ["GROQ_API_KEY"] = "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"


def format_input(passages_texts, question):
    formatted_passages = ""
    for i, passage in enumerate(passages_texts, 1):
        formatted_passages += f"Document [{i}]: {passage}\n"
    formatted_input = f"{formatted_passages}\nQuestion: {question}\nAnswer:"
    return formatted_input

def generate_answer(question, passages_texts):
    global hotpot_prompt_path
    with open(hotpot_prompt_path, 'r') as file:
        prompt_template = file.read()

    formatted_passages = ""
    for i, passage in enumerate(passages_texts, 1):
        formatted_passages += f"Document [{i}] {passage}\n"

    input_text = prompt_template.format(search_results=formatted_passages, question=question)

    client = Groq(api_key="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")

    chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": input_text,
        }
    ],
    model="llama3-8b-8192",
    )

    answer = chat_completion.choices[0].message.content
    return answer

    # client = OpenAI(
    #   base_url="https://openrouter.ai/api/v1",
    #   api_key="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",
    # )

    # completion = client.chat.completions.create(
    #   model="microsoft/phi-3-mini-128k-instruct",
    #   messages=[
    #     {
    #       "role": "user",
    #       "content": input_text,
    #     },
    #   ],
    # )

    # if completion is None or not completion.choices or not completion.choices[0].message:
    #     # Handle the error, for example, by returning a default value or raising an exception
    #     raise ValueError("Invalid completion response received")

    # answer = completion.choices[0].message.content
    # return answer

def save_answers_to_json(generated_answers, filename, update_existing=True):
    if update_existing and os.path.exists(filename):
        with open(filename, 'r') as file:
            existing_data = json.load(file)
        existing_data.update(generated_answers)
        with open(filename, 'w') as file:
            json.dump(existing_data, file, indent=4)
    else:
        with open(filename, 'w') as file:
            json.dump(generated_answers, file, indent=4)

def map_jsons(generated_answers_filename, gold_truth_answers_filename, title):
    # Load JSON data
    with open(generated_answers_filename, 'r') as file:
        generated_answers = json.load(file)

    # Load JSONL data
    id_to_gold_answer = {}
    with open(gold_truth_answers_filename, 'r') as file:
        for line in file:
            item = json.loads(line)
            id_to_gold_answer[item["_id"]] = item["metadata"]["answer"]

    # Update the first JSON object with gold truth answers
    for key in generated_answers:
        if key in id_to_gold_answer:
            generated_answers[key] = {
                "generated_answer": generated_answers[key],
                "gold_truth_answer": id_to_gold_answer[key]
            }

    # Save the updated JSON data
    with open(f'./{title}_updated_answers.json', 'w') as file:
        json.dump(generated_answers, file, indent=4)

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    ZERO_METRIC = (0, 0, 0)

    if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC
    if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return ZERO_METRIC
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def update_answer(metrics, prediction, gold):
    em = exact_match_score(prediction, gold)
    f1, prec, recall = f1_score(prediction, gold)
    metrics['em'] += float(em)
    metrics['f1'] += f1
    metrics['prec'] += prec
    metrics['recall'] += recall
    return em, prec, recall

def eval(prediction_file, title):
    with open(prediction_file) as f:
        data = json.load(f)

    metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0}
    N = len(data)

    for cur_id, values in data.items():
        generated_answer = values['generated_answer']
        gold_truth_answer = values['gold_truth_answer']

        em, prec, recall = update_answer(metrics, generated_answer, gold_truth_answer)

    for k in metrics.keys():
        metrics[k] /= N

    with open(f'./metrics_{title}.json', 'w') as metrics_file:
        json.dump(metrics, metrics_file, indent=4)

    print(metrics)

def random_sampling(queries, sample_size):
    # Ensure sample size is not larger than the available number of queries
    if sample_size > len(queries):
        sample_size = len(queries)

    # Randomly sample query IDs
    sampled_query_ids = random.sample(list(queries.keys()), sample_size)

    # Create a new dictionary with the sampled queries
    sampled_queries = {qid: queries[qid] for qid in sampled_query_ids}

    return sampled_queries


def compute_query_length(query):
    """Compute the length of the query."""
    return len(query)  # length based on number of characters


def stratified_sampling(queries, sample_size_per_bin, bin_size):
    # Calculate the length of each query (in words)
    # query_lengths = {qid: len(query.split()) for qid, query in queries.items()}
    query_lengths = {qid: len(query) for qid, query in queries.items()}


    # Group queries by their length ranges
    bins = defaultdict(list)
    for qid, length in query_lengths.items():
        bin_range = (length - 1) // bin_size * bin_size + 1
        bins[bin_range].append(qid)

    # Print the number of bins and the word length ranges they contain
    # print(f"Number of bins: {len(bins)}")
    # print("Word length ranges and their respective number of queries:")
    for bin_range, qids in bins.items():
        bin_end = bin_range + bin_size - 1
        # print(f"Length {bin_range}-{bin_end} words: {len(qids)} queries")

    # Sample from each bin
    sampled_queries = {}
    for bin_range, qids in bins.items():
        if len(qids) > sample_size_per_bin:
            sampled_qids = random.sample(qids, sample_size_per_bin)
        else:
            sampled_qids = qids  # Take all if less than the sample size
        sampled_queries.update({qid: queries[qid] for qid in sampled_qids})

    return sampled_queries

def sample_query_ids_by_overlap(json_file_path, num_samples=100):
    """
    Randomly sample a specified number of query IDs from each term overlap bin.

    Args:
    - json_file_path (str): The path to the JSON file containing the HotpotQA data.
    - num_samples (int): Number of random samples to return from each bin.

    Returns:
    - dict: A dictionary with bin names as keys and lists of sampled query IDs as values.
    """

    # Function to calculate overlap between query and supporting facts
    def calculate_term_overlap(query, facts):
        query_terms = set(query.split())
        fact_terms = set(" ".join(facts).split())
        return len(query_terms & fact_terms) / len(query_terms) if query_terms else 0

    # Load JSON data
    with open(json_file_path, 'r') as file:
        data = json.load(file)

    # Calculate overlap for each query
    overlaps = []
    queries_with_overlaps = []

    for entry in data:
        # Extract the query
        query = entry['question']

        # Construct the list of facts using correct indexing
        facts = []
        for i in range(len(entry['supporting_facts'])):
            # Extract title and sentence ID for the current supporting fact
            title = entry['supporting_facts'][i][0]  # Get the title
            sent_index = entry['supporting_facts'][i][1]  # Get the sentence index

            # Search for the title in the context and get the corresponding sentences
            for context_title, sentences in entry['context']:
                if context_title == title:
                    if sent_index < len(sentences):
                        facts.append(sentences[sent_index])
                    break  # Stop searching once we find the matching title

        # Calculate the overlap and store the result
        overlap = calculate_term_overlap(query, facts)
        queries_with_overlaps.append((entry['_id'], overlap))



    # Bin queries into 10 bins based on overlap
    bins = [[] for _ in range(10)]
    bin_edges = [i / 10 for i in range(11)]

    for query_id, query_text, overlap in queries_with_overlaps:
        for i in range(len(bin_edges) - 1):
            if bin_edges[i] <= overlap < bin_edges[i + 1]:
                bins[i].append((query_id, query_text))
                break

    # Randomly sample a specified number of query IDs and texts from each bin
    sampled_query_texts = {
        f"Bin {i + 1}": {query_id: query_text for query_id, query_text in random.sample(bin, min(num_samples, len(bin)))}
        for i, bin in enumerate(bins)
    }

    return sampled_query_texts


def print_results_with_passages_and_answers(results, queries, corpus, qrels, max_entries=5):
    # Initialize a counter
    count = 0

    for query_id, passages in results.items():
        if count >= max_entries:
            break  # Exit the loop if the maximum number of entries has been printed

        # Retrieve the question text
        question_text = queries.get(query_id, "Unknown Question")

        # Retrieve the gold standard answer
        gold_answer = qrels.get(query_id, "No gold answer found")

        print(f"Question: {question_text}")
        print(f"Gold Answer: {gold_answer}")
        print("Retrieved Passages:")

        # Retrieve and print the passages
        for doc_id, score in passages:
            passage_text = corpus.get(doc_id, {}).get('text', "No passage found")
            print(f"Passage ID: {doc_id}")
            print(f"Score: {score}")
            print(f"Passage Text: {passage_text}")
            print("-" * 80)  # Separator for readability

        print("=" * 80)  # Separator between questions

        # Increment the counter
        count += 1

def sample_query_texts_by_overlap(json_file_path, num_samples=100):
    """
    Randomly sample a specified number of query IDs and texts from each term overlap bin.

    Args:
    - json_file_path (str): The path to the JSON file containing the HotpotQA data.
    - num_samples (int): Number of random samples to return from each bin.

    Returns:
    - dict: A dictionary with bin names as keys and dictionaries of sampled query IDs and texts as values.
    """

    # Function to calculate overlap between query and supporting facts
    def calculate_term_overlap(query, facts):
        query_terms = set(query.split())
        fact_terms = set(" ".join(facts).split())
        return len(query_terms & fact_terms) / len(query_terms) if query_terms else 0

    # Load JSON data
    with open(json_file_path, 'r') as file:
        data = json.load(file)

    # Calculate overlap for each query
    queries_with_overlaps = []

    for entry in data:
        # Extract the query
        query = entry['question']

        # Construct the list of facts using correct indexing
        facts = []
        for i in range(len(entry['supporting_facts'])):
            # Extract title and sentence ID for the current supporting fact
            title = entry['supporting_facts'][i][0]  # Get the title
            sent_index = entry['supporting_facts'][i][1]  # Get the sentence index

            # Search for the title in the context and get the corresponding sentences
            for context_title, sentences in entry['context']:
                if context_title == title:
                    if sent_index < len(sentences):
                        facts.append(sentences[sent_index])
                    break  # Stop searching once we find the matching title

        # Calculate the overlap and store the result
        overlap = calculate_term_overlap(query, facts)
        queries_with_overlaps.append((entry['_id'], query, overlap))

    # Bin queries into 10 bins based on overlap
    bins = [[] for _ in range(10)]
    bin_edges = [i / 10 for i in range(11)]

    for query_id, query_text, overlap in queries_with_overlaps:
        for i in range(len(bin_edges) - 1):
            if bin_edges[i] <= overlap < bin_edges[i + 1]:
                bins[i].append((query_id, query_text))
                break

    # Randomly sample a specified number of query IDs and texts from each bin
    sampled_query_texts = {
        f"Bin {i + 1}": {query_id: query_text for query_id, query_text in random.sample(bin, min(num_samples, len(bin)))}
        for i, bin in enumerate(bins)
    }

    return sampled_query_texts

def sample_query_texts_by_type(json_file_path, num_samples=100):
    """
    Randomly sample a specified number of query IDs and texts for each query type (bridge or comparison).

    Args:
    - json_file_path (str): The path to the JSON file containing the HotpotQA data.
    - num_samples (int): Number of random samples to return for each type.

    Returns:
    - dict: A dictionary with query types ('bridge', 'comparison') as keys and dictionaries of sampled query IDs and texts as values.
    """

    # Load JSON data
    with open(json_file_path, 'r') as file:
        data = json.load(file)

    # Group queries by their type
    queries_by_type = {'bridge': [], 'comparison': []}

    for entry in data:
        # Extract query ID, text, and type
        query_id = entry['_id']
        query_text = entry['question']
        query_type = entry['type']

        # Append to the appropriate type group
        if query_type in queries_by_type:
            queries_by_type[query_type].append((query_id, query_text))

    # Randomly sample a specified number of query IDs and texts from each type
    sampled_query_texts = {
        query_type: {query_id: query_text for query_id, query_text in random.sample(queries, min(num_samples, len(queries)))}
        for query_type, queries in queries_by_type.items()
    }

    return sampled_query_texts

def evaluate_results_at_k(results, queries, corpus, top_k, qrels):
    precision_scores = []
    recall_scores = []
    ndcg_scores = []

    for query_id, ranking_scores in results.items():
        relevant_docs = set(qrels.get(query_id, {}))
        precision_scores.append(precision_at_k(ranking_scores, relevant_docs, top_k))
        recall_scores.append(recall_at_k(ranking_scores, relevant_docs, top_k))
        ndcg_scores.append(ndcg_at_k(ranking_scores, relevant_docs, top_k))

    avg_precision = np.mean(precision_scores)
    avg_recall = np.mean(recall_scores)
    avg_ndcg = np.mean(ndcg_scores)

    print(f"Average Precision@{top_k}: {avg_precision}")
    print(f"Average Recall@{top_k}: {avg_recall}")
    print(f"Average NDCG@{top_k}: {avg_ndcg}")

def evaluate_results_multiple_ks(results, queries, corpus, top_ks, qrels):
    for top_k in top_ks:
        evaluate_results_at_k(results, queries, corpus, top_k, qrels)


def generate_save_answers(results, queries, corpus, title, max_retries=5, backoff_factor=1):
    os.environ["GROQ_API_KEY"] = "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"

    generated_answers_file = f'./{title}_generated_answers.json'
    print(f"Generating answers for {len(results)} queries")
    generated_answers = {}

    for query_id, passages in tqdm(results.items(), desc="Generating Answers", total=len(results)):
        passages_text = [corpus[doc_id]['text'] for doc_id, _ in passages]

        # Retry logic
        for attempt in range(max_retries):
            try:
                generated_answer = generate_answer(queries[query_id], passages_text)
                generated_answers[query_id] = generated_answer
                break  # Break out of the retry loop if successful
            except requests.exceptions.RequestException as e:
                wait_time = backoff_factor * (2 ** attempt)
                print(f"Error: {e}. Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
        else:
            print(f"Failed to generate answer for query {query_id} after {max_retries} retries.")
            generated_answers[query_id] = "Error: Unable to generate answer"

        # Save the generated answer to the JSON file immediately
        save_answers_to_json(generated_answers, generated_answers_file, False)

    print("All answers generated and saved.")

def get_results(corpus, queries, corpus_ids, corpus_embeddings, model, top_k):
    # corpus_ids, corpus_embeddings = load_embeddings("corpus_embeddings")
    query_ids, query_embeddings = load_embeddings("query_embeddings")
    start_time = time.time()
    results = retrieve(corpus_ids, corpus_embeddings, query_ids, query_embeddings, top_k)
    retrieval_time = time.time() - start_time
    print(f"Retrieval Time: {retrieval_time:.2f} seconds")
    print(f"Queries Processed per Second: {len(queries) / retrieval_time:.2f}")
    return results

def preprocess_question(question, tokenizer):
    # 'return_tensors' is correctly passed to the tokenizer, not the model
    inputs = tokenizer(question, return_tensors='pt', padding=True, truncation=True)
    return inputs

def classify_question(question, model, tokenizer):
    # Preprocess the question using the tokenizer
    inputs = preprocess_question(question, tokenizer)

    # Ensure inputs are moved to the same device as the model
    inputs = {key: val.to(model.device) for key, val in inputs.items()}

    # Perform inference without computing gradients
    with torch.no_grad():
        outputs = model(**inputs)

    # Extract logits and calculate probabilities
    logits = outputs.logits
    probs = F.softmax(logits, dim=-1)

    # Return the confidence scores for open-ended and not open-ended classes
    confidence_open_ended = probs[:, 1].item()
    confidence_not_open_ended = probs[:, 0].item()

    return confidence_open_ended, confidence_not_open_ended

def interpret_classification(query, model, tokenizer):
    # Classify the question to determine if it's open-ended
    confidence_open_ended, confidence_not_open_ended = classify_question(query, model, tokenizer)

    # Determine the classification based on the confidence scores
    if confidence_open_ended > confidence_not_open_ended:
        classification = 'Open-ended'
    else:
        classification = 'Not open-ended'

    return {
        'question': query,
        'classification': classification,
        'confidence_open_ended': confidence_open_ended,
        'confidence_not_open_ended': confidence_not_open_ended
    }


def get_results_oe(queries, query_ids, query_embeddings, corpus_ids,
                   corpus_embeddings, corpus, retriever_model, classifier_model, classifier_tokenizer,
                   top_k):

    print("Retrieving for open-ended retrieval method...")
    classifier_model.eval()

    start_time = time.time()

    results = {}
    top_k_values = {}
    for i, (query_id, query_text, query_embedding) in tqdm(enumerate(zip(query_ids, queries.values(), query_embeddings)), desc="Retrieving Documents", total=len(queries)):

        # Classify the query
        classification = interpret_classification(query_text, classifier_model, classifier_tokenizer)
        classification_label = classification['classification']
        confidence_open_ended = classification['confidence_open_ended']
        confidence_not_open_ended = confidence_not_open_ended = classification['confidence_not_open_ended']

        ###### COMPLEXITY SCORE ONLY #####
        # # Calculate linguistic complexity
        # complexity_score = calculate_linguistic_complexity(query_text)

        # Determine the number of documents to retrieve based on complexity score alone
        if complexity_score < 10:
            top_k = 10  # Less complex, retrieve more documents
        elif complexity_score >= 10 and complexity_score < 20:
            top_k = 7  # Moderately complex, retrieve a moderate number of documents
        else:
            top_k = 5  # Highly complex, retrieve fewer documents


        ###### CLASSIFICATION + COMPLEXITY SCORE #######
        # # Determine the number of documents to retrieve based on classification, confidence, and complexity
        # if classification_label == 'Open-ended':
        #     if confidence_open_ended > 0.8:
        #         top_k = 10 if complexity_score < 10 else 7  # High confidence, more docs if less complex
        #     elif confidence_open_ended > 0.6:
        #         top_k = 7 if complexity_score < 10 else 5   # Medium confidence, moderate docs
        #     else:
        #         top_k = 5 if complexity_score < 10 else 3   # Low confidence, fewer docs
        # else:
        #     if confidence_not_open_ended > 0.8:
        #         top_k = 3 if complexity_score < 10 else 5   # High confidence, fewer docs
        #     elif confidence_not_open_ended > 0.6:
        #         top_k = 5 if complexity_score < 10 else 7   # Medium confidence, moderate docs
        #     else:
        #         top_k = 7 if complexity_score < 10 else 10  # Low confidence, more docs


        ###### CLASSIFICATION ONLY #######
        # Determine the number of documents to retrieve based on classification, confidence, and complexity
        if classification_label == 'Open-ended':
            if confidence_open_ended > 0.8:
                top_k = 10  # High confidence, more docs if less complex
            elif confidence_open_ended > 0.6:
                top_k = 7   # Medium confidence, moderate docs
            else:
                top_k = 5   # Low confidence, fewer docs
        else:
            if confidence_not_open_ended > 0.8:
                top_k = 3   # High confidence, fewer docs
            elif confidence_not_open_ended > 0.6:
                top_k = 5  # Medium confidence, moderate docs
            else:
                top_k = 7  # Low confidence, more docs

        # Retrieve relevant passages
        results[query_id] = retrieve_single_query_simple(corpus_ids, corpus_embeddings, query_id, query_embedding, top_k)
        top_k_values[query_id] = top_k  # Store the top_k value used

    end_time = time.time()
    retrieval_time = end_time - start_time

    print(f"Retrieval Time: {retrieval_time:.2f} seconds")
    print(f"Queries Processed per Second: {len(queries) / retrieval_time:.2f}")

    return results, top_k_values, retrieval_time

def evaluate_retrieval_system(type, corpus, queries, corpus_ids, corpus_embeddings, qrels, model, classifier_model, classifier_tokenizer, title, top_k):
    global gold_truth_path
    query_ids, query_embeddings = load_embeddings("query_embeddings")

    start_time = time.time()
    if type == "base-retrieval-system":
        results = get_results(corpus, queries, corpus_ids, corpus_embeddings, model, top_k)

        # Save retrieval results
        print("Saving retrieval results")
        save_results(f'./{title}_retrieval_results.pkl', results)
        print("Retrieval results saved")

        results = load_results(f'./{title}_retrieval_results.pkl')
        print("Retrieval results loaded")

        # Generate answers using LLM with a progress bar
        print("\n\nGenerating answers")
        os.environ["GROQ_API_KEY"] = "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"
        generate_save_answers(results, queries, corpus, title)
        print("Answers generated and saved")
        map_jsons(f'./{title}_generated_answers.json', gold_truth_path, title)
        eval(f'./{title}_updated_answers.json', title)
        print("Answers evaluated")
        retrieval_time = time.time() - start_time
        print(f"Total Time for Retrieval System Evaluation: {retrieval_time:.2f} seconds")

    elif type == "open-ended-retrieval-system":
        results, top_k_values, _ = get_results_oe(queries, query_ids, query_embeddings, corpus_ids,
                                   corpus_embeddings, corpus,
                                   model, classifier_model, classifier_tokenizer, top_k)

        # Save retrieval results
        print("Saving retrieval results")
        save_results(f'./{title}_retrieval_results.pkl', results)
        print("Retrieval results saved")

        results = load_results(f'./{title}_retrieval_results.pkl')
        print("Retrieval results loaded")

        # Save top_k values
        print("Saving top_k values")
        save_results(f'./{title}_top_k_values.pkl', top_k_values)
        print("top_k values saved")


        top_k_values = load_results(f'./{title}_top_k_values.pkl')
        print("top-k values loaded")

        # Group queries by top_k and evaluate
        grouped_results = {}
        for query_id, passages in results.items():
            k = top_k_values[query_id]
            if k not in grouped_results:
                grouped_results[k] = {}
            grouped_results[k][query_id] = passages

        # Initialize a metrics dictionary to store results for each top_k
        all_metrics = {}

        # Evaluate each group separately
        for k, group_results in grouped_results.items():
            print(f"Evaluating for top-{k}...")
            evaluate_results_at_k(group_results, queries, corpus, k, qrels)
            print("number of queries: ", len(group_results))

            # Generate answers and compare them to the gold truth for this group
            title_group = f"{title}_top{k}"
            generate_save_answers(group_results, queries, corpus, title_group)
            map_jsons(f'./{title_group}_generated_answers.json', gold_truth_path, title_group)
            eval(f'./{title_group}_updated_answers.json', title_group)

            # Load metrics for this group
            with open(f'./metrics_{title_group}.json', 'r') as metrics_file:
                metrics = json.load(metrics_file)

            # Add the number of queries analyzed for this top_k group
            metrics['num_queries'] = len(group_results)

            # Save metrics for this group
            all_metrics[k] = metrics

        # Optionally, save all metrics for later comparison
        with open(f'./metrics_{title}_all.json', 'w') as metrics_file:
            json.dump(all_metrics, metrics_file, indent=4)

        retrieval_time = time.time() - start_time
        print(f"Total Time for Retrieval System Evaluation: {retrieval_time:.2f} seconds")

        return all_metrics



def evaluation_pipeline(type, model, classifier_model, classifier_tokenizer, corpus, queries, qrels, corpus_embeddings, create_corpus_embeddings_flag, create_query_embeddings_flag, title):
    """
    Runs the evaluation pipeline with the specified sampling strategy.

    Parameters:
    - type: Type of retrieval system (e.g., 'base-retrieval-system', 'open-ended-retrieval-system')
    - model: The embedding model
    - corpus: The corpus of documents
    - queries: The set of queries
    - qrels: The ground truth answers
    - corpus_size: Number of documents to consider in the corpus
    - create_corpus_embeddings_flag: Flag to create corpus embeddings
    - create_query_embeddings_flag: Flag to create query embeddings
    - sampling_type: Type of sampling to use ('random', 'stratified')
    - num_bins: Number of bins for stratified sampling
    - sample_size_per_bin: Number of samples to take per bin in stratified sampling
    """
    global corpus_ids_path

    qrels = {k: qrels[k] for k in queries.keys() if k in qrels}

    ks = [1, 3, 5, 10]

    if create_corpus_embeddings_flag:
        create_corpus_embeddings(corpus, model)
    if create_query_embeddings_flag:
        create_query_embeddings(queries, model)

    print("Loading embeddings")

    corpus_ids = np.load(corpus_ids_path, allow_pickle=True).tolist()

    print("Embeddings loaded")
    # corpus_ids, corpus_embeddings = load_corpus_embeddings_for_evaluation()
    query_ids, query_embeddings = load_query_embeddings_for_evaluation()

    if type == "base-retriever":
        results = get_results(corpus, queries, model, top_k=10)
        evaluate_results_multiple_ks(results, queries, corpus, ks, qrels)
    elif type == "open-ended-retriever":
        classifier_model, classifier_tokenizer = load_classifier_model()
        results, _ = get_results_oe(queries, query_ids, query_embeddings, corpus_ids,
                                    corpus_embeddings, corpus, classifier_model, classifier_tokenizer,
                                    model, top_k=10)
        evaluate_results_multiple_ks(results, queries, corpus, ks, qrels)
    elif type == "base-retrieval-system":
        all_metrics = evaluate_retrieval_system(type, corpus, queries, corpus_ids, corpus_embeddings, qrels, model, classifier_model, classifier_tokenizer, title, top_k=10)
    elif type == "open-ended-retrieval-system":
        all_metrics = evaluate_retrieval_system(type, corpus, queries, corpus_ids, corpus_embeddings, qrels, model, classifier_model, classifier_tokenizer, title, top_k=10)


## Load data, embeddings, and models

In [None]:
train_corpus, train_queries, train_qrels = load_data(data_path, "train")
test_corpus, test_queries, test_qrels = load_data(data_path, "test")
val_corpus, val_queries, val_qrels = load_data(data_path, "dev")

full_queries = {**train_queries, **test_queries, **val_queries}
print(len(full_queries))

full_qrels = {**train_qrels, **test_qrels, **val_qrels}
print(len(full_qrels))

  0%|          | 0/5233329 [00:00<?, ?it/s]

  0%|          | 0/5233329 [00:00<?, ?it/s]

  0%|          | 0/5233329 [00:00<?, ?it/s]

97852
97852


In [None]:
corpus_embeddings = torch.tensor(np.load(corpus_embeddings_path))
model = load_embedding_model()
model.to(device='cuda')

# Detect the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load the classifier model and tokenizer
classifier_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
classifier_model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

with open(classifier_model_path, 'rb') as f:
    classifier_model = pickle.load(f)

classifier_model.to(device)
classifier_model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Using device: cuda


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return torch.load(io.BytesIO(b))


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

## Experiments

### Base retriever with 4000 samples stratified by query length


In [None]:
# create and save sample
base_title = 'hotpotqa-base-4000'
base_sample = stratified_sampling(full_queries, bin_size=10, sample_size_per_bin=100)

base_sample_list = list(base_sample)

with open(f'{base_title}_sample.json', 'w') as json_file:
    json.dump(base_sample_list, json_file)

# start evaluation of base sample
evaluation_pipeline("base-retrieval-system", model, classifier_model, classifier_tokenizer, train_corpus, base_sample, full_qrels, corpus_embeddings, False, True, base_title)


### Classification adaptive retrieval with 4000 samples stratified by query length

In [None]:
# create and save sample
classification_title = 'hotpotqa-classification-4000'
classification_sample = stratified_sampling(full_queries, bin_size=10, sample_size_per_bin=100)

classification_sample_list = list(classification_sample)

with open(f'{classification_title}_sample.json', 'w') as json_file:
    json.dump(classification_sample_list, json_file)

# start evaluation of classification sample
print(f"Classification adaptive retrieval with 4000 samples stratified by query length \n\n")
evaluation_pipeline("open-ended-retrieval-system", model, classifier_model, classifier_tokenizer, train_corpus, classification_sample, full_qrels, corpus_embeddings, False, True, classification_title)


### Complexity adaptive retrieval with 4000 samples stratified by query length

In [None]:
# create and save sample
complexity_title = 'hotpotqa-complexity-4000'
complexity_sample = stratified_sampling(full_queries, bin_size=10, sample_size_per_bin=100)

complexity_sample_list = list(complexity_sample)

with open(f'{complexity_title}_sample.json', 'w') as json_file:
    json.dump(complexity_sample_list, json_file)

# start evaluation of complexity sample
print(f"Evaluating Complexity adaptive retrieval with 4000 samples stratified by query length \n\n")
evaluation_pipeline("open-ended-retrieval-system", model, classifier_model, classifier_tokenizer, train_corpus, complexity_sample, full_qrels, corpus_embeddings, False, True, complexity_title)


B

In [None]:
bridge_title = "base_bridge"
query_type_sample = sample_query_texts_by_type(full_wiki_path, num_samples = 400)

bridge_sample = query_type_sample['bridge']

bridge_sample_list = list(bridge_sample)

with open(f'{bridge_title}_sample.json', 'w') as json_file:
    json.dump(bridge_sample_list, json_file)

print(f"Evaluating base retriever on 400 bridge samples \n\n")
evaluation_pipeline("base-retrieval-system", model, classifier_model, classifier_tokenizer, train_corpus, bridge_sample, full_qrels, corpus_embeddings, False, False, bridge_title)




In [None]:
comparison_title = "base_comparison"
comparison_sample = query_type_sample['comparison']

comparison_sample_list = list(comparison_sample)

with open(f'{comparison_title}_sample.json', 'w') as json_file:
    json.dump(comparison_sample_list, json_file)

print(f"Evaluating base retriever comparison samples \n\n")
evaluation_pipeline("base-retrieval-system", model, classifier_model, classifier_tokenizer, train_corpus, comparison_sample, full_qrels, corpus_embeddings, False, True, comparison_title)



### K-Fold Cross Evaluation on Sampling Methods

In [None]:
# Initialize variables to store evaluation results
print("Running mini Benchmark")

stratified_results = defaultdict(list)
random_results = defaultdict(list)

# Split data into folds
kf = KFold(n_splits=5, shuffle=True, random_state=42)

for fold_idx, (_, fold_indices) in enumerate(kf.split(full_queries)):
    print(f"Processing Fold {fold_idx + 1}")

    strat_title = f"base_stratified_fold{fold_idx + 1}"
    random_title = f"oe_random_fold{fold_idx + 1}"

    # Stratified sampling
    stratified_sample = stratified_sampling(full_queries, bin_size=10, sample_size_per_bin=50)

    # Random sampling
    # random_sample = random_sampling(full_queries, sample_size=50)  # Assuming total 2000 queries sampled

    # # Evaluate your retrieval system for random sampling
    # evaluation_pipeline("base-retrieval-system", model, train_corpus, random_sample, full_qrels, corpus_embeddings, False, True, random_title)
    # map_jsons(f"./{random_title}_generated_answers.json", "/content/drive/My Drive/Imperial/Dissertation/datasets/hotpotqa/queries.jsonl", random_title)
    # eval(f"./{random_title}_updated_answers.json", random_title)

    # Evaluate your retrieval system for stratified sampling
    # evaluation_pipeline("open-ended-retrieval-system", model, classifier_model, classifier_tokenizer, train_corpus, stratified_sample, full_qrels, corpus_embeddings, False, True, strat_title)
    evaluation_pipeline("base-retrieval-system", model, classifier_model, classifier_tokenizer, train_corpus, stratified_sample, full_qrels, corpus_embeddings, False, True, strat_title)
    map_jsons(f"./{strat_title}_generated_answers.json", "/content/drive/My Drive/Imperial/Dissertation/datasets/hotpotqa/queries.jsonl", strat_title)
    eval(f"./{strat_title}_updated_answers.json", strat_title)

