# LLM exact match performance with top [1,3,5] retrieved contexts


In [10]:
from dotenv import load_dotenv
load_dotenv()

from dexter.llms.llm_engine_orchestrator import LLMEngineOrchestrator
import json
import pandas as pd
from tqdm import tqdm

from dexter.config.constants import Split
from dexter.data.loaders.RetrieverDataset import RetrieverDataset


In [11]:
config_instance = LLMEngineOrchestrator()
llm_instance = config_instance.get_llm_engine(data="",llm_class="llama",model_name="llama-3.1-8b-instant")

system_prompt = """You are a precise answering assistant. You will be provided with a [Question] and a [Context] to help you answer the question. The [Context] consists of documents with a [Title] and [Text].
1. Answer the question using ONLY the provided context.
2. Do not use full sentences. Provide only the exact answer entity or phrase.
3. Do not add conversational filler like "The answer is" or "Based on the context".
4. If the answer is not contained within the context, respond with "Not Answerable".
"""


### Run inference for top-k = [1,3,5]

In [12]:
def load_data(dev_path, corpus_path):
    """Loads question data and corpus data."""
    print("Loading questions...")
    with open(dev_path, 'r') as f:
        questions = json.load(f)
    # Map id -> question object for easy lookup
    questions_map = {q['_id']: q for q in questions}
    
    print("Loading corpus...")
    with open(corpus_path, 'r') as f:
        corpus = json.load(f)
        
    return questions_map, corpus

def load_retrieval_results(retrieval_path):
    """Loads retrieval results (question_id -> {doc_id: score})."""
    print(f"Loading retrieval results from {retrieval_path}...")
    with open(retrieval_path, 'r') as f:
        retrieval = json.load(f)
    return retrieval

def run_inference_loop(llm_engine, retrieval_path, questions_map, corpus, system_prompt, max_samples=None):
    """
    Runs inference on questions using retrieved contexts.
    """
    retrieval_results = load_retrieval_results(retrieval_path)
    predictions = {}
    
    count = 0
    # Create prompts and query LLM
    for q_id, contexts in tqdm(retrieval_results.items(), desc="Running Inference"):
        if max_samples and count >= max_samples:
            break
            
        if q_id not in questions_map:
            continue
            
        # 1. Get Question Text
        question_text = questions_map[q_id]['question']
        
        context_str = ""
        for doc_id in contexts:
            if doc_id in corpus:
                doc = corpus[doc_id]
                title = doc.get('title', '')
                text = doc.get('text', '')
                context_str += f"[Title]: {title}\n[Text]: {text}\n\n"
        
        # 3. Construct System and User Prompts
        user_prompt = f"[Context]:\n{context_str}\n\n[Question]: {question_text}"
        
        # 4. Call LLM (using the instance you initialized earlier)
        try:
            answer = llm_engine.get_llama_completion(system_prompt, user_prompt)
            predictions[q_id] = answer
        except Exception as e:
            print(f"Error processing {q_id}: {e}")
            predictions[q_id] = "ERROR"
            
        count += 1
        
    return predictions

# --- EXECUTION BLOCK ---

# 1. Load Data
# Ensure these paths are correct relative to your notebook
questions_map, corpus = load_data('data/dev.json', 'data/wiki_musique_corpus.json')

# 2. Run Inference (Example for k=1)
# llm_instance is the variable corresponding to your initialized LlamaEngine
print("Starting Inference for k=1...")
top_k = [1,3,5]
for k in top_k:
    predictions_k = run_inference_loop(
        llm_instance, 
        f'retrieval_k{k}.json', 
        questions_map, 
        corpus,
        system_prompt, 
         # Remove this limit to run on all data
    )

    output_file = f'predictions_k{k}.json'
    with open(output_file, 'w') as f:
        json.dump(predictions_k, f)
    print("saved predictions to ", output_file)

Loading questions...
Loading corpus...
Starting Inference for k=1...
Loading retrieval results from retrieval_k1.json...


Running Inference: 100%|██████████| 1200/1200 [03:33<00:00,  5.62it/s]


saved predictions to  predictions_k1.json
Loading retrieval results from retrieval_k3.json...


Running Inference: 100%|██████████| 1200/1200 [04:13<00:00,  4.74it/s]


saved predictions to  predictions_k3.json
Loading retrieval results from retrieval_k5.json...


Running Inference: 100%|██████████| 1200/1200 [04:59<00:00,  4.01it/s]

saved predictions to  predictions_k5.json





### Run inference for oracle contexts

In [14]:
def run_inference_with_oracle_contexts(llm_engine, dev_path, system_prompt, max_samples=None):
    """
    Runs inference on questions using oracle contexts from dev.json.
    Oracle contexts are the ground-truth supporting documents provided in the dataset.
    """
    print("Loading dev.json for oracle contexts...")
    with open(dev_path, 'r') as f:
        questions = json.load(f)
    
    predictions = {}
    count = 0
    
    for q in tqdm(questions[:1200], desc="Running Inference (Oracle)"):
        if max_samples and count >= max_samples:
            break
        
        q_id = q['_id']
        question_text = q['question']
        
        # Build context string from oracle contexts
        # Each context entry is [title, [list of paragraph strings]]
        context_str = ""
        for ctx in q['context']:
            title = ctx[0]
            paragraphs = ctx[1]
            text = " ".join(paragraphs)
            context_str += f"[Title]: {title}\n[Text]: {text}\n\n"
        
        user_prompt = f"[Context]:\n{context_str}\n\n[Question]: {question_text}"
        
        try:
            answer = llm_engine.get_llama_completion(system_prompt, user_prompt)
            predictions[q_id] = answer
        except Exception as e:
            print(f"Error processing {q_id}: {e}")
            predictions[q_id] = "ERROR"
        
        count += 1
    
    return predictions

# Run inference with oracle contexts
print("Starting Inference with Oracle Contexts...")
predictions_oracle = run_inference_with_oracle_contexts(
    llm_instance,
    'data/dev.json',
    system_prompt
)

output_file = 'predictions_oracle.json'
with open(output_file, 'w') as f:
    json.dump(predictions_oracle, f)
print("Saved predictions to", output_file)

Starting Inference with Oracle Contexts...
Loading dev.json for oracle contexts...


Running Inference (Oracle): 100%|██████████| 1200/1200 [06:29<00:00,  3.08it/s]


Saved predictions to predictions_oracle.json


### Evaluate experiments with ExactMatch

In [15]:
from dexter.utils.metrics.ExactMatch import ExactMatch

def evaluate_predictions(predictions_path, dev_path):
    """
    Evaluates predictions against ground truth answers using ExactMatch.
    Returns the accuracy (proportion of exact matches).
    """
    # Load predictions
    with open(predictions_path, 'r') as f:
        predictions = json.load(f)
    
    # Load ground truth
    with open(dev_path, 'r') as f:
        questions = json.load(f)
    
    # Create ground truth map
    ground_truth = {q['_id']: q['answer'] for q in questions}
    
    # Evaluate
    metric = ExactMatch()
    scores = []
    
    for q_id, pred_answer in predictions.items():
        if q_id in ground_truth:
            gt_answer = ground_truth[q_id]
            score = metric.evaluate(pred_answer, gt_answer)
            scores.append(score)
    
    accuracy = sum(scores) / len(scores) if scores else 0
    return accuracy, len(scores)

# Evaluate all prediction files
dev_path = 'data/dev.json'
prediction_files = [
    ('Top-1 Retrieved', 'predictions_k1.json'),
    ('Top-3 Retrieved', 'predictions_k3.json'),
    ('Top-5 Retrieved', 'predictions_k5.json'),
    ('Oracle Contexts', 'predictions_oracle.json'),
]

print("=" * 50)
print("Exact Match Evaluation Results")
print("=" * 50)

results = []
for name, pred_file in prediction_files:
    try:
        accuracy, num_samples = evaluate_predictions(pred_file, dev_path)
        results.append({'Experiment': name, 'Exact Match': f'{accuracy:.4f}', 'Samples': num_samples})
        print(f"{name}: {accuracy:.4f} ({num_samples} samples)")
    except FileNotFoundError:
        print(f"{name}: File not found ({pred_file})")

print("=" * 50)

# Display as DataFrame
results_df = pd.DataFrame(results)
results_df

Exact Match Evaluation Results
Top-1 Retrieved: 0.0400 (1200 samples)
Top-3 Retrieved: 0.0692 (1200 samples)
Top-5 Retrieved: 0.0742 (1200 samples)
Oracle Contexts: 0.2675 (1200 samples)


Unnamed: 0,Experiment,Exact Match,Samples
0,Top-1 Retrieved,0.04,1200
1,Top-3 Retrieved,0.0692,1200
2,Top-5 Retrieved,0.0742,1200
3,Oracle Contexts,0.2675,1200
