# Dense Retrieval on HotpotQA

## Installation

In [1]:
!pip install sentence-transformers faiss-cpu accelerate datasets tqdm evaluate 



## Imports

In [2]:
from sentence_transformers import SentenceTransformer 
from datasets import load_dataset
import faiss 
import tqdm 
import torch 
import pickle 
import os 
import numpy as np
import json
from itertools import chain
import re, nltk, evaluate 
nltk.download("punkt")
nltk.download("punkt_tab")

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/ajitkumarsingh/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/ajitkumarsingh/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

## Load Dataset

In [3]:
hp = load_dataset("hotpot_qa", "distractor", split="train[:10000]")

question = hp["question"][:100]
gold = hp["answer"][:100]


## Encode Corpus

In [4]:
model_id = "BAAI/bge-base-en-v1.5"
encoder  = SentenceTransformer(model_id, device='cpu')


In [5]:
corpus_raw = hp["context"]

corpus = []
for context in corpus_raw:
    corpus.extend(list(chain.from_iterable(context['sentences'])))

In [18]:
batch = 256 

embeddings = []
count = 0
for i in tqdm.trange(0, len(corpus), batch):

    chunk = encoder.encode(
        corpus[i:i+batch], 
        normalize_embeddings=True,
        batch_size = batch,
        show_progress_bar=True
    )
    embeddings.append(chunk.astype("float32"))
    if count >=15:
        break

  0%|          | 0/1595 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 1/1595 [02:27<65:22:00, 147.63s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 2/1595 [05:32<75:04:23, 169.66s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 3/1595 [07:37<65:58:40, 149.20s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 4/1595 [09:26<58:49:59, 133.12s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 5/1595 [11:52<60:52:34, 137.83s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 6/1595 [13:38<56:05:39, 127.09s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 7/1595 [15:23<52:52:41, 119.88s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 8/1595 [17:43<55:44:26, 126.44s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 9/1595 [20:04<57:42:03, 130.97s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 10/1595 [22:17<57:52:27, 131.45s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 11/1595 [23:50<52:43:32, 119.83s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 12/1595 [25:32<50:17:07, 114.36s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 13/1595 [27:46<52:51:30, 120.28s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 14/1595 [29:21<49:29:39, 112.70s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 15/1595 [31:36<52:18:30, 119.18s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 16/1595 [33:34<52:07:09, 118.83s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 17/1595 [52:09<183:25:44, 418.47s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 18/1595 [55:43<156:22:04, 356.96s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 19/1595 [58:24<130:32:41, 298.20s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|▏         | 20/1595 [1:00:44<109:41:09, 250.71s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|▏         | 21/1595 [1:06:02<118:30:00, 271.03s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|▏         | 22/1595 [1:07:30<94:25:58, 216.12s/it] 

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  1%|▏         | 23/1595 [1:09:05<78:25:42, 179.61s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  2%|▏         | 24/1595 [1:10:40<67:14:57, 154.10s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  2%|▏         | 25/1595 [1:14:26<76:40:05, 175.80s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  2%|▏         | 26/1595 [1:16:33<70:17:13, 161.27s/it]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [20]:
doc_embeddings = np.vstack(embeddings)
doc_embeddings.shape

(6912, 768)

## Building FAISS Index

**Trial:**
- `Exact`
- `HNSW`

In [None]:
embedding_dim = doc_embeddings.shape[1]

#normalize embeddings
faiss.normalize_L2(doc_embeddings)

### Exact

In [22]:
index_flat = faiss.IndexFlatIP(embedding_dim)
index_flat.add(doc_embeddings)

### Approximately HNSW

In [23]:
index_hnsw = faiss.IndexHNSWFlat(embedding_dim, 32)
index_hnsw.hnsw.efConstruction = 200
index_hnsw.add(doc_embeddings)
index_hnsw.hnsw.efSearch= 64

: 

: 

## Encode Documents

In [None]:
def embed_q(xs):

    return encoder.encode(xs, normalize_embeddings=True, 
                          batch_size=64, show_progress_bar=False)

q_vecs = embed_q(question)


## Evaluate - Recall@k

In [None]:
def recall(index, k=5):

    D, I = index.search(q_vecs, k) #I : (n_q, k) doc indices
    hit = 0
    for candidate, answer in zip(I, gold):
        if any(answer.lower() in corpus[i].lower() for i in candidate):
            hit += 1
        return hit*100/len(gold)

In [None]:
for k in [1, 3, 5]:
    print(f"HNSW Recall@{k}: {recall(index_hnsw, k):.1f}")

In [None]:
for k in [1, 3, 5]:
    print(f"Exact Recall@{k}: {recall(index_flat, k):.1f}")

## Plug into RAG Generation

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
llm = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)

generator = pipeline("text-generation", model=llm, tokenizer=tokenizer, temperature = 0.1, max_new_tokens=128)

In [None]:
def answer_dense(q, k=3):

    qv = encoder.encode([q], normalize_embeddings=True)
    D, I = index_hnsw.search(qv, k)
    context = "\n".join(corpus[i] for i in I[0])
    prompt = ("You are a question-answering system. Use the context. \n\n"
        "Context: {context}\n\n"
        "Question: {q}\n\n"
        "Answer briefly:"
    )

    return generator(prompt)[0]["generated_text"].split("Answer briefly:")[-1].strip()

## Model Evaluation

In [None]:
predictions = [answer_dense(q, k=3) for q in tqdm.tqdm(question[:100])]

predictions_formatted = []
references_formatted = []

for i, (pred, ref) in enumerate(zip(predictions, gold[:100])):
    predictions_formatted.append({"id": str(i), "prediction_text": pred})
    references_formatted.append({"id": str(i), "answers": {"text": [ref], "answer_start": [0]}})
squad = evaluate.load("squad")
results = squad.compute(predictions=predictions_formatted, references=references_formatted)
print(json.dumps(results, indent=2))

## Tunable Hyper-parameters

- Model Size: `bge-large-en` (1024 d) maybe improve recall at cost of more RAM
- Untied encoders: Train a doc-encoder specialized for long passages, query-encoder for short questions (e.g. DPR)
- Pooling: Mean Pool
- ANN Knobs: raising `efSearch`(HNSW) from 64 -> 128 bumps recall but increase latency 
- Float16 vs float32: Saves 50% RAM, negligible recall drop