In [20]:
from transformers import AutoTokenizer, AutoModel
import torch
import faiss
import numpy as np
from collections import defaultdict
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM

In [2]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [None]:
dataset = load_dataset("pubmed_qa", "pqa_artificial")["train"].select(range(20000))

In [11]:
# Mean pooling function
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # (batch_size, seq_len, hidden_size)
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [32]:
# Encode function for passages and questions
def encode_texts(texts, batch_size=8):
    all_embeddings = []

    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        encoded_input = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt', max_length=512).to(device)

        with torch.no_grad():
            with torch.amp.autocast(device_type=device.type, dtype=torch.float16):  # mixed precision for lower memory use
                model_output = model(**encoded_input)
                embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

        embeddings = embeddings.cpu().numpy()
        embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        all_embeddings.append(embeddings)

        torch.cuda.empty_cache()  # clear memory after each batch

    return np.vstack(all_embeddings)

In [13]:
# Extract unique passages
all_passages = list({p for item in dataset for p in item["context"]["contexts"]})

In [27]:
len(all_passages)

61992

In [14]:
# Encode and normalize embeddings
baseline_passage_embeddings = encode_texts(all_passages)

  with torch.cuda.amp.autocast():  # mixed precision for lower memory use


In [15]:
# Build FAISS index using Inner Product (dot product on normalized vectors = cosine similarity)
index = faiss.IndexFlatIP(baseline_passage_embeddings.shape[1])
index.add(baseline_passage_embeddings)

In [34]:
faiss.write_index(index, "bert_cosine_index.faiss")

In [35]:
import pickle

with open("passages.pkl", "wb") as f:
    pickle.dump(all_passages, f)

In [16]:
# Map each question to its gold contexts
gold_contexts_by_question = defaultdict(list)
for item in dataset:
    q = item["question"]
    gold_contexts_by_question[q] = item["context"]["contexts"]

In [17]:
# Retrieval evaluation
top_k = 5
correct = 0
total = 0

In [33]:
from tqdm import tqdm

correct = 0
total = 0

for question, gold_contexts in tqdm(gold_contexts_by_question.items(), desc="Evaluating Retriever"):
    if not gold_contexts:
        continue

    query_embedding = encode_texts([question])  # already normalized
    D, I = index.search(query_embedding, top_k)
    retrieved = [all_passages[i] for i in I[0]]

    match_count = sum(any(retr.strip() == gold.strip() for gold in gold_contexts) for retr in retrieved)

    if match_count > 0:
        correct += 1
    total += 1

recall_at_k = correct / total if total > 0 else 0.0
print(f"[BERT-Base Cosine Retriever] Recall@{top_k}: {recall_at_k:.3f}")


Evaluating Retriever: 100%|██████████| 19998/19998 [31:28<00:00, 10.59it/s]

[BERT-Base Cosine Retriever] Recall@5: 0.492





In [21]:
gen_tokenizer = AutoTokenizer.from_pretrained("t5-base")
gen_model = AutoModelForSeq2SeqLM.from_pretrained("t5-base").to(device)
gen_model.eval()

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [38]:
def generate_answer(question, contexts, max_input_len=512, max_output_len=32):
    # Concatenate contexts and prepend the question
    input_text = "question: " + question + " context: " + " ".join(contexts)
    inputs = gen_tokenizer(
        input_text, return_tensors="pt", truncation=True,
        padding=True, max_length=max_input_len
    ).to(device)

    with torch.no_grad():
        with torch.amp.autocast(device_type=device.type, dtype=torch.float16):
            output = gen_model.generate(
                **inputs, max_length=max_output_len,
                num_beams=4, early_stopping=True
            )
    return gen_tokenizer.decode(output[0], skip_special_tokens=True)

In [None]:
import json
from tqdm import tqdm

generated_answers = []

for item in tqdm(dataset, desc="Generating RAG Answers"):
    question = item["question"]
    gold_contexts = item["context"]["contexts"]

    if not gold_contexts:
        continue

    # Retrieve top-k relevant contexts
    query_embedding = encode_texts([question])
    D, I = index.search(query_embedding, top_k)
    retrieved = [all_passages[i] for i in I[0]]

    # Generate answer
    pred_answer = generate_answer(question, retrieved).strip()

    # Store the question and generated answer
    generated_answers.append({
        "question": question,
        "generated_answer": pred_answer
    })

# Save to JSON file
with open("generated_rag_answers.json", "w", encoding="utf-8") as f:
    json.dump(generated_answers, f, indent=2, ensure_ascii=False)

print(f"Saved {len(generated_answers)} question-answer pairs to generated_rag_answers.json")

Generating RAG Answers:  24%|██▍       | 4843/20000 [39:22<2:11:04,  1.93it/s]