In [1]:
from datasets import load_dataset
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from sklearn.metrics import f1_score
import numpy as np
import faiss
import torch
import pandas as pd
from tqdm import tqdm

import json
import os


## Load Dataset

In [6]:
from datasets import load_dataset
import pandas as pd

# 1. QA 데이터 (question, answer) → nq_open
nq_dataset = load_dataset("nq_open", split="train")

qa_data = []
for item in nq_dataset:
    question = item["question"]
    # answer는 정답 후보 리스트 중 첫 번째 사용
    answer = item["answer"][0] if item["answer"] else ""
    qa_data.append({
        "question": question,
        "answer": answer
    })

qa_df = pd.DataFrame(qa_data)
print("✅ Loaded QA pairs:", len(qa_df))
print(qa_df.head())

# 2. Corpus 데이터 (doc_id, text) → DPR Wikipedia passages
corpus = load_dataset("wiki_dpr", "psgs_w100.nq.compressed", split="train")

corpus_df = pd.DataFrame({
    "doc_id": corpus["id"],
    "text": corpus["text"]
})
print("✅ Loaded corpus passages:", len(corpus_df))
print(corpus_df.head())

✅ Loaded QA pairs: 87925
                                            question                answer
0           where did they film hot tub time machine  Fernie Alpine Resort
1   who has the right of way in international waters        Neither vessel
2            who does annie work for attack on titan                Marley
3  when was the immigration reform and control ac...      November 6, 1986
4              when was puerto rico added to the usa                  1950


Downloading data:   0%|          | 0/157 [00:00<?, ?files/s]

data/psgs_w100/nq/train-00000-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00001-of-00157.p(…):   0%|          | 0.00/546M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00002-of-00157.p(…):   0%|          | 0.00/546M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00003-of-00157.p(…):   0%|          | 0.00/546M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00004-of-00157.p(…):   0%|          | 0.00/546M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00005-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00006-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00007-of-00157.p(…):   0%|          | 0.00/537M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00008-of-00157.p(…):   0%|          | 0.00/530M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00009-of-00157.p(…):   0%|          | 0.00/538M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00010-of-00157.p(…):   0%|          | 0.00/546M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00011-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00012-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00013-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00014-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00015-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00016-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00017-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00018-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00019-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00020-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00021-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00022-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00023-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00024-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00025-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00026-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00027-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00028-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00029-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00030-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00031-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00032-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00033-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00034-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00035-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00036-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00037-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00038-of-00157.p(…):   0%|          | 0.00/545M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00039-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00040-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00041-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00042-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00043-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00044-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00045-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00046-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00047-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00048-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00049-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00050-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00051-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00052-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00053-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00054-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00055-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00056-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00057-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00058-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00059-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00060-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00061-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00062-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00063-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00064-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00065-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00066-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00067-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00068-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00069-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00070-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00071-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00072-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00073-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00074-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00075-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00076-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00077-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00078-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00079-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00080-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00081-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00082-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00083-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00084-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00085-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00086-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00087-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00088-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00089-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00090-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00091-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00092-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00093-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00094-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00095-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00096-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00097-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00098-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00099-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00100-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00101-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00102-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00103-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00104-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00105-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00106-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00107-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00108-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00109-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00110-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00111-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00112-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00113-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00114-of-00157.p(…):   0%|          | 0.00/544M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00115-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00116-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00117-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00118-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00119-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00120-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00121-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00122-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00123-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00124-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00125-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00126-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00127-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00128-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00129-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00130-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00131-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00132-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00133-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00134-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00135-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00136-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00137-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00138-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00139-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00140-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00141-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00142-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00143-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00144-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00145-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00146-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00147-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00148-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00149-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00150-of-00157.p(…):   0%|          | 0.00/543M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00151-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00152-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00153-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00154-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00155-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

data/psgs_w100/nq/train-00156-of-00157.p(…):   0%|          | 0.00/542M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/21015300 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/161 [00:00<?, ?it/s]

  0%|          | 0/21016 [00:00<?, ?it/s]

FileNotFoundError: [Errno 2] No such file or directory: '/home/aix7101/.cache/huggingface/datasets/wiki_dpr/psgs_w100.nq.compressed/0.0.0/66fd9b80f51375c02cd9010050e781ed3e8f759e868f690c31b2686a7a0eeb5c/psgs_w100.nq.IVF4096_HNSW128_PQ128-IP-train.faiss'

In [11]:
from datasets import load_dataset

# stream=True를 추가하면 메모리에서만 로드하며 인덱스를 시도하지 않음
corpus = load_dataset(
    "wiki_dpr",
    "psgs_w100.nq.compressed",
    split="train",
    streaming=True  # 핵심
)

# passage 뽑기 예시
first_passage = next(iter(corpus))
print(first_passage)

{'id': '1', 'text': 'Aaron Aaron ( or ; "Ahärôn") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother\'s spokesman ("prophet") to the Pharaoh. Part of the Law (Torah) that Moses received from', 'title': 'Aaron', 'embeddings': [0.013342111371457577, 0.582173764705658, -0.31309744715690613, -0.6991612911224365, -0.5583199858665466, 0.5187504887580872, 0.7152731418609619, -0.08567414432764053, -0.24895088374614716, -0.4495537281036377, -0.643000066280365, 0.11746902763843536, -0.22123917937278748, 0.30100083351135254, 0.08902842551469803, 0.018262844532728195,

In [15]:
import json

# IterableDataset을 리스트로 변환 (주의: 메모리에 다 올라감)
corpus_list = list(corpus)

# 필요 없는 'embeddings' 필드 제거
for item in corpus_list:
    item.pop("embeddings", None)  # 없을 수도 있으므로 안전하게 pop

# JSON으로 저장
with open("/mnt/aix7101/jeong/aix_project/dataset/nq_rag_corpus.json", "w", encoding="utf-8") as f:
    json.dump(corpus_list, f, ensure_ascii=False, indent=2)

print("✅ embeddings 제거 후 JSON 저장 완료")

In [None]:
# # 1. corpus 저장 (Retrieval 문서들)
# corpus_records = corpus_df.to_dict(orient="records")
# with open("dataset/nq_rag_corpus.json", "w", encoding="utf-8") as f:
#     json.dump(corpus_records, f, ensure_ascii=False, indent=2)

# # # 2. QA 쌍 저장 (질문-정답-문서 매핑)
# # qa_records = qa_df.to_dict(orient="records")
# # with open("dataset/nq_rag_qa_pairs.json", "w", encoding="utf-8") as f:
# #     json.dump(qa_records, f, ensure_ascii=False, indent=2)

# print("✅ JSON 파일 저장 완료:")

✅ JSON 파일 저장 완료:


In [None]:
import json
import pandas as pd

# 1. Corpus 로드
with open("/mnt/aix7101/jeong/aix_project/dataset/nq_rag_corpus.json", "r", encoding="utf-8") as f:
    corpus_records = json.load(f)
corpus_df = pd.DataFrame(corpus_records)

# 2. QA 쌍 로드
with open("dataset/nq_rag_qa_pairs.json", "r", encoding="utf-8") as f:
    qa_records = json.load(f)
qa_pairs = pd.DataFrame(qa_records)

print("📂 corpus_df shape:", corpus_df.shape)
print("📂 qa_pairs shape:", qa_pairs.shape)

## Save embedded vector

In [None]:

# 3. Load DPR model and tokenizer (use multi-qa)
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")

q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-multiset-base")
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-multiset-base")


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-multiset-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification mode

In [None]:
ctx_encoder.eval()
q_encoder.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ctx_encoder.to(device)
q_encoder.to(device)


DPRQuestionEncoder(
  (question_encoder): DPREncoder(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_feature

In [None]:
batch_size = 32
ctx_embeddings = []

for i in tqdm(range(0, len(corpus_df), batch_size), desc="Encoding contexts"):
    batch_texts = corpus_df["text"].iloc[i:i+batch_size].tolist()
    batch_texts = [str(t).strip() for t in batch_texts]

    inputs = ctx_tokenizer(batch_texts, return_tensors="pt", truncation=True, padding=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        output = ctx_encoder(**inputs)
        emb_batch = output.pooler_output.cpu().numpy()  # or output.last_hidden_state[:, 0]
        ctx_embeddings.append(emb_batch)

ctx_embeddings = np.vstack(ctx_embeddings)

Encoding contexts: 100%|██████████| 591/591 [01:54<00:00,  5.14it/s]


In [None]:
batch_size = 32  # 필요에 따라 조정 가능
q_embeddings = []

questions = qa_pairs["question"].tolist()

for i in tqdm(range(0, len(questions), batch_size), desc="Encoding questions"):
    batch_questions = questions[i:i+batch_size]
    batch_questions = [str(q).strip() for q in batch_questions]

    inputs = q_tokenizer(batch_questions, return_tensors="pt", truncation=True, padding=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        output = q_encoder(**inputs)
        emb_batch = output.pooler_output.cpu().numpy()  # or output.last_hidden_state[:, 0]
        q_embeddings.append(emb_batch)

q_embeddings = np.vstack(q_embeddings)

Encoding questions: 100%|██████████| 2738/2738 [00:54<00:00, 50.16it/s]


In [None]:
# 4. 저장
embedding_dir = "/mnt/aix7101/jeong/aix_project"
if not os.path.exists(embedding_dir):
    os.makedirs(embedding_dir)
    print(f"📁 Created directory: {embedding_dir}")

ctx_path = os.path.join(embedding_dir, "nq_dpr_ctx_embeddings_multiqa.npy")
q_path = os.path.join(embedding_dir, "nq_dpr_q_embeddings_multiqa.npy")

np.save(ctx_path, ctx_embeddings)
np.save(q_path, q_embeddings)

print(f"✅ Context embeddings saved to: {ctx_path}")
print(f"✅ Question embeddings saved to: {q_path}")

✅ Context embeddings saved to: /mnt/aix7101/jeong/aix_project/dpr_ctx_embeddings_multiqa.npy
✅ Question embeddings saved to: /mnt/aix7101/jeong/aix_project/dpr_q_embeddings_multiqa.npy


In [None]:
from nltk.tokenize import sent_tokenize
import numpy as np
from tqdm import tqdm

batch_size = 16  # GPU 상황에 따라 조정
ctx_sentence_embeddings = []

for doc in tqdm(corpus_df["text"], desc="Encoding multi-sentence contexts"):
    # 1. 문서 내 문장 분리
    sentences = sent_tokenize(doc)
    doc_embeddings = []

    # 2. 문장들을 배치로 처리
    for i in range(0, len(sentences), batch_size):
        batch_sents = sentences[i:i+batch_size]
        inputs = ctx_tokenizer(
            batch_sents,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=128
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            output = ctx_encoder(**inputs)
            emb_batch = output.pooler_output.cpu().numpy()  # or last_hidden_state[:, 0]
            doc_embeddings.append(emb_batch)

    # 3. 문서 하나에 대한 (문장 수, dim) 배열 생성
    doc_embeddings = np.vstack(doc_embeddings)
    ctx_sentence_embeddings.append(doc_embeddings)
    
# 4. 문서별 문장 수가 달라 3D 배열로 만들고 싶을 경우
max_len = max(e.shape[0] for e in ctx_sentence_embeddings)
dim = ctx_sentence_embeddings[0].shape[1]

padded_embeddings = np.zeros((len(ctx_sentence_embeddings), max_len, dim))
for i, emb in enumerate(ctx_sentence_embeddings):
    padded_embeddings[i, :emb.shape[0], :] = emb

Encoding multi-sentence contexts: 100%|██████████| 18891/18891 [03:39<00:00, 85.90it/s] 


In [None]:
# 4. 저장
embedding_dir = "/mnt/aix7101/jeong/aix_project"
if not os.path.exists(embedding_dir):
    os.makedirs(embedding_dir)
    print(f"📁 Created directory: {embedding_dir}")

sentence_ctx_path = os.path.join(embedding_dir, "dpr_m_ctx_embeddings_multiqa.npy")
ctx_sentence_embeddings = np.array(ctx_sentence_embeddings, dtype=object)
np.save(sentence_ctx_path, ctx_sentence_embeddings, allow_pickle=True)

print(f"✅ Context Sentence embeddings saved to: {sentence_ctx_path}")

✅ Context Sentence embeddings saved to: /mnt/aix7101/jeong/aix_project/dpr_m_ctx_embeddings_multiqa.npy


## BM25

In [None]:
from rank_bm25 import BM25Okapi
from tqdm import tqdm
import pandas as pd
import numpy as np
from transformers import BartForConditionalGeneration, BartTokenizer
from sklearn.metrics import accuracy_score, f1_score

def evaluate_generator_with_bm25(
    qa_pairs: pd.DataFrame,
    corpus_df: pd.DataFrame,
    k: int = 5,
    model_name: str = "facebook/bart-base",
    max_input_length: int = 1024,
    max_output_length: int = 50,
) -> None:
    # 1. Tokenize corpus
    tokenized_corpus = [doc.split() for doc in corpus_df["text"]]
    bm25 = BM25Okapi(tokenized_corpus)

    # 2. Load generator
    tokenizer = BartTokenizer.from_pretrained(model_name)
    model = BartForConditionalGeneration.from_pretrained(model_name)

    predictions = []
    references = []

    for _, row in tqdm(qa_pairs.iterrows(), total=len(qa_pairs), desc="Evaluating Generator"):
        question = row["question"]
        true_answer = row["answer"]

        # 3. BM25 top-k retrieval
        tokenized_query = question.split()
        scores = bm25.get_scores(tokenized_query)
        topk_indices = np.argsort(scores)[::-1][:k]
        topk_texts = corpus_df.iloc[topk_indices]["text"].tolist()

        # 4. Concatenate retrieved passages as context
        context = " ".join(topk_texts)
        input_text = f"question: {question} context: {context}"
        inputs = tokenizer(
            input_text,
            return_tensors="pt",
            max_length=max_input_length,
            truncation=True,
        )

        # 5. Generate answer
        output_ids = model.generate(
            inputs["input_ids"],
            max_length=max_output_length,
            num_beams=4,
            early_stopping=True
        )
        generated_answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        predictions.append(generated_answer.strip())
        references.append(true_answer.strip())

    # 6. Evaluation
    acc = accuracy_score(references, predictions)
    f1 = f1_score(references, predictions, average="macro")  # or "weighted"

    print(f"✅ Accuracy: {acc:.4f}")
    print(f"✅ F1 Score: {f1:.4f}")
    return acc, f1

In [None]:
acc, f1 = evaluate_generator_with_bm25(qa_pairs, corpus_df, k=5)

Evaluating BM25 Recall@K:   0%|          | 123/87599 [00:05<1:02:03, 23.49it/s]


KeyboardInterrupt: 

## DPR

In [None]:
from transformers import pipeline
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
import numpy as np

def evaluate_generator_with_dpr(qa_pairs, corpus_df, ctx_emb_path, q_emb_path, k=5, model_name="facebook/bart-base"):
    """
    저장된 DPR 임베딩을 기반으로 top-k 문서를 찾아 Generator로 답변을 생성하고, 정답과 비교하여 Accuracy와 F1을 평가합니다.

    Args:
        qa_pairs (pd.DataFrame): 질문-정답 쌍 (columns: ['question', 'answer', 'doc_id'])
        corpus_df (pd.DataFrame): 문서 집합 (columns: ['doc_id', 'text'])
        ctx_emb_path (str): 문서 임베딩이 저장된 .npy 경로
        q_emb_path (str): 질문 임베딩이 저장된 .npy 경로
        k (int): top-k 문서 수
        model_name (str): Hugging Face Generator 모델 이름

    Returns:
        (float, float): Accuracy, F1 Score
    """
    # 1. 임베딩 로드
    ctx_embeddings = np.load(ctx_emb_path)
    q_embeddings = np.load(q_emb_path)

    assert len(q_embeddings) == len(qa_pairs), "❗ 질문 임베딩 수와 QA 쌍 수가 일치하지 않습니다."

    # 2. Generator 로드
    generator = pipeline("text2text-generation", model=model_name)

    predictions = []
    references = []

    # 3. 질문마다 DPR 기반 top-k 문서로 답변 생성
    for idx, row in tqdm(qa_pairs.iterrows(), total=len(qa_pairs), desc="Evaluating Generator via DPR"):
        question = row["question"]
        gold_answer = row["answer"]
        q_emb = q_embeddings[idx]

        # dot product 기반 유사도 계산
        scores = np.dot(ctx_embeddings, q_emb)
        topk_indices = np.argsort(scores)[::-1][:k]
        topk_texts = corpus_df.iloc[topk_indices]["text"].tolist()

        # context 구성
        context = " ".join(topk_texts)
        input_text = f"question: {question} context: {context}"

        # 생성
        output = generator(input_text, max_length=64, do_sample=False)[0]["generated_text"]
        predictions.append(output.strip())
        references.append(gold_answer.strip())

    # 평가
    accuracy = np.mean([pred == ref for pred, ref in zip(predictions, references)])
    f1 = np.mean([
        f1_score(ref.split(), pred.split(), average="micro") if len(ref.split()) > 0 and len(pred.split()) > 0 else 0
        for pred, ref in zip(predictions, references)
    ])

    print(f"📌 Accuracy: {accuracy:.4f} | F1 Score: {f1:.4f}")
    return accuracy, f1

In [None]:
acc, f1 = evaluate_generator_with_dpr(
    qa_pairs=qa_pairs,
    corpus_df=corpus_df,
    ctx_emb_path="/mnt/aix7101/jeong/aix_project/dpr_ctx_embeddings_multiqa.npy",
    q_emb_path="/mnt/aix7101/jeong/aix_project/dpr_q_embeddings_multiqa.npy",
    k=5,
    model_name="facebook/bart-base"
)


Evaluating Recall@K: 100%|██████████| 87599/87599 [01:03<00:00, 1387.19it/s]

📌 Recall@3: 0.5645





## sentenceDPR

In [25]:
def compute_dprm_recall(
    qa_pairs: pd.DataFrame,
    corpus_df: pd.DataFrame,
    ctx_emb_path: str,
    q_emb_path: str,
    k: int = 5,
    aggregation: str = "mean",
) -> float:
    """
    문장 단위의 문서 임베딩을 사용하여 DPR-m 방식의 Recall@k 계산.

    Args:
        qa_pairs (pd.DataFrame): 질문-정답 쌍 (columns: ['question', 'answer', 'doc_id'])
        corpus_df (pd.DataFrame): 문서 집합 (columns: ['doc_id', 'text'])
        ctx_emb_path (str): 문장 단위 문서 임베딩 저장 경로 (.npy, shape: [num_docs, num_sents, dim])
        q_emb_path (str): 질문 임베딩 저장 경로 (.npy, shape: [num_queries, dim])
        k (int): Recall@k
        aggregation (str): 'max' 또는 'mean' 방식으로 문서 유사도 집계
        
    Returns:
        float: Recall@k
    """
    
    # 1. 임베딩 로드
    ctx_embeddings = np.load(ctx_emb_path, allow_pickle=True)  # object 배열
    q_embeddings = np.load(q_emb_path)

    assert len(q_embeddings) == len(qa_pairs), "❗ 질문 임베딩 수와 QA 쌍 수가 일치하지 않습니다."

    hit_count = 0

    for idx, row in tqdm(qa_pairs.iterrows(), total=len(qa_pairs), desc="Evaluating DPR-m Recall@K"):
        gt_doc_id = row["doc_id"]
        q_emb = q_embeddings[idx]  # (dim,)

        # 각 문서에 대해 문장 임베딩과 q_emb의 유사도 계산
        scores = []
        for doc_sents in ctx_embeddings:
            sent_scores = np.dot(doc_sents, q_emb)  # (num_sents,)
            if aggregation == "max": # 유사도가 제일 높은 문장이 있는 것으로 할지
                score = np.max(sent_scores)
            elif aggregation == "mean": # 전체적인 문장의 평균으로 계산할지
                score = np.mean(sent_scores)
            else:
                raise ValueError("aggregation은 'max' 또는 'mean'이어야 합니다.")
            scores.append(score)

        scores = np.array(scores)
        topk_indices = np.argsort(scores)[::-1][:k]
        topk_doc_ids = corpus_df.iloc[topk_indices]["doc_id"].tolist()

        if gt_doc_id in topk_doc_ids:
            hit_count += 1

    recall_at_k = hit_count / len(qa_pairs)
    print(f"📌 DPR-m Recall@{k} ({aggregation} aggregation): {recall_at_k:.4f}")
    return recall_at_k

In [36]:
compute_dprm_recall(
    qa_pairs=qa_pairs,
    corpus_df=corpus_df,
    ctx_emb_path="/mnt/aix7101/jeong/aix_project/dpr_m_ctx_embeddings_multiqa.npy",
    q_emb_path="/mnt/aix7101/jeong/aix_project/dpr_q_embeddings_multiqa.npy",
    k=3,
    aggregation="max"
)

Evaluating DPR-m Recall@K: 100%|██████████| 87599/87599 [1:55:38<00:00, 12.63it/s]

📌 DPR-m Recall@3 (max aggregation): 0.6796





0.6795739677393577

## hybrid (bm25 + DPR)

In [None]:
from transformers import pipeline
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
from rank_bm25 import BM25Okapi
import numpy as np
import pandas as pd

def evaluate_generator_with_hybrid(
    qa_pairs: pd.DataFrame,
    corpus_df: pd.DataFrame,
    ctx_emb_path: str,
    q_emb_path: str,
    bm25_top_n: int = 100,
    k: int = 5,
    model_name: str = "facebook/bart-base"
) -> tuple:
    """
    BM25 + DPR 하이브리드 top-k 문서를 기반으로 generator가 답변 생성 → Accuracy/F1 계산

    Args:
        qa_pairs (pd.DataFrame): 질문-정답 쌍 (columns: ['question', 'answer', 'doc_id'])
        corpus_df (pd.DataFrame): 문서 집합 (columns: ['doc_id', 'text'])
        ctx_emb_path (str): DPR 문서 임베딩 경로 (.npy)
        q_emb_path (str): DPR 질문 임베딩 경로 (.npy)
        bm25_top_n (int): BM25로 먼저 추리는 후보 문서 수
        k (int): 최종 DPR로 고르는 문서 수
        model_name (str): Hugging Face generator 모델명

    Returns:
        tuple: (accuracy, f1 score)
    """
    ctx_embeddings = np.load(ctx_emb_path)
    q_embeddings = np.load(q_emb_path)
    assert len(q_embeddings) == len(qa_pairs), "❗ 질문 임베딩 수와 QA 쌍 수가 다릅니다."

    tokenized_corpus = [doc.split() for doc in corpus_df["text"]]
    bm25 = BM25Okapi(tokenized_corpus)
    generator = pipeline("text2text-generation", model=model_name)

    predictions = []
    references = []

    for idx, row in tqdm(qa_pairs.iterrows(), total=len(qa_pairs), desc="Evaluating Hybrid Generator"):
        question = row["question"]
        gold_answer = row["answer"]
        q_emb = q_embeddings[idx]

        # Step 1: BM25 top-N 후보
        tokenized_query = question.split()
        bm25_scores = bm25.get_scores(tokenized_query)
        bm25_top_indices = np.argsort(bm25_scores)[::-1][:bm25_top_n]

        # Step 2: DPR 점수 계산
        candidate_ctx_embs = ctx_embeddings[bm25_top_indices]
        dpr_scores = np.dot(candidate_ctx_embs, q_emb)

        # Step 3: 최종 top-k 인덱스
        topk_local_indices = np.argsort(dpr_scores)[::-1][:k]
        topk_doc_indices = [bm25_top_indices[i] for i in topk_local_indices]
        topk_texts = corpus_df.iloc[topk_doc_indices]["text"].tolist()

        # Step 4: Generator로 답 생성
        context = " ".join(topk_texts)
        input_text = f"question: {question} context: {context}"
        output = generator(input_text, max_length=64, do_sample=False)[0]["generated_text"]

        predictions.append(output.strip())
        references.append(gold_answer.strip())

    # Step 5: 평가
    accuracy = np.mean([pred == ref for pred, ref in zip(predictions, references)])
    f1 = np.mean([
        f1_score(ref.split(), pred.split(), average="micro") if ref and pred else 0
        for pred, ref in zip(predictions, references)
    ])

    print(f"📌 Hybrid Generator Accuracy: {accuracy:.4f} | F1: {f1:.4f}")
    return accuracy, f1

In [None]:
acc, f1 = evaluate_generator_with_hybrid(
    qa_pairs=qa_pairs,
    corpus_df=corpus_df,
    ctx_emb_path="/mnt/aix7101/jeong/aix_project/dpr_ctx_embeddings_multiqa.npy",
    q_emb_path="/mnt/aix7101/jeong/aix_project/dpr_q_embeddings_multiqa.npy",
    bm25_top_n=300,
    k=5,
    model_name="facebook/bart-base"
)

Evaluating Hybrid Recall@K: 100%|██████████| 87599/87599 [50:34<00:00, 28.86it/s]  

📌 Hybrid Recall@5 (BM25 top-300 + DPR top-5): 0.6972





In [None]:
with open("hybrid_recall_result.txt", "w", encoding="utf-8") as f:
    f.write(f"{recall_hybrid:.4f}")

## Custom Retrieval 구성 요소
1. 문장에서 keyword 추출 (phrase 단위로 추출할 수 있는 방법이 있는지)
2. 추출한 keyword와의 score도 함께 계산
4. query만으로 추출한 recall@k
5. keyword만으로 추출한 recall@k
6. 둘을 hybrid하는 것도 ㄱㅊ

### keyword extract function

In [None]:
import spacy
import pandas as pd

# 1. spaCy 영어 모델 로드
nlp = spacy.load("en_core_web_sm")

# 2. 의문사 리스트 정의
WH_WORDS = {"what", "who", "whom", "where", "when", "why", "how"}

# 3. keyphrase 추출 함수 정의
def extract_keyphrases_spacy(question: str):
    doc = nlp(question.lower())
    keyphrases = set()

    wh_word = None
    for token in doc:
        if token.text in WH_WORDS:
            wh_word = token.text
            break

    for chunk in doc.noun_chunks:
        if any(not token.is_stop and token.pos_ in {"NOUN", "PROPN"} for token in chunk):
            keyphrases.add(chunk.text.strip())

    # 의문사에 따른 힌트 키워드 추가
    if wh_word:
        hint_map = {
            "who": "person",
            "where": "location",
            "when": "time",
            "why": "reason",
            "how": "method",
        }
        hint = hint_map.get(wh_word)
        if hint:
            keyphrases.add(hint)

    return list(keyphrases)

### use keybert

In [None]:
from keybert import KeyBERT
from typing import List
import re

# KeyBERT 모델 초기화 (기본적으로 'all-MiniLM-L6-v2' 사용)
kw_model = KeyBERT(model='all-MiniLM-L6-v2')

def extract_keyphrases_keybert(question: str, top_n: int = 5, diversity: bool = False) -> List[str]:
    """
    KeyBERT 기반 keyphrase 추출 함수 (의문사 힌트 없음)

    Args:
        question (str): 입력 질문
        top_n (int): 추출할 키프레이즈 개수
        diversity (bool): MMR(Minimal Marginal Relevance) 사용 여부

    Returns:
        List[str]: 추출된 키프레이즈 리스트
    """
    question_clean = re.sub(r"[^\w\s]", "", question.lower())  # 간단한 전처리

    if diversity:
        keyphrases = kw_model.extract_keywords(
            question_clean,
            keyphrase_ngram_range=(1, 3),
            stop_words='english',
            use_mmr=True,
            diversity=0.7,
            top_n=top_n
        )
    else:
        keyphrases = kw_model.extract_keywords(
            question_clean,
            keyphrase_ngram_range=(1, 3),
            stop_words='english',
            top_n=top_n
        )

    return [phrase for phrase, _ in keyphrases]

1. keyphrase-based pre-filtering
- keyphrase를 추출
- 각 keyphrase를 embedding하고 corpus와의 유사도 계산을 통해 후보 100개씩 추출
2. query-to-context matching
- 전체 corpus가 아닌 후보 corpus와만 비교해서 최종 recall@k를 계산

In [None]:
from transformers import pipeline
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import normalize
from tqdm import tqdm
import numpy as np
import torch

def evaluate_generator_with_custom_keyphrase_retrieval(
    qa_pairs,
    corpus_df,
    ctx_emb_path,
    extract_keyphrases_fn,
    top_n_per_keyphrase=50,
    final_top_k=5,
    model_name="facebook/bart-base",
    device="cuda" if torch.cuda.is_available() else "cpu"
):
    """
    키프레이즈 기반 Hybrid Retrieval 후 Generator로 답변 생성하여 Accuracy/F1 평가

    Returns:
        tuple: (accuracy, f1_score)
    """
    ctx_embeddings = np.load(ctx_emb_path, allow_pickle=True)
    ctx_embeddings = normalize(ctx_embeddings)

    generator = pipeline("text2text-generation", model=model_name, device=0 if device == "cuda" else -1)

    predictions = []
    references = []

    for idx, row in tqdm(qa_pairs.iterrows(), total=len(qa_pairs), desc="Evaluating Generator (Custom Keyphrase)"):
        question = row["question"]
        gold_answer = row["answer"]
        gt_doc_id = row["doc_id"]

        # 1. 키프레이즈 추출
        keyphrases = extract_keyphrases_fn(question)
        if not keyphrases:
            continue

        # 2. 키프레이즈 임베딩
        phrase_embs = []
        for phrase in keyphrases:
            inputs = ctx_tokenizer(phrase, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            with torch.no_grad():
                emb = ctx_encoder(**inputs).pooler_output[0].cpu().numpy()
            phrase_embs.append(emb)
        phrase_embs = normalize(np.stack(phrase_embs))

        # 3. 키워드 별 후보 문서 수집
        candidate_indices = set()
        for emb in phrase_embs:
            scores = np.dot(ctx_embeddings, emb)
            top_indices = np.argsort(scores)[::-1][:top_n_per_keyphrase]
            candidate_indices.update(top_indices)

        if not candidate_indices:
            continue

        # 4. 쿼리 임베딩
        q_inputs = q_tokenizer(question, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
        q_inputs = {k: v.to(device) for k, v in q_inputs.items()}
        with torch.no_grad():
            query_emb = q_encoder(**q_inputs).pooler_output[0].cpu().numpy()
        query_emb = normalize(query_emb.reshape(1, -1))[0]

        # 5. 재랭킹
        candidate_indices = list(candidate_indices)
        candidate_embs = ctx_embeddings[candidate_indices]
        rerank_scores = np.dot(candidate_embs, query_emb)
        top_k_indices = np.argsort(rerank_scores)[::-1][:final_top_k]
        top_k_doc_ids = [corpus_df.iloc[candidate_indices[i]]["doc_id"] for i in top_k_indices]
        top_k_texts = [corpus_df.iloc[candidate_indices[i]]["text"] for i in top_k_indices]

        # 6. 답변 생성
        context = " ".join(top_k_texts)
        input_text = f"question: {question} context: {context}"
        output = generator(input_text, max_length=64, do_sample=False)[0]["generated_text"]

        predictions.append(output.strip())
        references.append(gold_answer.strip())

    # 7. 평가
    accuracy = np.mean([pred == ref for pred, ref in zip(predictions, references)])
    f1 = np.mean([
        f1_score(ref.split(), pred.split(), average="micro") if ref and pred else 0
        for pred, ref in zip(predictions, references)
    ])

    print(f"📌 Custom Keyphrase Generator Accuracy: {accuracy:.4f} | F1: {f1:.4f}")
    return accuracy, f1

In [None]:
acc, f1 = evaluate_generator_with_custom_keyphrase_retrieval(
    qa_pairs=qa_pairs,
    corpus_df=corpus_df,
    ctx_emb_path="/mnt/aix7101/jeong/aix_project/dpr_ctx_embeddings2.npy",
    extract_keyphrases_fn=extract_keyphrases_spacy,  # 앞서 정의한 spaCy 기반 함수
    top_n_per_keyphrase=100,
    final_top_k=5
)

Hybrid DPR Retrieval Recall@K:   1%|▏         | 1102/87599 [00:52<1:55:53, 12.44it/s]

In [None]:
acc, f1 = evaluate_generator_with_custom_keyphrase_retrieval(
    qa_pairs=qa_pairs,
    corpus_df=corpus_df,
    ctx_emb_path="/mnt/aix7101/jeong/aix_project/dpr_ctx_embeddings2.npy",
    extract_keyphrases_fn=extract_keyphrases_keybert,  # 앞서 정의한 spaCy 기반 함수
    top_n_per_keyphrase=100,
    final_top_k=5
)

Custom Retrieval Recall@K:  25%|██▌       | 22272/87599 [25:53<1:03:49, 17.06it/s]

In [None]:
# use keybert

In [None]:
#-- retrieval

In [None]:
#-- custom checking code
