In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/nq10k-comp5423/New_metrics_calculation.py
/kaggle/input/nq10k-comp5423/data_and_code/metrics_calculation.py
/kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl
/kaggle/input/nq10k-comp5423/data_and_code/data/test.jsonl
/kaggle/input/nq10k-comp5423/data_and_code/data/(example) test_predict.jsonl
/kaggle/input/nq10k-comp5423/data_and_code/data/val_predict.jsonl
/kaggle/input/nq10k-comp5423/data_and_code/data/train.jsonl
/kaggle/input/nq10k-comp5423/data_and_code/data/documents.jsonl
/kaggle/input/crawl-300d-2m/crawl-300d-2M.vec


In [2]:
import json
import pandas as pd
import os

# --- 配置 ---
# 定义数据文件所在的目录路径
# 请确保 'data' 文件夹与此脚本在同一目录下，或者提供绝对路径
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data/'
DOCUMENTS_FILE = os.path.join(DATA_DIR, 'documents.jsonl')
TRAIN_FILE = os.path.join(DATA_DIR, 'train.jsonl')
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl')
TEST_FILE = os.path.join(DATA_DIR, 'test.jsonl')

# --- 函数：加载 JSON Lines 文件 ---
def load_jsonl(file_path):
    """
    加载 JSON Lines 文件 (.jsonl) 并返回数据列表。
    每个元素是一个从 JSON 行解析出的 Python 字典。

    Args:
        file_path (str): .jsonl 文件的路径。

    Returns:
        list[dict]: 包含文件中所有 JSON 对象的列表。
                    如果文件不存在或为空，则返回空列表。
    """
    data = []
    if not os.path.exists(file_path):
        print(f"警告: 文件未找到 {file_path}")
        return data
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip(): # 确保行不为空
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"警告: 跳过无法解析的行: {line.strip()} - 错误: {e}")
        print(f"成功加载 {len(data)} 条记录从 {file_path}")
    except Exception as e:
        print(f"加载文件时出错 {file_path}: {e}")
    return data

# --- 函数：将列表数据转换为 Pandas DataFrame ---
def list_to_dataframe(data_list):
    """
    将字典列表转换为 Pandas DataFrame。

    Args:
        data_list (list[dict]): 字典列表。

    Returns:
        pd.DataFrame: 转换后的 DataFrame。如果列表为空，则返回空的 DataFrame。
    """
    if not data_list:
        return pd.DataFrame()
    return pd.DataFrame(data_list)

# --- 主程序 ---
if __name__ == "__main__":
    print("开始加载 NQ10K 数据集...")

    # 加载文档数据
    print(f"\n--- 加载文档 ---")
    documents_data = load_jsonl(DOCUMENTS_FILE)
    documents_df = list_to_dataframe(documents_data)
    if not documents_df.empty:
        print(f"文档数据预览 (前 5 条):")
        print(documents_df.head())
        print(f"\n文档总数: {len(documents_df)}")
        print(f"列名: {documents_df.columns.tolist()}")
        # 检查是否有缺失值
        print("\n文档数据缺失值检查:")
        print(documents_df.isnull().sum())
        # 将文档存储在字典中，方便后续按 ID 查找
        documents_dict = {doc['document_id']: doc['document_text'] for doc in documents_data if 'document_id' in doc and 'document_text' in doc}
        print(f"已将 {len(documents_dict)} 个文档存入字典。")
    else:
        print("未能加载文档数据或文档数据为空。")
        documents_dict = {}

    # 加载训练数据
    print(f"\n--- 加载训练数据 ---")
    train_data = load_jsonl(TRAIN_FILE)
    train_df = list_to_dataframe(train_data)
    if not train_df.empty:
        print(f"训练数据预览 (前 5 条):")
        print(train_df.head())
        print(f"\n训练样本总数: {len(train_df)}")
        print(f"列名: {train_df.columns.tolist()}")
        print("\n训练数据缺失值检查:")
        print(train_df.isnull().sum())
    else:
        print("未能加载训练数据或训练数据为空。")

    # 加载验证数据
    print(f"\n--- 加载验证数据 ---")
    val_data = load_jsonl(VAL_FILE)
    val_df = list_to_dataframe(val_data)
    if not val_df.empty:
        print(f"验证数据预览 (前 5 条):")
        print(val_df.head())
        print(f"\n验证样本总数: {len(val_df)}")
        print(f"列名: {val_df.columns.tolist()}")
        print("\n验证数据缺失值检查:")
        print(val_df.isnull().sum())
    else:
        print("未能加载验证数据或验证数据为空。")

    # 加载测试数据 (注意：测试集没有答案和文档 ID)
    print(f"\n--- 加载测试数据 ---")
    test_data = load_jsonl(TEST_FILE)
    test_df = list_to_dataframe(test_data)
    if not test_df.empty:
        print(f"测试数据预览 (前 5 条):")
        print(test_df.head())
        print(f"\n测试样本总数: {len(test_df)}")
        print(f"列名: {test_df.columns.tolist()}")
        print("\n测试数据缺失值检查:")
        print(test_df.isnull().sum()) # 预期 answer 和 document_id 列为 null
    else:
        print("未能加载测试数据或测试数据为空。")

    print("\n数据加载和初步探索完成。")


开始加载 NQ10K 数据集...

--- 加载文档 ---
成功加载 12138 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/documents.jsonl
文档数据预览 (前 5 条):
   document_id                                      document_text
0            0  Email marketing - Wikipedia <H1> Email marketi...
1            1  The Mother ( How I Met Your Mother ) - wikiped...
2            2  Human fertilization - wikipedia <H1> Human fer...
3            3  List of National Football League career quarte...
4            4  Roanoke Colony - wikipedia <H1> Roanoke Colony...

文档总数: 12138
列名: ['document_id', 'document_text']

文档数据缺失值检查:
document_id      0
document_text    0
dtype: int64
已将 12138 个文档存入字典。

--- 加载训练数据 ---
成功加载 8000 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/train.jsonl
训练数据预览 (前 5 条):
                                            question  \
0           what did the huns do to the roman empire   
1       who won women's singles australian open 2018   
2         who plays the gunslinger in the dark tower   
3        what 

In [3]:
import json
import os
import re # 用于正则表达式，处理 HTML 标签等
import pandas as pd
import nltk # 自然语言处理工具包
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# --- 配置 ---
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data' # 假设数据在 'data' 文件夹下
DOCUMENTS_FILE = os.path.join(DATA_DIR, 'documents.jsonl')
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl')

# --- 下载 NLTK 数据 (如果尚未下载) ---
try:
    nltk.data.find('tokenizers/punkt')
except nltk.downloader.DownloadError:
    print("下载 NLTK 'punkt' 数据...")
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('corpora/stopwords')
except nltk.downloader.DownloadError:
    print("下载 NLTK 'stopwords' 数据...")
    nltk.download('stopwords', quiet=True)

# --- 加载停用词 ---
# 使用英文停用词列表
stop_words = set(stopwords.words('english'))

# --- 函数：加载 JSON Lines 文件 ---
def load_jsonl(file_path):
    """加载 JSON Lines 文件 (.jsonl) 并返回数据列表。"""
    data = []
    if not os.path.exists(file_path):
        print(f"警告: 文件未找到 {file_path}")
        return data
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"警告: 跳过无法解析的行: {line.strip()} - 错误: {e}")
        print(f"成功加载 {len(data)} 条记录从 {file_path}")
    except Exception as e:
        print(f"加载文件时出错 {file_path}: {e}")
    return data

# --- 函数：文本预处理 ---
def preprocess_text(text):
    """
    对文本进行预处理：
    1. 移除 HTML 标签
    2. 移除标点符号和数字
    3. 转换为小写
    4. 分词
    5. 移除停用词
    """
    if not isinstance(text, str):
        return ""
    # 1. 移除 HTML 标签 (简单的正则表达式)
    text = re.sub(r'<[^>]+>', ' ', text)
    # 2. 移除标点符号和数字，只保留字母和空格
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    # 3. 转换为小写
    text = text.lower()
    # 4. 分词
    tokens = word_tokenize(text)
    # 5. 移除停用词
    filtered_tokens = [word for word in tokens if word.isalpha() and word not in stop_words]
    # 将处理后的词语列表重新组合成字符串，供 TfidfVectorizer 使用
    return " ".join(filtered_tokens)

# --- TF-IDF 检索器类 ---
class TfidfRetriever:
    """
    使用 TF-IDF 进行文档检索的类。
    """
    def __init__(self):
        # 初始化 TfidfVectorizer
        # max_df=0.85: 忽略在超过 85% 的文档中出现的词语 (过于常见)
        # min_df=2: 忽略在少于 2 个文档中出现的词语 (过于罕见或可能是噪音)
        self.vectorizer = TfidfVectorizer(max_df=0.85, min_df=2)
        self.tfidf_matrix = None
        self.document_ids = []
        self.documents_dict = {} # 存储原始文档内容，方便查看

    def build_index(self, documents_data):
        """
        使用文档数据构建 TF-IDF 索引。

        Args:
            documents_data (list[dict]): 从 documents.jsonl 加载的文档数据列表。
                                         每个字典应包含 'document_id' 和 'document_text'。
        """
        print("开始构建 TF-IDF 索引...")
        if not documents_data:
            print("错误：文档数据为空，无法构建索引。")
            return

        # 提取文档 ID 和文本
        doc_texts = []
        self.document_ids = []
        self.documents_dict = {}

        for doc in documents_data:
            if 'document_id' in doc and 'document_text' in doc:
                self.document_ids.append(doc['document_id'])
                # 预处理文档文本
                processed_text = preprocess_text(doc['document_text'])
                doc_texts.append(processed_text)
                self.documents_dict[doc['document_id']] = doc['document_text'] # 存储原始文本
            else:
                print(f"警告: 跳过格式不正确的文档记录: {doc}")

        if not doc_texts:
            print("错误：没有有效的文档文本用于构建索引。")
            return

        # 拟合 TfidfVectorizer 并转换文档文本
        print(f"正在对 {len(doc_texts)} 个文档进行 TF-IDF 向量化...")
        self.tfidf_matrix = self.vectorizer.fit_transform(doc_texts)
        print(f"TF-IDF 索引构建完成。矩阵形状: {self.tfidf_matrix.shape}")
        print(f"词汇表大小: {len(self.vectorizer.get_feature_names_out())}")

    def retrieve(self, query, top_n=5):
        """
        根据查询检索最相关的 top_n 个文档 ID。

        Args:
            query (str): 用户输入的查询问题。
            top_n (int): 要返回的最相关文档的数量。

        Returns:
            list[int]: 按相关性排序的 top_n 个文档 ID 列表。
                       如果索引未构建或查询无效，则返回空列表。
        """
        if self.tfidf_matrix is None or not self.document_ids:
            print("错误：TF-IDF 索引尚未构建。请先调用 build_index()。")
            return []
        if not query or not isinstance(query, str):
            print("错误：查询无效。")
            return []

        # 1. 预处理查询
        processed_query = preprocess_text(query)
        if not processed_query:
            print("警告：预处理后的查询为空。")
            return []

        # 2. 将查询转换为 TF-IDF 向量
        # 注意：使用 transform 而不是 fit_transform，因为要用已有的词汇表和 IDF
        query_vector = self.vectorizer.transform([processed_query])

        # 3. 计算查询向量与所有文档向量的余弦相似度
        # cosine_similarity 返回一个形状为 (n_queries, n_documents) 的矩阵
        cosine_similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten()

        # 4. 获取相似度最高的 top_n 个文档的索引
        # argsort 返回排序后的索引，[::-1] 将其反转为降序
        # [:top_n] 取前 n 个
        top_n_indices = np.argsort(cosine_similarities)[::-1][:top_n]

        # 5. 将索引映射回文档 ID
        retrieved_doc_ids = [self.document_ids[i] for i in top_n_indices]

        # (可选) 打印检索结果和相似度得分
        # print(f"\n查询: '{query}'")
        # print(f"检索到的 Top-{top_n} 文档 ID (及相似度):")
        # for i, doc_id in enumerate(retrieved_doc_ids):
        #     similarity_score = cosine_similarities[top_n_indices[i]]
        #     print(f"  {i+1}. ID: {doc_id} (相似度: {similarity_score:.4f})")
            # print(f"     内容预览: {self.documents_dict.get(doc_id, 'N/A')[:100]}...") # 打印部分原文

        return retrieved_doc_ids

# --- 主程序：演示 TF-IDF 检索 ---
if __name__ == "__main__":
    # 1. 加载文档数据
    print("--- 加载文档数据 ---")
    documents_data = load_jsonl(DOCUMENTS_FILE)

    # 2. 初始化并构建 TF-IDF 检索器索引
    retriever = TfidfRetriever()
    retriever.build_index(documents_data)

    # 3. 加载验证数据以获取示例问题
    print("\n--- 加载验证数据以获取示例 ---")
    val_data = load_jsonl(VAL_FILE)

    if val_data and retriever.tfidf_matrix is not None:
        # 4. 选择一个示例问题进行检索
        example_index = 0 # 可以修改这个索引来测试不同的问题
        example_question = val_data[example_index]['question']
        true_doc_id = val_data[example_index]['document_id']

        print(f"\n--- 使用 TF-IDF 进行检索示例 ---")
        print(f"示例问题 (来自 val.jsonl[{example_index}]): '{example_question}'")
        print(f"真实相关的文档 ID: {true_doc_id}")

        # 5. 执行检索
        retrieved_ids = retriever.retrieve(example_question, top_n=5)

        # 6. 显示检索结果
        print(f"\nTF-IDF 检索到的 Top-5 文档 ID:")
        print(retrieved_ids)

        # (可选) 检查真实文档是否在前 5 个结果中
        if true_doc_id in retrieved_ids:
            rank = retrieved_ids.index(true_doc_id) + 1
            print(f"成功! 真实文档 ID {true_doc_id} 在检索结果中排名第 {rank}。")
        else:
            print(f"失败。真实文档 ID {true_doc_id} 未在前 5 个检索结果中。")

        # (可选) 查看检索到的第一个文档的原文
        if retrieved_ids:
             first_retrieved_id = retrieved_ids[0]
             print(f"\n检索到的第一个文档 (ID: {first_retrieved_id}) 的原文预览:")
             print(retriever.documents_dict.get(first_retrieved_id, "未找到原文")[:500] + "...") # 显示前 500 个字符

    elif retriever.tfidf_matrix is None:
         print("\n由于索引构建失败，无法执行检索示例。")
    else:
        print("\n未能加载验证数据，无法执行检索示例。")



--- 加载文档数据 ---
成功加载 12138 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/documents.jsonl
开始构建 TF-IDF 索引...
正在对 12138 个文档进行 TF-IDF 向量化...
TF-IDF 索引构建完成。矩阵形状: (12138, 279270)
词汇表大小: 279270

--- 加载验证数据以获取示例 ---
成功加载 1000 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl

--- 使用 TF-IDF 进行检索示例 ---
示例问题 (来自 val.jsonl[0]): 'when did the british first land in north america'
真实相关的文档 ID: 11484

TF-IDF 检索到的 Top-5 文档 ID:
[8852, 7355, 11484, 1592, 860]
成功! 真实文档 ID 11484 在检索结果中排名第 3。

检索到的第一个文档 (ID: 8852) 的原文预览:
Geography of North America - wikipedia <H1> Geography of North America </H1> Jump to : navigation , search Global view centered on North America <P> North America is the third largest continent , and is also a portion of the second largest supercontinent if North and South America are combined into the Americas and Africa , Europe , and Asia are considered to be part of one supercontinent called Afro - Eurasia . </P> <P> With an estimated population of 460 million and an a

In [4]:
!pip install rank_bm25

Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25
Successfully installed rank_bm25-0.2.2


In [5]:
import json
import os
import re # 用于正则表达式
import pandas as pd
import nltk # 自然语言处理工具包
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Okapi # BM25 库
import numpy as np
import time # 用于计时

# --- 配置 ---
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data' # 假设数据在 'data' 文件夹下
DOCUMENTS_FILE = os.path.join(DATA_DIR, 'documents.jsonl')
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl')

# --- 下载 NLTK 数据 (如果尚未下载) ---
try:
    nltk.data.find('tokenizers/punkt')
except nltk.downloader.DownloadError:
    print("下载 NLTK 'punkt' 数据...")
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('corpora/stopwords')
except nltk.downloader.DownloadError:
    print("下载 NLTK 'stopwords' 数据...")
    nltk.download('stopwords', quiet=True)

# --- 加载停用词 ---
stop_words = set(stopwords.words('english'))

# --- 函数：加载 JSON Lines 文件 ---
def load_jsonl(file_path):
    """加载 JSON Lines 文件 (.jsonl) 并返回数据列表。"""
    data = []
    if not os.path.exists(file_path):
        print(f"警告: 文件未找到 {file_path}")
        return data
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"警告: 跳过无法解析的行: {line.strip()} - 错误: {e}")
        print(f"成功加载 {len(data)} 条记录从 {file_path}")
    except Exception as e:
        print(f"加载文件时出错 {file_path}: {e}")
    return data

# --- 函数：文本预处理 (适用于 BM25，返回 token 列表) ---
def preprocess_text_bm25(text):
    """
    对文本进行预处理，返回 token 列表：
    1. 移除 HTML 标签
    2. 移除标点符号和数字
    3. 转换为小写
    4. 分词
    5. 移除停用词
    """
    if not isinstance(text, str):
        return []
    # 1. 移除 HTML 标签
    text = re.sub(r'<[^>]+>', ' ', text)
    # 2. 移除标点符号和数字
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    # 3. 转换为小写
    text = text.lower()
    # 4. 分词
    tokens = word_tokenize(text)
    # 5. 移除停用词，并确保是字母
    filtered_tokens = [word for word in tokens if word.isalpha() and word not in stop_words]
    return filtered_tokens

# --- BM25 检索器类 ---
class Bm25Retriever:
    """
    使用 BM25 进行文档检索的类。
    """
    def __init__(self):
        self.bm25_index = None
        self.document_ids = []
        self.documents_dict = {} # 存储原始文档内容

    def build_index(self, documents_data):
        """
        使用文档数据构建 BM25 索引。

        Args:
            documents_data (list[dict]): 从 documents.jsonl 加载的文档数据列表。
        """
        print("开始构建 BM25 索引...")
        start_time = time.time()
        if not documents_data:
            print("错误：文档数据为空，无法构建索引。")
            return

        tokenized_corpus = []
        self.document_ids = []
        self.documents_dict = {}

        print(f"正在对 {len(documents_data)} 个文档进行预处理和分词...")
        count = 0
        for doc in documents_data:
            if 'document_id' in doc and 'document_text' in doc:
                self.document_ids.append(doc['document_id'])
                self.documents_dict[doc['document_id']] = doc['document_text']
                # 使用适用于 BM25 的预处理函数
                processed_tokens = preprocess_text_bm25(doc['document_text'])
                tokenized_corpus.append(processed_tokens)
                count += 1
                if count % 1000 == 0:
                    print(f"  已处理 {count}/{len(documents_data)} 个文档...")
            else:
                print(f"警告: 跳过格式不正确的文档记录: {doc}")

        if not tokenized_corpus:
            print("错误：没有有效的文档文本用于构建索引。")
            return

        print(f"\n使用 {len(tokenized_corpus)} 个文档的 token 列表初始化 BM25Okapi...")
        # 使用分词后的文档列表初始化 BM25Okapi
        self.bm25_index = BM25Okapi(tokenized_corpus)
        end_time = time.time()
        print(f"BM25 索引构建完成。耗时: {end_time - start_time:.2f} 秒")

    def retrieve(self, query, top_n=5):
        """
        根据查询检索最相关的 top_n 个文档 ID。

        Args:
            query (str): 用户输入的查询问题。
            top_n (int): 要返回的最相关文档的数量。

        Returns:
            list[int]: 按相关性排序的 top_n 个文档 ID 列表。
        """
        if self.bm25_index is None or not self.document_ids:
            print("错误：BM25 索引尚未构建。请先调用 build_index()。")
            return []
        if not query or not isinstance(query, str):
            print("错误：查询无效。")
            return []

        # 1. 预处理查询 (得到 token 列表)
        tokenized_query = preprocess_text_bm25(query)
        if not tokenized_query:
            print("警告：预处理后的查询为空。")
            return []

        # 2. 计算查询与所有文档的 BM25 得分
        # get_scores 返回一个包含所有文档得分的 numpy 数组
        doc_scores = self.bm25_index.get_scores(tokenized_query)

        # 3. 获取得分最高的 top_n 个文档的索引
        # argsort 返回排序后的索引，[::-1] 将其反转为降序
        top_n_indices = np.argsort(doc_scores)[::-1][:top_n]

        # 4. 将索引映射回文档 ID
        retrieved_doc_ids = [self.document_ids[i] for i in top_n_indices]

        # (可选) 打印检索结果和得分
        # print(f"\n查询: '{query}'")
        # print(f"BM25 检索到的 Top-{top_n} 文档 ID (及得分):")
        # for i, doc_id in enumerate(retrieved_doc_ids):
        #     score = doc_scores[top_n_indices[i]]
        #     print(f"  {i+1}. ID: {doc_id} (得分: {score:.4f})")

        return retrieved_doc_ids

# --- 主程序：演示 BM25 检索 ---
if __name__ == "__main__":
    # 1. 加载文档数据
    print("--- 加载文档数据 ---")
    documents_data = load_jsonl(DOCUMENTS_FILE)

    # 2. 初始化并构建 BM25 检索器索引
    retriever = Bm25Retriever()
    retriever.build_index(documents_data)

    # 3. 加载验证数据以获取示例问题
    print("\n--- 加载验证数据以获取示例 ---")
    val_data = load_jsonl(VAL_FILE)

    if val_data and retriever.bm25_index is not None:
        # 4. 选择与 TF-IDF 相同的示例问题进行检索
        example_index = 0 # 与 TF-IDF 示例保持一致
        example_question = val_data[example_index]['question']
        true_doc_id = val_data[example_index]['document_id']

        print(f"\n--- 使用 BM25 进行检索示例 ---")
        print(f"示例问题 (来自 val.jsonl[{example_index}]): '{example_question}'")
        print(f"真实相关的文档 ID: {true_doc_id}")

        # 5. 执行检索
        retrieved_ids = retriever.retrieve(example_question, top_n=5)

        # 6. 显示检索结果
        print(f"\nBM25 检索到的 Top-5 文档 ID:")
        print(retrieved_ids)

        # (可选) 检查真实文档是否在前 5 个结果中
        if true_doc_id in retrieved_ids:
            rank = retrieved_ids.index(true_doc_id) + 1
            print(f"成功! 真实文档 ID {true_doc_id} 在检索结果中排名第 {rank}。")
        else:
            print(f"失败。真实文档 ID {true_doc_id} 未在前 5 个检索结果中。")

        # (可选) 查看检索到的第一个文档的原文
        if retrieved_ids:
             first_retrieved_id = retrieved_ids[0]
             print(f"\n检索到的第一个文档 (ID: {first_retrieved_id}) 的原文预览:")
             print(retriever.documents_dict.get(first_retrieved_id, "未找到原文")[:500] + "...")

    elif retriever.bm25_index is None:
         print("\n由于索引构建失败，无法执行检索示例。")
    else:
        print("\n未能加载验证数据，无法执行检索示例。")



--- 加载文档数据 ---
成功加载 12138 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/documents.jsonl
开始构建 BM25 索引...
正在对 12138 个文档进行预处理和分词...
  已处理 1000/12138 个文档...
  已处理 2000/12138 个文档...
  已处理 3000/12138 个文档...
  已处理 4000/12138 个文档...
  已处理 5000/12138 个文档...
  已处理 6000/12138 个文档...
  已处理 7000/12138 个文档...
  已处理 8000/12138 个文档...
  已处理 9000/12138 个文档...
  已处理 10000/12138 个文档...
  已处理 11000/12138 个文档...
  已处理 12000/12138 个文档...

使用 12138 个文档的 token 列表初始化 BM25Okapi...
BM25 索引构建完成。耗时: 342.66 秒

--- 加载验证数据以获取示例 ---
成功加载 1000 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl

--- 使用 BM25 进行检索示例 ---
示例问题 (来自 val.jsonl[0]): 'when did the british first land in north america'
真实相关的文档 ID: 11484

BM25 检索到的 Top-5 文档 ID:
[11423, 7092, 411, 2358, 3066]
失败。真实文档 ID 11484 未在前 5 个检索结果中。

检索到的第一个文档 (ID: 11423) 的原文预览:
History of Antarctica - Wikipedia <H1> History of Antarctica </H1> Jump to : navigation , search <Dl> <Dd> For the natural history of the Antarctic continent , see Antarctica . </Dd

In [6]:
import json
import os
import re
import pandas as pd
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import gensim # 用于加载和使用 FastText/Word2Vec 模型
# 不再需要 downloader，因为我们直接加载本地文件
# import gensim.downloader as api
from gensim.models import KeyedVectors # 用于加载 .vec 文件
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import time

# --- 配置 ---
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data' # 假设数据在 'data' 文件夹下
DOCUMENTS_FILE = os.path.join(DATA_DIR, 'documents.jsonl')
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl')

# --- 重要：修改这里 ---
# 将 PRETRAINED_MODEL_PATH 设置为你的本地 .vec 文件路径
PRETRAINED_MODEL_PATH = '/kaggle/input/crawl-300d-2m/crawl-300d-2M.vec'
# 如果你想尝试 gensim downloader (需要网络和下载)，可以取消注释下面这行
# PRETRAINED_MODEL_NAME_DOWNLOAD = 'fasttext-wiki-news-subwords-300'

# --- 下载 NLTK 数据 (如果尚未下载) ---
try:
    nltk.data.find('tokenizers/punkt')
except nltk.downloader.DownloadError:
    print("下载 NLTK 'punkt' 数据...")
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('corpora/stopwords')
except nltk.downloader.DownloadError:
    print("下载 NLTK 'stopwords' 数据...")
    nltk.download('stopwords', quiet=True)

# --- 加载停用词 ---
stop_words = set(stopwords.words('english'))

# --- 函数：加载 JSON Lines 文件 ---
def load_jsonl(file_path):
    """加载 JSON Lines 文件 (.jsonl) 并返回数据列表。"""
    data = []
    if not os.path.exists(file_path):
        print(f"警告: 文件未找到 {file_path}")
        return data
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"警告: 跳过无法解析的行: {line.strip()} - 错误: {e}")
        print(f"成功加载 {len(data)} 条记录从 {file_path}")
    except Exception as e:
        print(f"加载文件时出错 {file_path}: {e}")
    return data

# --- 函数：文本预处理 (返回 token 列表) ---
def preprocess_text_tokens(text):
    """
    对文本进行预处理，返回 token 列表：
    1. 移除 HTML 标签
    2. 移除标点符号和数字
    3. 转换为小写
    4. 分词
    5. 移除停用词
    """
    if not isinstance(text, str):
        return []
    # 1. 移除 HTML 标签
    text = re.sub(r'<[^>]+>', ' ', text)
    # 2. 移除标点符号和数字
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    # 3. 转换为小写
    text = text.lower()
    # 4. 分词
    tokens = word_tokenize(text)
    # 5. 移除停用词
    filtered_tokens = [word for word in tokens if word.isalpha() and word not in stop_words]
    return filtered_tokens

# --- 函数：计算文本的平均词向量 ---
def get_average_vector(tokens, model):
    """
    计算给定 token 列表的平均词向量。
    忽略模型词汇表中不存在的词。
    """
    vectors = []
    # KeyedVectors 使用 model.key_to_index 检查词是否存在
    # 或者直接尝试访问，如果不存在会抛出 KeyError
    for token in tokens:
        try:
            vectors.append(model[token])
        except KeyError:
            # 忽略不在词汇表中的词
            pass

    if not vectors:
        # 如果文本中没有词在模型词汇表中，返回零向量
        return np.zeros(model.vector_size)

    # 计算向量的平均值
    return np.mean(vectors, axis=0)

# --- 词向量检索器类 (名称保持不变，但现在加载 .vec) ---
class FastTextRetriever:
    """
    使用预训练词向量 (.vec 文件) 进行文档检索的类。
    """
    def __init__(self, model_path=PRETRAINED_MODEL_PATH):
        self.model_path = model_path # 修改为 model_path
        self.model = None
        self.doc_vectors = None
        self.document_ids = []
        self.documents_dict = {} # 存储原始文档内容

    def load_model(self):
        """从本地文件路径加载预训练的词向量模型 (.vec 格式)。"""
        print(f"开始从本地文件加载预训练模型: {self.model_path} ...")
        print("这可能需要一些时间，取决于模型大小和你的机器性能。")
        start_time = time.time()
        if not os.path.exists(self.model_path):
             print(f"错误: 模型文件未找到: {self.model_path}")
             self.model = None
             return
        try:
            # --- 修改：使用 KeyedVectors.load_word2vec_format 加载 ---
            # binary=False 表示是文本格式的 .vec 文件
            # limit 参数可以限制加载的向量数量，用于快速测试 (例如 limit=500000)
            # 如果内存足够，可以移除 limit 参数加载完整模型
            self.model = KeyedVectors.load_word2vec_format(self.model_path, binary=False) #, limit=500000)
            end_time = time.time()
            print(f"模型加载完成。耗时: {end_time - start_time:.2f} 秒")
            print(f"向量维度: {self.model.vector_size}")
            print(f"词汇表大小: {len(self.model.key_to_index)}") # 显示加载的词汇量
        except Exception as e:
            print(f"加载模型时发生错误: {e}")
            print("请确保文件路径正确且文件是有效的 .vec 格式。")
            self.model = None

    def build_index(self, documents_data):
        """
        使用文档数据计算并存储文档向量。

        Args:
            documents_data (list[dict]): 从 documents.jsonl 加载的文档数据列表。
        """
        if self.model is None:
            print("错误：预训练模型未加载。请先调用 load_model()。")
            return

        print("开始构建文档向量索引...")
        start_time = time.time()
        if not documents_data:
            print("错误：文档数据为空，无法构建索引。")
            return

        doc_vectors_list = []
        self.document_ids = []
        self.documents_dict = {}

        print(f"正在对 {len(documents_data)} 个文档进行预处理和向量化...")
        count = 0
        for doc in documents_data:
            if 'document_id' in doc and 'document_text' in doc:
                self.document_ids.append(doc['document_id'])
                self.documents_dict[doc['document_id']] = doc['document_text']
                # 预处理文档文本得到 tokens
                processed_tokens = preprocess_text_tokens(doc['document_text'])
                # 计算文档的平均向量
                doc_vector = get_average_vector(processed_tokens, self.model)
                doc_vectors_list.append(doc_vector)
                count += 1
                if count % 1000 == 0:
                    print(f"  已处理 {count}/{len(documents_data)} 个文档...")
            else:
                print(f"警告: 跳过格式不正确的文档记录: {doc}")

        if not doc_vectors_list:
            print("错误：未能为任何文档生成向量。")
            return

        # 将向量列表转换为 numpy 数组，方便计算
        self.doc_vectors = np.array(doc_vectors_list)
        end_time = time.time()
        print(f"文档向量索引构建完成。向量矩阵形状: {self.doc_vectors.shape}")
        print(f"耗时: {end_time - start_time:.2f} 秒")


    def retrieve(self, query, top_n=5):
        """
        根据查询检索最相关的 top_n 个文档 ID。

        Args:
            query (str): 用户输入的查询问题。
            top_n (int): 要返回的最相关文档的数量。

        Returns:
            list[int]: 按相关性排序的 top_n 个文档 ID 列表。
        """
        if self.model is None or self.doc_vectors is None or not self.document_ids:
            print("错误：模型或文档向量索引尚未构建。")
            return []
        if not query or not isinstance(query, str):
            print("错误：查询无效。")
            return []

        # 1. 预处理查询并计算查询向量
        tokenized_query = preprocess_text_tokens(query)
        if not tokenized_query:
            print("警告：预处理后的查询为空。")
            return []
        query_vector = get_average_vector(tokenized_query, self.model)

        # 如果查询向量全为零 (例如，查询中所有词都不在模型词汇表里)
        if np.all(query_vector == 0):
             print("警告：无法为查询生成有效向量 (可能所有词都不在词汇表中)。")
             return []

        # 2. 计算查询向量与所有文档向量的余弦相似度
        # query_vector 需要 reshape 成 (1, vector_size) 来进行计算
        cosine_similarities = cosine_similarity(query_vector.reshape(1, -1), self.doc_vectors).flatten()

        # 3. 获取相似度最高的 top_n 个文档的索引
        top_n_indices = np.argsort(cosine_similarities)[::-1][:top_n]

        # 4. 将索引映射回文档 ID
        retrieved_doc_ids = [self.document_ids[i] for i in top_n_indices]

        # (可选) 打印检索结果和得分
        # print(f"\n查询: '{query}'")
        # print(f"词向量检索到的 Top-{top_n} 文档 ID (及相似度):")
        # for i, doc_id in enumerate(retrieved_doc_ids):
        #     similarity_score = cosine_similarities[top_n_indices[i]]
        #     print(f"  {i+1}. ID: {doc_id} (相似度: {similarity_score:.4f})")

        return retrieved_doc_ids

# --- 主程序：演示词向量检索 ---
if __name__ == "__main__":
    # 1. 初始化检索器并加载预训练模型 (从本地文件)
    print("--- 初始化词向量检索器并加载本地模型 ---")
    # 注意：这里的 model_path 参数现在指向你的本地文件
    retriever = FastTextRetriever(model_path=PRETRAINED_MODEL_PATH)
    retriever.load_model() # 加载本地模型

    if retriever.model: # 仅在模型加载成功后继续
        # 2. 加载文档数据
        print("\n--- 加载文档数据 ---")
        documents_data = load_jsonl(DOCUMENTS_FILE)

        # 3. 构建文档向量索引
        retriever.build_index(documents_data)

        # 4. 加载验证数据以获取示例问题
        print("\n--- 加载验证数据以获取示例 ---")
        val_data = load_jsonl(VAL_FILE)

        if val_data and retriever.doc_vectors is not None:
            # 5. 选择与之前相同的示例问题进行检索
            example_index = 0
            example_question = val_data[example_index]['question']
            true_doc_id = val_data[example_index]['document_id']

            print(f"\n--- 使用词向量进行检索示例 ---")
            print(f"示例问题 (来自 val.jsonl[{example_index}]): '{example_question}'")
            print(f"真实相关的文档 ID: {true_doc_id}")

            # 6. 执行检索
            retrieved_ids = retriever.retrieve(example_question, top_n=5)

            # 7. 显示检索结果
            print(f"\n词向量检索到的 Top-5 文档 ID:")
            print(retrieved_ids)

            # (可选) 检查真实文档是否在前 5 个结果中
            if retrieved_ids and true_doc_id in retrieved_ids:
                rank = retrieved_ids.index(true_doc_id) + 1
                print(f"成功! 真实文档 ID {true_doc_id} 在检索结果中排名第 {rank}。")
            elif retrieved_ids:
                print(f"失败。真实文档 ID {true_doc_id} 未在前 5 个检索结果中。")
            else:
                print("未能检索到任何文档。")


            # (可选) 查看检索到的第一个文档的原文
            if retrieved_ids:
                 first_retrieved_id = retrieved_ids[0]
                 print(f"\n检索到的第一个文档 (ID: {first_retrieved_id}) 的原文预览:")
                 print(retriever.documents_dict.get(first_retrieved_id, "未找到原文")[:500] + "...")

        elif retriever.doc_vectors is None:
             print("\n由于文档向量索引构建失败，无法执行检索示例。")
        else:
            print("\n未能加载验证数据，无法执行检索示例。")
    else:
        print("\n由于模型加载失败，无法继续执行。")



--- 初始化词向量检索器并加载本地模型 ---
开始从本地文件加载预训练模型: /kaggle/input/crawl-300d-2m/crawl-300d-2M.vec ...
这可能需要一些时间，取决于模型大小和你的机器性能。
模型加载完成。耗时: 416.74 秒
向量维度: 300
词汇表大小: 1999995

--- 加载文档数据 ---
成功加载 12138 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/documents.jsonl
开始构建文档向量索引...
正在对 12138 个文档进行预处理和向量化...
  已处理 1000/12138 个文档...
  已处理 2000/12138 个文档...
  已处理 3000/12138 个文档...
  已处理 4000/12138 个文档...
  已处理 5000/12138 个文档...
  已处理 6000/12138 个文档...
  已处理 7000/12138 个文档...
  已处理 8000/12138 个文档...
  已处理 9000/12138 个文档...
  已处理 10000/12138 个文档...
  已处理 11000/12138 个文档...
  已处理 12000/12138 个文档...
文档向量索引构建完成。向量矩阵形状: (12138, 300)
耗时: 453.60 秒

--- 加载验证数据以获取示例 ---
成功加载 1000 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl

--- 使用词向量进行检索示例 ---
示例问题 (来自 val.jsonl[0]): 'when did the british first land in north america'
真实相关的文档 ID: 11484

词向量检索到的 Top-5 文档 ID:
[11484, 81, 4117, 5493, 570]
成功! 真实文档 ID 11484 在检索结果中排名第 1。

检索到的第一个文档 (ID: 11484) 的原文预览:
British colonization of the Americas - wikiped

In [8]:
import json
import os
import re
import time
import pandas as pd
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from rank_bm25 import BM25Okapi
from gensim.models import KeyedVectors
import numpy as np
from tqdm import tqdm # 用于显示进度条

# --- 配置 ---
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data'
DOCUMENTS_FILE = os.path.join(DATA_DIR, 'documents.jsonl')
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl')
# 你的本地 .vec 文件路径
PRETRAINED_MODEL_PATH = '/kaggle/input/crawl-300d-2m/crawl-300d-2M.vec'

# --- NLTK 数据下载检查 ---
try:
    nltk.data.find('tokenizers/punkt')
except nltk.downloader.DownloadError:
    print("下载 NLTK 'punkt' 数据...")
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('corpora/stopwords')
except nltk.downloader.DownloadError:
    print("下载 NLTK 'stopwords' 数据...")
    nltk.download('stopwords', quiet=True)

stop_words = set(stopwords.words('english'))

# --- 辅助函数 ---

def load_jsonl(file_path):
    """加载 JSON Lines 文件。"""
    data = []
    if not os.path.exists(file_path):
        print(f"警告: 文件未找到 {file_path}")
        return data
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            # 使用 tqdm 显示文件读取进度 (如果文件很大)
            total_lines = sum(1 for line in f) # 计算总行数可能较慢
            f.seek(0) # 重置文件指针
            for line in tqdm(f, total=total_lines, desc=f"加载 {os.path.basename(file_path)}"):
            # for line in f: # 不显示进度条的简化版本
                 if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"警告: 跳过无法解析的行: {line.strip()} - 错误: {e}")
        print(f"成功加载 {len(data)} 条记录从 {file_path}")
    except Exception as e:
        print(f"加载文件时出错 {file_path}: {e}")
    return data

def preprocess_text_tfidf(text):
    """文本预处理 (TF-IDF): 返回处理后的字符串。"""
    if not isinstance(text, str): return ""
    text = re.sub(r'<[^>]+>', ' ', text)
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    text = text.lower()
    tokens = word_tokenize(text)
    filtered_tokens = [word for word in tokens if word.isalpha() and word not in stop_words]
    return " ".join(filtered_tokens)

def preprocess_text_tokens(text):
    """文本预处理 (BM25/FastText): 返回 token 列表。"""
    if not isinstance(text, str): return []
    text = re.sub(r'<[^>]+>', ' ', text)
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    text = text.lower()
    tokens = word_tokenize(text)
    filtered_tokens = [word for word in tokens if word.isalpha() and word not in stop_words]
    return filtered_tokens

def get_average_vector(tokens, model):
    """计算平均词向量。"""
    vectors = []
    for token in tokens:
        try:
            vectors.append(model[token])
        except KeyError:
            pass
    if not vectors: return np.zeros(model.vector_size)
    return np.mean(vectors, axis=0)

# --- 检索器类定义 (从之前步骤复制并整合) ---

class BaseRetriever:
    """检索器基类，定义通用接口。"""
    def __init__(self):
        self.document_ids = []
        self.documents_dict = {}

    def build_index(self, documents_data):
        raise NotImplementedError

    def retrieve(self, query, top_n=5):
        raise NotImplementedError

    def _prepare_docs(self, documents_data):
        """提取文档 ID 和内容字典。"""
        self.document_ids = [doc['document_id'] for doc in documents_data if 'document_id' in doc]
        self.documents_dict = {doc['document_id']: doc['document_text'] for doc in documents_data if 'document_id' in doc and 'document_text' in doc}
        print(f"准备了 {len(self.document_ids)} 个文档 ID 和 {len(self.documents_dict)} 个文档原文。")
        return documents_data # 返回原始数据供子类使用

class TfidfRetriever(BaseRetriever):
    """TF-IDF 检索器。"""
    def __init__(self):
        super().__init__()
        self.vectorizer = TfidfVectorizer(max_df=0.85, min_df=2)
        self.tfidf_matrix = None

    def build_index(self, documents_data):
        print("\n--- 构建 TF-IDF 索引 ---")
        start_time = time.time()
        documents_data = self._prepare_docs(documents_data)
        doc_texts = [preprocess_text_tfidf(self.documents_dict.get(doc_id, '')) for doc_id in self.document_ids]
        if not doc_texts: print("错误：无有效文档文本。"); return
        self.tfidf_matrix = self.vectorizer.fit_transform(doc_texts)
        print(f"TF-IDF 索引构建完成。矩阵形状: {self.tfidf_matrix.shape}. 耗时: {time.time() - start_time:.2f} 秒")

    def retrieve(self, query, top_n=5):
        if self.tfidf_matrix is None: return []
        processed_query = preprocess_text_tfidf(query)
        if not processed_query: return []
        query_vector = self.vectorizer.transform([processed_query])
        cosine_similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten()
        # 使用 argpartition 获取 top_n 最快，然后排序这 top_n 个
        # top_n_indices = np.argsort(cosine_similarities)[::-1][:top_n]
        k = min(top_n, self.tfidf_matrix.shape[0]) # 确保 top_n 不超过文档数
        top_n_indices = np.argpartition(cosine_similarities, -k)[-k:]
        # 对这 k 个索引按得分排序
        top_n_indices = top_n_indices[np.argsort(cosine_similarities[top_n_indices])[::-1]]
        return [self.document_ids[i] for i in top_n_indices]

class Bm25Retriever(BaseRetriever):
    """BM25 检索器。"""
    def __init__(self):
        super().__init__()
        self.bm25_index = None
        self.tokenized_corpus = []

    def build_index(self, documents_data):
        print("\n--- 构建 BM25 索引 ---")
        start_time = time.time()
        documents_data = self._prepare_docs(documents_data)
        print(f"正在对 {len(self.document_ids)} 个文档进行预处理和分词...")
        self.tokenized_corpus = [preprocess_text_tokens(self.documents_dict.get(doc_id, '')) for doc_id in tqdm(self.document_ids, desc="BM25 Preprocessing")]
        if not self.tokenized_corpus: print("错误：无有效文档文本。"); return
        print(f"\n初始化 BM25Okapi...")
        self.bm25_index = BM25Okapi(self.tokenized_corpus)
        print(f"BM25 索引构建完成。耗时: {time.time() - start_time:.2f} 秒")

    def retrieve(self, query, top_n=5):
        if self.bm25_index is None: return []
        tokenized_query = preprocess_text_tokens(query)
        if not tokenized_query: return []
        doc_scores = self.bm25_index.get_scores(tokenized_query)
        # top_n_indices = np.argsort(doc_scores)[::-1][:top_n]
        k = min(top_n, len(self.document_ids))
        top_n_indices = np.argpartition(doc_scores, -k)[-k:]
        top_n_indices = top_n_indices[np.argsort(doc_scores[top_n_indices])[::-1]]
        return [self.document_ids[i] for i in top_n_indices]

class FastTextRetriever(BaseRetriever):
    """词向量检索器。"""
    def __init__(self, model_path=PRETRAINED_MODEL_PATH):
        super().__init__()
        self.model_path = model_path
        self.model = None
        self.doc_vectors = None

    def load_model(self):
        print(f"\n--- 加载词向量模型 ({os.path.basename(self.model_path)}) ---")
        start_time = time.time()
        if not os.path.exists(self.model_path): print(f"错误: 模型文件未找到: {self.model_path}"); return
        try:
            self.model = KeyedVectors.load_word2vec_format(self.model_path, binary=False)
            print(f"模型加载完成。耗时: {time.time() - start_time:.2f} 秒. V={len(self.model.key_to_index)}, D={self.model.vector_size}")
        except Exception as e: print(f"加载模型时发生错误: {e}"); self.model = None

    def build_index(self, documents_data):
        if self.model is None: print("错误：模型未加载。"); return
        print("\n--- 构建文档向量索引 ---")
        start_time = time.time()
        documents_data = self._prepare_docs(documents_data)
        print(f"正在对 {len(self.document_ids)} 个文档进行预处理和向量化...")
        doc_vectors_list = [get_average_vector(preprocess_text_tokens(self.documents_dict.get(doc_id, '')), self.model)
                            for doc_id in tqdm(self.document_ids, desc="Vectorizing Docs")]
        if not doc_vectors_list: print("错误：未能生成向量。"); return
        self.doc_vectors = np.array(doc_vectors_list)
        print(f"文档向量索引构建完成。矩阵形状: {self.doc_vectors.shape}. 耗时: {time.time() - start_time:.2f} 秒")

    def retrieve(self, query, top_n=5):
        if self.model is None or self.doc_vectors is None: return []
        tokenized_query = preprocess_text_tokens(query)
        if not tokenized_query: return []
        query_vector = get_average_vector(tokenized_query, self.model)
        if np.all(query_vector == 0): return []
        cosine_similarities = cosine_similarity(query_vector.reshape(1, -1), self.doc_vectors).flatten()
        # top_n_indices = np.argsort(cosine_similarities)[::-1][:top_n]
        k = min(top_n, self.doc_vectors.shape[0])
        top_n_indices = np.argpartition(cosine_similarities, -k)[-k:]
        top_n_indices = top_n_indices[np.argsort(cosine_similarities[top_n_indices])[::-1]]
        return [self.document_ids[i] for i in top_n_indices]

# --- 评估函数 ---
def evaluate_retriever(retriever, validation_data, top_n=5):
    """
    评估给定检索器在验证集上的 Recall@N 和 MRR@N。

    Args:
        retriever (BaseRetriever): 已构建好索引的检索器对象。
        validation_data (list[dict]): 验证数据集。
        top_n (int): 评估指标的 N 值 (例如 Recall@5, MRR@5)。

    Returns:
        dict: 包含 'recall@N' 和 'mrr@N' 的字典。
    """
    recall_sum = 0
    mrr_sum = 0
    total = len(validation_data)

    if total == 0:
        return {f'recall@{top_n}': 0, f'mrr@{top_n}': 0}

    print(f"\n开始评估 {type(retriever).__name__} (共 {total} 个问题)...")
    # 使用 tqdm 显示评估进度
    for item in tqdm(validation_data, desc=f"Evaluating {type(retriever).__name__}"):
        question = item['question']
        true_doc_id = item['document_id']

        # 执行检索
        retrieved_ids = retriever.retrieve(question, top_n=top_n)

        # 计算 Recall@N
        if true_doc_id in retrieved_ids:
            recall_sum += 1

            # 计算 MRR@N
            try:
                rank = retrieved_ids.index(true_doc_id) + 1
                mrr_sum += 1.0 / rank
            except ValueError:
                #理论上如果 recall_sum+=1，这里不会发生，但为了健壮性保留
                pass # MRR 贡献为 0

    recall_at_n = recall_sum / total
    mrr_at_n = mrr_sum / total

    return {f'recall@{top_n}': recall_at_n, f'mrr@{top_n}': mrr_at_n}

# --- 主评估流程 ---
if __name__ == "__main__":
    # 1. 加载文档数据 (只需要加载一次)
    print("--- 加载文档数据 ---")
    documents_data = load_jsonl(DOCUMENTS_FILE)
    if not documents_data:
        print("错误：无法加载文档数据，评估中止。")
        exit()

    # 2. 加载验证数据 (只需要加载一次)
    print("\n--- 加载验证数据 ---")
    val_data = load_jsonl(VAL_FILE)
    if not val_data:
        print("错误：无法加载验证数据，评估中止。")
        exit()

    # 3. 初始化、构建索引并评估各个检索器
    results = {}

    # --- TF-IDF ---
    tfidf_retriever = TfidfRetriever()
    tfidf_retriever.build_index(documents_data)
    if tfidf_retriever.tfidf_matrix is not None:
        results['TF-IDF'] = evaluate_retriever(tfidf_retriever, val_data, top_n=5)
    else:
        results['TF-IDF'] = {'recall@5': '构建失败', 'mrr@5': '构建失败'}


    # --- BM25 ---
    bm25_retriever = Bm25Retriever()
    bm25_retriever.build_index(documents_data)
    if bm25_retriever.bm25_index is not None:
         results['BM25'] = evaluate_retriever(bm25_retriever, val_data, top_n=5)
    else:
        results['BM25'] = {'recall@5': '构建失败', 'mrr@5': '构建失败'}

    # --- FastText (词向量平均) ---
    fasttext_retriever = FastTextRetriever(model_path=PRETRAINED_MODEL_PATH)
    fasttext_retriever.load_model() # 加载模型
    if fasttext_retriever.model is not None:
        fasttext_retriever.build_index(documents_data) # 构建文档向量
        if fasttext_retriever.doc_vectors is not None:
            results['FastText Avg'] = evaluate_retriever(fasttext_retriever, val_data, top_n=5)
        else:
             results['FastText Avg'] = {'recall@5': '构建失败', 'mrr@5': '构建失败'}
    else:
         results['FastText Avg'] = {'recall@5': '模型加载失败', 'mrr@5': '模型加载失败'}


    # 4. 打印结果表格
    print("\n\n--- 检索器性能评估结果 (验证集) ---")
    print("-" * 50)
    print(f"{'Retriever':<15} | {'Recall@5':<15} | {'MRR@5':<15}")
    print("-" * 50)
    for name, metrics in results.items():
        recall_str = f"{metrics.get('recall@5', 'N/A'):.4f}" if isinstance(metrics.get('recall@5'), float) else str(metrics.get('recall@5', 'N/A'))
        mrr_str = f"{metrics.get('mrr@5', 'N/A'):.4f}" if isinstance(metrics.get('mrr@5'), float) else str(metrics.get('mrr@5', 'N/A'))
        print(f"{name:<15} | {recall_str:<15} | {mrr_str:<15}")
    print("-" * 50)



--- 加载文档数据 ---


加载 documents.jsonl: 100%|██████████| 12138/12138 [00:04<00:00, 2969.53it/s]


成功加载 12138 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/documents.jsonl

--- 加载验证数据 ---


加载 val.jsonl: 100%|██████████| 1000/1000 [00:00<00:00, 203864.29it/s]


成功加载 1000 条记录从 /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl

--- 构建 TF-IDF 索引 ---
准备了 12138 个文档 ID 和 12138 个文档原文。
TF-IDF 索引构建完成。矩阵形状: (12138, 279270). 耗时: 365.52 秒

开始评估 TfidfRetriever (共 1000 个问题)...


Evaluating TfidfRetriever: 100%|██████████| 1000/1000 [12:43<00:00,  1.31it/s]



--- 构建 BM25 索引 ---
准备了 12138 个文档 ID 和 12138 个文档原文。
正在对 12138 个文档进行预处理和分词...


BM25 Preprocessing: 100%|██████████| 12138/12138 [05:38<00:00, 35.84it/s]



初始化 BM25Okapi...
BM25 索引构建完成。耗时: 357.33 秒

开始评估 Bm25Retriever (共 1000 个问题)...


Evaluating Bm25Retriever: 100%|██████████| 1000/1000 [00:48<00:00, 20.76it/s]



--- 加载词向量模型 (crawl-300d-2M.vec) ---
模型加载完成。耗时: 433.05 秒. V=1999995, D=300

--- 构建文档向量索引 ---
准备了 12138 个文档 ID 和 12138 个文档原文。
正在对 12138 个文档进行预处理和向量化...


Vectorizing Docs: 100%|██████████| 12138/12138 [07:39<00:00, 26.41it/s]


文档向量索引构建完成。矩阵形状: (12138, 300). 耗时: 459.70 秒

开始评估 FastTextRetriever (共 1000 个问题)...


Evaluating FastTextRetriever: 100%|██████████| 1000/1000 [00:13<00:00, 74.33it/s]



--- 检索器性能评估结果 (验证集) ---
--------------------------------------------------
Retriever       | Recall@5        | MRR@5          
--------------------------------------------------
TF-IDF          | 0.6960          | 0.5123         
BM25            | 0.7220          | 0.5538         
FastText Avg    | 0.5570          | 0.3936         
--------------------------------------------------





In [4]:
!pip install transformers torch faiss-gpu # 或者 faiss-cpu
# 可能還需要 sentencepiece
!pip install sentencepiece

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [2]:
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import time
import os

# --- 配置 ---
# 選擇預訓練的 DPR 模型名稱 (針對 NQ 數據集微調過的)
# Facebook 提供了針對 NQ 微調的模型
QUESTION_ENCODER_NAME = 'facebook/dpr-question_encoder-single-nq-base'
CONTEXT_ENCODER_NAME = 'facebook/dpr-ctx_encoder-single-nq-base'
# 模型保存路徑 (可選，如果想將下載的模型保存在特定位置)
# MODEL_SAVE_DIR = './dpr_models'
# os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

# --- 檢查是否有可用的 GPU ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用的設備: {device}")

# --- 加載 DPR 模型和分詞器 ---

def load_dpr_models(q_encoder_name, ctx_encoder_name, device):
    """
    從 Hugging Face Hub 加載 DPR 問題編碼器、上下文編碼器及對應的分詞器。

    Args:
        q_encoder_name (str): 問題編碼器模型名稱或路徑。
        ctx_encoder_name (str): 上下文編碼器模型名稱或路徑。
        device (torch.device): 模型運行的設備 (cpu 或 cuda)。

    Returns:
        tuple: 包含 (question_tokenizer, question_encoder, context_tokenizer, context_encoder) 的元組。
               如果加載失敗，則對應項為 None。
    """
    print(f"開始加載 DPR 模型...")
    start_time = time.time()
    try:
        # 加載問題編碼器和分詞器
        print(f"  加載問題分詞器: {q_encoder_name}")
        q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(q_encoder_name)
        print(f"  加載問題編碼器: {q_encoder_name}")
        q_encoder = DPRQuestionEncoder.from_pretrained(q_encoder_name)
        q_encoder.to(device) # 將模型移動到指定設備
        q_encoder.eval() # 設置為評估模式

        # 加載上下文編碼器和分詞器
        print(f"  加載上下文分詞器: {ctx_encoder_name}")
        ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(ctx_encoder_name)
        print(f"  加載上下文編碼器: {ctx_encoder_name}")
        ctx_encoder = DPRContextEncoder.from_pretrained(ctx_encoder_name)
        ctx_encoder.to(device) # 將模型移動到指定設備
        ctx_encoder.eval() # 設置為評估模式

        end_time = time.time()
        print(f"DPR 模型加載完成。耗時: {end_time - start_time:.2f} 秒")
        return q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder

    except Exception as e:
        print(f"加載 DPR 模型時出錯: {e}")
        print("請確保模型名稱正確，網絡連接正常，或已安裝所需依賴。")
        return None, None, None, None

# --- 主程序入口 ---
if __name__ == "__main__":
    # 執行模型加載
    question_tokenizer, question_encoder, context_tokenizer, context_encoder = load_dpr_models(
        QUESTION_ENCODER_NAME,
        CONTEXT_ENCODER_NAME,
        device
    )

    if question_encoder and context_encoder:
        print("\n成功加載所有 DPR 組件！")
        # 打印一些模型信息 (可選)
        print(f"問題編碼器配置: {question_encoder.config.model_type}, 隱藏層大小: {question_encoder.config.hidden_size}")
        print(f"上下文編碼器配置: {context_encoder.config.model_type}, 隱藏層大小: {context_encoder.config.hidden_size}")

        # --- 接下來的步驟 (將在後續代碼中實現) ---
        # 1. 加載 documents.jsonl
        # 2. 遍歷文檔，使用 context_tokenizer 和 context_encoder 生成文檔向量
        # 3. 使用 faiss 構建索引
        # 4. 實現檢索函數 (接收問題，使用 question_tokenizer 和 question_encoder 生成問題向量，在 faiss 中搜索)
        # 5. (可選) 進行評估
        print("\n下一步：使用加載的模型對文檔進行編碼並構建 FAISS 索引。")

    else:
        print("\nDPR 模型加載失敗，請檢查錯誤信息。")

使用的設備: cuda
開始加載 DPR 模型...
  加載問題分詞器: facebook/dpr-question_encoder-single-nq-base


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/493 [00:00<?, ?B/s]

  加載問題編碼器: facebook/dpr-question_encoder-single-nq-base


pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  加載上下文分詞器: facebook/dpr-ctx_encoder-single-nq-base


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/492 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.


  加載上下文編碼器: facebook/dpr-ctx_encoder-single-nq-base


pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DPR 模型加載完成。耗時: 9.51 秒

成功加載所有 DPR 組件！
問題編碼器配置: dpr, 隱藏層大小: 768
上下文編碼器配置: dpr, 隱藏層大小: 768

下一步：使用加載的模型對文檔進行編碼並構建 FAISS 索引。


In [5]:
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import faiss # 引入 faiss
import numpy as np
import time
import os
import json
from tqdm import tqdm

# --- 配置 ---
QUESTION_ENCODER_NAME = 'facebook/dpr-question_encoder-single-nq-base'
CONTEXT_ENCODER_NAME = 'facebook/dpr-ctx_encoder-single-nq-base'
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data'
DOCUMENTS_FILE = os.path.join(DATA_DIR, 'documents.jsonl')
# 保存 FAISS 索引和文檔 ID 映射的文件路徑
FAISS_INDEX_PATH = "dpr_faiss_index.idx"
DOC_IDS_PATH = "dpr_doc_ids.json"
# 編碼時的批次大小 (根據你的內存調整，CPU 上可以設小一點)
BATCH_SIZE = 32 # 例如 16, 32, 64
# DPR 模型通常的最大序列長度
MAX_LENGTH = 512

# --- 檢查是否有可用的 GPU ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用的設備: {device}")

# --- 加載 DPR 模型和分詞器 (與上一步相同) ---
def load_dpr_models(q_encoder_name, ctx_encoder_name, device):
    """加載 DPR 模型。"""
    print(f"開始加載 DPR 模型...")
    start_time = time.time()
    try:
        q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(q_encoder_name)
        q_encoder = DPRQuestionEncoder.from_pretrained(q_encoder_name)
        q_encoder.to(device).eval()

        ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(ctx_encoder_name)
        ctx_encoder = DPRContextEncoder.from_pretrained(ctx_encoder_name)
        ctx_encoder.to(device).eval()

        end_time = time.time()
        print(f"DPR 模型加載完成。耗時: {end_time - start_time:.2f} 秒")
        return q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder
    except Exception as e:
        print(f"加載 DPR 模型時出錯: {e}")
        return None, None, None, None

# --- 加載文檔數據 ---
def load_jsonl(file_path):
    """加載 JSON Lines 文件。"""
    data = []
    if not os.path.exists(file_path):
        print(f"警告: 文件未找到 {file_path}")
        return data
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                 if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"警告: 跳過無法解析的行: {line.strip()} - 錯誤: {e}")
        print(f"成功加載 {len(data)} 條記錄從 {file_path}")
    except Exception as e:
        print(f"加載文件時出錯 {file_path}: {e}")
    return data

# --- 文檔編碼函數 ---
def encode_documents(documents, tokenizer, encoder, device, batch_size=32, max_length=512):
    """
    使用 DPR 上下文編碼器對文檔進行編碼。

    Args:
        documents (list[dict]): 文檔數據列表，每個 dict 包含 'document_id' 和 'document_text'。
        tokenizer: DPR 上下文分詞器。
        encoder: DPR 上下文編碼器。
        device: 運行設備。
        batch_size (int): 批處理大小。
        max_length (int): 分詞器最大長度。

    Returns:
        tuple: (numpy.ndarray of document vectors, list of document ids)
               如果出錯則返回 (None, None)。
    """
    doc_vectors = []
    doc_ids = []
    total_docs = len(documents)
    print(f"\n開始對 {total_docs} 個文檔進行編碼 (Batch Size: {batch_size})...")
    print("這一步會非常耗時，尤其是在 CPU 上。")

    # 記錄所有 ID
    for doc in documents:
         if 'document_id' in doc:
             doc_ids.append(doc['document_id'])

    # 使用 tqdm 顯示進度
    with tqdm(total=total_docs, desc="Encoding Documents") as pbar:
        for i in range(0, total_docs, batch_size):
            batch_docs = documents[i : i + batch_size]
            # 提取文本，處理可能的 None 或非字符串值
            texts = [str(doc.get('document_text', '')) for doc in batch_docs]

            try:
                # 分詞
                inputs = tokenizer(
                    texts,
                    max_length=max_length,
                    padding='longest', # 在 batch 內部填充到最長
                    truncation=True,
                    return_tensors='pt'
                )
                # 將輸入移動到設備
                inputs = {key: val.to(device) for key, val in inputs.items()}

                # 在 no_grad 模式下進行編碼，節省內存和計算
                with torch.no_grad():
                    outputs = encoder(**inputs)
                    # DPR 的輸出通常在 pooler_output 屬性中
                    batch_vectors = outputs.pooler_output

                # (推薦) L2 歸一化向量，以便使用 IndexFlatIP 計算余弦相似度
                norms = torch.linalg.norm(batch_vectors, dim=1, keepdim=True)
                normalized_vectors = batch_vectors / norms

                # 將向量移回 CPU 並轉換為 NumPy 數組
                doc_vectors.append(normalized_vectors.cpu().numpy())

            except Exception as e:
                print(f"\n處理批次 {i // batch_size} 時出錯: {e}")
                # 可以選擇跳過錯誤批次或中止
                # continue
                # return None, None

            pbar.update(len(batch_docs))

    if not doc_vectors:
        print("錯誤：未能生成任何文檔向量。")
        return None, None

    # 將所有批次的向量合併為一個大的 NumPy 數組
    all_doc_vectors = np.concatenate(doc_vectors, axis=0)
    print(f"\n文檔編碼完成。生成向量矩陣形狀: {all_doc_vectors.shape}")
    return all_doc_vectors.astype('float32'), doc_ids # FAISS 通常需要 float32

# --- FAISS 索引構建函數 ---
def build_faiss_index(doc_vectors, index_path):
    """
    使用文檔向量構建 FAISS 索引並保存。

    Args:
        doc_vectors (np.ndarray): 文檔向量 (float32)。
        index_path (str): 保存索引的文件路徑。

    Returns:
        faiss.Index: 構建好的 FAISS 索引。如果出錯則返回 None。
    """
    if doc_vectors is None or doc_vectors.shape[0] == 0:
        print("錯誤：沒有有效的文檔向量來構建索引。")
        return None

    vector_dim = doc_vectors.shape[1]
    print(f"\n開始構建 FAISS 索引 (向量維度: {vector_dim})...")
    start_time = time.time()

    # 選擇索引類型: IndexFlatIP (內積)。因為向量已歸一化，內積等價於余弦相似度。
    # IndexFlatL2 計算歐氏距離。
    index = faiss.IndexFlatIP(vector_dim)

    # 添加向量到索引
    print(f"正在添加 {doc_vectors.shape[0]} 個向量到索引...")
    index.add(doc_vectors)

    end_time = time.time()
    print(f"FAISS 索引構建完成。索引中文檔數: {index.ntotal}")
    print(f"耗時: {end_time - start_time:.2f} 秒")

    # 保存索引
    try:
        print(f"正在將索引保存到: {index_path}")
        faiss.write_index(index, index_path)
        print("索引保存成功。")
    except Exception as e:
        print(f"保存 FAISS 索引時出錯: {e}")

    return index

# --- 主程序入口 ---
if __name__ == "__main__":
    # 1. 加載 DPR 模型
    question_tokenizer, question_encoder, context_tokenizer, context_encoder = load_dpr_models(
        QUESTION_ENCODER_NAME,
        CONTEXT_ENCODER_NAME,
        device
    )

    faiss_index = None
    doc_ids_list = []

    # 檢查索引和 ID 映射是否已存在，如果存在則直接加載
    if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(DOC_IDS_PATH):
        print(f"\n檢測到已存在的 FAISS 索引 ({FAISS_INDEX_PATH}) 和文檔 ID 映射 ({DOC_IDS_PATH})。")
        try:
            print("正在加載 FAISS 索引...")
            faiss_index = faiss.read_index(FAISS_INDEX_PATH)
            print(f"索引加載成功。包含 {faiss_index.ntotal} 個向量。")

            print("正在加載文檔 ID 映射...")
            with open(DOC_IDS_PATH, 'r') as f:
                doc_ids_list = json.load(f)
            print(f"文檔 ID 映射加載成功。包含 {len(doc_ids_list)} 個 ID。")

            if faiss_index.ntotal != len(doc_ids_list):
                print("警告：索引中的向量數與加載的 ID 數量不匹配！可能需要重新構建。")
                faiss_index = None # 標記為無效，觸發重新構建
                doc_ids_list = []

        except Exception as e:
            print(f"加載已存文件時出錯: {e}。將重新構建索引。")
            faiss_index = None
            doc_ids_list = []

    # 如果模型加載成功，並且索引未從文件加載，則執行編碼和構建
    if context_encoder and context_tokenizer and faiss_index is None:
        # 2. 加載文檔數據
        print("\n--- 加載文檔數據 ---")
        documents_data = load_jsonl(DOCUMENTS_FILE)

        if documents_data:
            # 3. 文檔編碼
            document_vectors, doc_ids_list = encode_documents(
                documents_data,
                context_tokenizer,
                context_encoder,
                device,
                batch_size=BATCH_SIZE,
                max_length=MAX_LENGTH
            )

            if document_vectors is not None and doc_ids_list:
                # 4. 構建 FAISS 索引
                faiss_index = build_faiss_index(document_vectors, FAISS_INDEX_PATH)

                # 5. 保存文檔 ID 列表
                if faiss_index is not None: # 確保索引構建成功
                    try:
                        print(f"正在將文檔 ID 列表保存到: {DOC_IDS_PATH}")
                        with open(DOC_IDS_PATH, 'w') as f:
                            json.dump(doc_ids_list, f)
                        print("文檔 ID 列表保存成功。")
                    except Exception as e:
                        print(f"保存文檔 ID 列表時出錯: {e}")
            else:
                 print("文檔編碼失敗，無法構建索引。")
        else:
            print("無法加載文檔數據，無法構建索引。")

    elif faiss_index is not None:
         print("\n已成功加載預先構建的 FAISS 索引和文檔 ID。")

    else:
        print("\nDPR 模型加載失敗 或 索引構建/加載失敗，無法繼續。")


    # --- 接下來的步驟 ---
    if faiss_index is not None and question_encoder is not None and doc_ids_list:
        print("\n下一步：實現使用 FAISS 索引進行檢索的函數。")
        # 示例：獲取索引中的向量數量
        print(f"最終 FAISS 索引中的向量總數: {faiss_index.ntotal}")
        print(f"對應的文檔 ID 數量: {len(doc_ids_list)}")
        # 實現 retrieve_dpr(query, top_n) -> list[doc_id]
    else:
        print("\n未能成功準備好用於檢索的組件。")


使用的設備: cuda
開始加載 DPR 模型...


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRCon

DPR 模型加載完成。耗時: 2.83 秒

--- 加載文檔數據 ---
成功加載 12138 條記錄從 /kaggle/input/nq10k-comp5423/data_and_code/data/documents.jsonl

開始對 12138 個文檔進行編碼 (Batch Size: 32)...
這一步會非常耗時，尤其是在 CPU 上。


Encoding Documents: 100%|██████████| 12138/12138 [41:13<00:00,  4.91it/s]


文檔編碼完成。生成向量矩陣形狀: (12138, 768)

開始構建 FAISS 索引 (向量維度: 768)...
正在添加 12138 個向量到索引...
FAISS 索引構建完成。索引中文檔數: 12138
耗時: 0.03 秒
正在將索引保存到: dpr_faiss_index.idx
索引保存成功。
正在將文檔 ID 列表保存到: dpr_doc_ids.json
文檔 ID 列表保存成功。

下一步：實現使用 FAISS 索引進行檢索的函數。
最終 FAISS 索引中的向量總數: 12138
對應的文檔 ID 數量: 12138





In [7]:
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import faiss # 引入 faiss
import numpy as np
import time
import os
import json
from tqdm import tqdm

# --- 配置 ---
QUESTION_ENCODER_NAME = 'facebook/dpr-question_encoder-single-nq-base'
CONTEXT_ENCODER_NAME = 'facebook/dpr-ctx_encoder-single-nq-base'
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data'
DOCUMENTS_FILE = os.path.join(DATA_DIR, 'documents.jsonl')
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl') # 用於演示檢索
# 保存/加載 FAISS 索引和文檔 ID 映射的文件路徑
FAISS_INDEX_PATH = "/kaggle/working/dpr_faiss_index.idx"
DOC_IDS_PATH = "/kaggle/working/dpr_doc_ids.json"
# 編碼時的批次大小
BATCH_SIZE = 32 # 文檔編碼時使用
# DPR 模型通常的最大序列長度
MAX_LENGTH = 512

# --- 檢查是否有可用的 GPU ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用的設備: {device}")

# --- 加載 DPR 模型和分詞器 ---
def load_dpr_models(q_encoder_name, ctx_encoder_name, device):
    """加載 DPR 模型。"""
    print(f"開始加載 DPR 模型...")
    start_time = time.time()
    try:
        q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(q_encoder_name)
        q_encoder = DPRQuestionEncoder.from_pretrained(q_encoder_name)
        q_encoder.to(device).eval()

        # 注意：上下文分詞器僅在需要重新編碼文檔時才需要加載
        # 如果索引已存在，可以考慮延遲加載或不加載 ctx_tokenizer, ctx_encoder
        ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(ctx_encoder_name)
        ctx_encoder = DPRContextEncoder.from_pretrained(ctx_encoder_name)
        ctx_encoder.to(device).eval()

        end_time = time.time()
        print(f"DPR 模型加載完成。耗時: {end_time - start_time:.2f} 秒")
        return q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder
    except Exception as e:
        print(f"加載 DPR 模型時出錯: {e}")
        # 即使上下文編碼器加載失敗，如果索引存在，仍可能繼續進行檢索
        # 但為了完整性，這裡返回 None
        return None, None, None, None

# --- 加載文檔/驗證數據 ---
def load_jsonl(file_path):
    """加載 JSON Lines 文件。"""
    data = []
    if not os.path.exists(file_path):
        print(f"警告: 文件未找到 {file_path}")
        return data
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                 if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"警告: 跳過無法解析的行: {line.strip()} - 錯誤: {e}")
        print(f"成功加載 {len(data)} 條記錄從 {file_path}")
    except Exception as e:
        print(f"加載文件時出錯 {file_path}: {e}")
    return data

# --- 文檔編碼函數 ---
def encode_documents(documents, tokenizer, encoder, device, batch_size=32, max_length=512):
    """使用 DPR 上下文編碼器對文檔進行編碼。"""
    doc_vectors = []
    doc_ids = []
    total_docs = len(documents)
    print(f"\n開始對 {total_docs} 個文檔進行編碼 (Batch Size: {batch_size})...")
    print("這一步會非常耗時，尤其是在 CPU 上。")
    for doc in documents:
         if 'document_id' in doc: doc_ids.append(doc['document_id'])
    with tqdm(total=total_docs, desc="Encoding Documents") as pbar:
        for i in range(0, total_docs, batch_size):
            batch_docs = documents[i : i + batch_size]
            texts = [str(doc.get('document_text', '')) for doc in batch_docs]
            try:
                inputs = tokenizer(texts, max_length=max_length, padding='longest', truncation=True, return_tensors='pt')
                inputs = {key: val.to(device) for key, val in inputs.items()}
                with torch.no_grad():
                    outputs = encoder(**inputs)
                    batch_vectors = outputs.pooler_output
                norms = torch.linalg.norm(batch_vectors, dim=1, keepdim=True)
                # 添加一個小的 epsilon 防止除以零
                normalized_vectors = batch_vectors / (norms + 1e-8)
                doc_vectors.append(normalized_vectors.cpu().numpy())
            except Exception as e:
                print(f"\n處理批次 {i // batch_size} 時出錯: {e}")
            pbar.update(len(batch_docs))
    if not doc_vectors: return None, None
    all_doc_vectors = np.concatenate(doc_vectors, axis=0)
    print(f"\n文檔編碼完成。生成向量矩陣形狀: {all_doc_vectors.shape}")
    return all_doc_vectors.astype('float32'), doc_ids

# --- FAISS 索引構建函數 ---
def build_faiss_index(doc_vectors, index_path):
    """使用文檔向量構建 FAISS 索引並保存。"""
    if doc_vectors is None or doc_vectors.shape[0] == 0: return None
    vector_dim = doc_vectors.shape[1]
    print(f"\n開始構建 FAISS 索引 (向量維度: {vector_dim})...")
    start_time = time.time()
    index = faiss.IndexFlatIP(vector_dim)
    print(f"正在添加 {doc_vectors.shape[0]} 個向量到索引...")
    index.add(doc_vectors)
    end_time = time.time()
    print(f"FAISS 索引構建完成。索引中文檔數: {index.ntotal}. 耗時: {end_time - start_time:.2f} 秒")
    try:
        print(f"正在將索引保存到: {index_path}")
        faiss.write_index(index, index_path)
        print("索引保存成功。")
    except Exception as e: print(f"保存 FAISS 索引時出錯: {e}")
    return index

# --- 新增：DPR 檢索函數 ---
def retrieve_dpr(query, q_tokenizer, q_encoder, faiss_index, doc_ids_list, top_n, device):
    """
    使用 DPR 問題編碼器和 FAISS 索引進行檢索。

    Args:
        query (str): 查詢問題。
        q_tokenizer: DPR 問題分詞器。
        q_encoder: DPR 問題編碼器。
        faiss_index: 構建好的 FAISS 索引。
        doc_ids_list (list): 與 FAISS 索引順序一致的文檔 ID 列表。
        top_n (int): 需要檢索的 top N 結果。
        device: 運行設備。

    Returns:
        list[int]: 檢索到的 top_n 個文檔 ID 列表。
    """
    if not query or faiss_index is None or q_encoder is None or not doc_ids_list:
        print("錯誤：檢索所需的組件不完整。")
        return []

    try:
        # 1. 編碼查詢
        inputs = q_tokenizer(query, max_length=MAX_LENGTH, padding=False, truncation=True, return_tensors='pt')
        inputs = {key: val.to(device) for key, val in inputs.items()}

        with torch.no_grad():
            outputs = q_encoder(**inputs)
            query_vector = outputs.pooler_output

        # 2. 歸一化查詢向量
        norm = torch.linalg.norm(query_vector)
        normalized_query_vector = query_vector / (norm + 1e-8) # 添加 epsilon

        # 轉換為 NumPy float32 並 reshape
        query_np = normalized_query_vector.cpu().numpy().astype('float32').reshape(1, -1)

        # 3. 在 FAISS 中搜索
        # search 返回 (距離/分數 D, 索引 I)
        distances, indices = faiss_index.search(query_np, top_n)

        # 4. 將 FAISS 索引映射回文檔 ID
        retrieved_ids = [doc_ids_list[i] for i in indices[0] if i != -1] # i == -1 表示沒有找到足夠的鄰居

        return retrieved_ids

    except Exception as e:
        print(f"檢索過程中出錯: {e}")
        return []


# --- 主程序入口 ---
if __name__ == "__main__":
    # 1. 加載 DPR 模型
    # 只需要問題編碼器用於檢索，但如果需要重新構建索引，則兩者都需要
    question_tokenizer, question_encoder, context_tokenizer, context_encoder = load_dpr_models(
        QUESTION_ENCODER_NAME,
        CONTEXT_ENCODER_NAME,
        device
    )

    faiss_index = None
    doc_ids_list = []

    # 2. 嘗試加載已存在的索引和 ID 列表
    if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(DOC_IDS_PATH):
        print(f"\n檢測到已存在的 FAISS 索引和文檔 ID 映射。")
        try:
            print("正在加載 FAISS 索引...")
            faiss_index = faiss.read_index(FAISS_INDEX_PATH)
            print(f"索引加載成功。包含 {faiss_index.ntotal} 個向量。")
            print("正在加載文檔 ID 映射...")
            with open(DOC_IDS_PATH, 'r') as f:
                doc_ids_list = json.load(f)
            print(f"文檔 ID 映射加載成功。包含 {len(doc_ids_list)} 個 ID。")
            if faiss_index.ntotal != len(doc_ids_list):
                print("警告：索引向量數與 ID 數量不匹配！")
                faiss_index = None; doc_ids_list = []
        except Exception as e:
            print(f"加載已存文件時出錯: {e}。")
            faiss_index = None; doc_ids_list = []

    # 3. 如果索引加載失敗或不存在，則嘗試構建
    if faiss_index is None and context_encoder and context_tokenizer:
        print("\n需要構建 FAISS 索引...")
        print("\n--- 加載文檔數據 ---")
        documents_data = load_jsonl(DOCUMENTS_FILE)
        if documents_data:
            document_vectors, doc_ids_list_build = encode_documents(
                documents_data, context_tokenizer, context_encoder, device, BATCH_SIZE, MAX_LENGTH)
            if document_vectors is not None and doc_ids_list_build:
                faiss_index = build_faiss_index(document_vectors, FAISS_INDEX_PATH)
                if faiss_index is not None:
                    doc_ids_list = doc_ids_list_build # 使用新生成的 ID 列表
                    try:
                        print(f"正在將新生成的文檔 ID 列表保存到: {DOC_IDS_PATH}")
                        with open(DOC_IDS_PATH, 'w') as f: json.dump(doc_ids_list, f)
                        print("文檔 ID 列表保存成功。")
                    except Exception as e: print(f"保存文檔 ID 列表時出錯: {e}")
            else: print("文檔編碼失敗，無法構建索引。")
        else: print("無法加載文檔數據，無法構建索引。")

    # 4. 檢查是否可以進行檢索
    if faiss_index is not None and question_encoder is not None and question_tokenizer is not None and doc_ids_list:
        print("\n所有檢索組件準備就緒！")

        # 5. 演示檢索功能
        print("\n--- 演示 DPR 檢索 ---")
        # 加載驗證數據以獲取示例問題
        val_data = load_jsonl(VAL_FILE)
        if val_data:
            example_index = 0 # 使用第一個驗證問題
            example_question = val_data[example_index]['question']
            true_doc_id = val_data[example_index]['document_id']

            print(f"示例問題 (來自 val.jsonl[{example_index}]): '{example_question}'")
            print(f"真實相關的文檔 ID: {true_doc_id}")

            start_retrieval_time = time.time()
            retrieved_ids = retrieve_dpr(
                example_question,
                question_tokenizer,
                question_encoder,
                faiss_index,
                doc_ids_list,
                top_n=5,
                device=device
            )
            end_retrieval_time = time.time()

            print(f"\nDPR 檢索到的 Top-5 文檔 ID:")
            print(retrieved_ids)
            print(f"檢索耗時: {end_retrieval_time - start_retrieval_time:.4f} 秒")

            # 檢查結果
            if retrieved_ids and true_doc_id in retrieved_ids:
                rank = retrieved_ids.index(true_doc_id) + 1
                print(f"成功! 真實文檔 ID {true_doc_id} 在檢索結果中排名第 {rank}。")
            elif retrieved_ids:
                print(f"失敗。真實文檔 ID {true_doc_id} 未在前 5 個檢索結果中。")
            else:
                print("未能檢索到任何文檔。")
        else:
            print("未能加載驗證數據進行演示。")

        print("\n下一步：(可選) 使用驗證集評估 DPR 檢索性能，或開始實現答案生成。")

    else:
        print("\n未能成功準備好用於檢索的組件。請檢查之前的錯誤信息。")



使用的設備: cuda
開始加載 DPR 模型...


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRCon

DPR 模型加載完成。耗時: 2.31 秒

檢測到已存在的 FAISS 索引和文檔 ID 映射。
正在加載 FAISS 索引...
索引加載成功。包含 12138 個向量。
正在加載文檔 ID 映射...
文檔 ID 映射加載成功。包含 12138 個 ID。

所有檢索組件準備就緒！

--- 演示 DPR 檢索 ---
成功加載 1000 條記錄從 /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl
示例問題 (來自 val.jsonl[0]): 'when did the british first land in north america'
真實相關的文檔 ID: 11484

DPR 檢索到的 Top-5 文檔 ID:
[11484, 4117, 81, 411, 7632]
檢索耗時: 0.0242 秒
成功! 真實文檔 ID 11484 在檢索結果中排名第 1。

下一步：(可選) 使用驗證集評估 DPR 檢索性能，或開始實現答案生成。


In [5]:
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import faiss # 引入 faiss
import numpy as np
import time
import os
import json
from tqdm import tqdm

# --- 配置 ---
QUESTION_ENCODER_NAME = 'facebook/dpr-question_encoder-single-nq-base'
CONTEXT_ENCODER_NAME = 'facebook/dpr-ctx_encoder-single-nq-base'
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data'
DOCUMENTS_FILE = os.path.join(DATA_DIR, 'documents.jsonl')
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl') # 用於演示和評估
# 保存/加載 FAISS 索引和文檔 ID 映射的文件路徑
FAISS_INDEX_PATH = "/kaggle/working/dpr_faiss_index.idx"
DOC_IDS_PATH = "/kaggle/working/dpr_doc_ids.json"
# 編碼時的批次大小
BATCH_SIZE = 32 # 文檔編碼時使用
# DPR 模型通常的最大序列長度
MAX_LENGTH = 512

# --- 檢查是否有可用的 GPU ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用的設備: {device}")

# --- 加載 DPR 模型和分詞器 ---
def load_dpr_models(q_encoder_name, ctx_encoder_name, device):
    """加載 DPR 模型。"""
    print(f"開始加載 DPR 模型...")
    start_time = time.time()
    # 初始化為 None
    q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder = None, None, None, None
    try:
        print(f"  加載問題分詞器: {q_encoder_name}")
        q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(q_encoder_name)
        print(f"  加載問題編碼器: {q_encoder_name}")
        q_encoder = DPRQuestionEncoder.from_pretrained(q_encoder_name)
        q_encoder.to(device).eval()

        # 嘗試加載上下文模型，即使失敗也可能繼續（如果索引已存在）
        try:
            print(f"  加載上下文分詞器: {ctx_encoder_name}")
            ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(ctx_encoder_name)
            print(f"  加載上下文編碼器: {ctx_encoder_name}")
            ctx_encoder = DPRContextEncoder.from_pretrained(ctx_encoder_name)
            ctx_encoder.to(device).eval()
        except Exception as e_ctx:
             print(f"警告：加載上下文模型時出錯（如果索引已存在，可能不影響檢索）: {e_ctx}")

        end_time = time.time()
        print(f"DPR 模型加載完成（或部分完成）。耗時: {end_time - start_time:.2f} 秒")
        return q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder
    except Exception as e_q:
        print(f"加載 DPR 問題模型時出錯，無法繼續: {e_q}")
        return None, None, None, None


# --- 加載文檔/驗證數據 ---
def load_jsonl(file_path):
    """加載 JSON Lines 文件。"""
    data = []
    if not os.path.exists(file_path):
        print(f"警告: 文件未找到 {file_path}")
        return data
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                 if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"警告: 跳過無法解析的行: {line.strip()} - 錯誤: {e}")
        print(f"成功加載 {len(data)} 條記錄從 {file_path}")
    except Exception as e:
        print(f"加載文件時出錯 {file_path}: {e}")
    return data

# --- 文檔編碼函數 ---
def encode_documents(documents, tokenizer, encoder, device, batch_size=32, max_length=512):
    """使用 DPR 上下文編碼器對文檔進行編碼。"""
    doc_vectors = []
    doc_ids = []
    total_docs = len(documents)
    print(f"\n開始對 {total_docs} 個文檔進行編碼 (Batch Size: {batch_size})...")
    # print("這一步會非常耗時，尤其是在 CPU 上。") # 已知信息
    for doc in documents:
         if 'document_id' in doc: doc_ids.append(doc['document_id'])
    with tqdm(total=total_docs, desc="Encoding Documents") as pbar:
        for i in range(0, total_docs, batch_size):
            batch_docs = documents[i : i + batch_size]
            texts = [str(doc.get('document_text', '')) for doc in batch_docs]
            try:
                inputs = tokenizer(texts, max_length=max_length, padding='longest', truncation=True, return_tensors='pt')
                inputs = {key: val.to(device) for key, val in inputs.items()}
                with torch.no_grad():
                    outputs = encoder(**inputs)
                    batch_vectors = outputs.pooler_output
                norms = torch.linalg.norm(batch_vectors, dim=1, keepdim=True)
                normalized_vectors = batch_vectors / (norms + 1e-8)
                doc_vectors.append(normalized_vectors.cpu().numpy())
            except Exception as e:
                print(f"\n處理批次 {i // batch_size} 時出錯: {e}")
            pbar.update(len(batch_docs))
    if not doc_vectors: return None, None
    all_doc_vectors = np.concatenate(doc_vectors, axis=0)
    print(f"\n文檔編碼完成。生成向量矩陣形狀: {all_doc_vectors.shape}")
    return all_doc_vectors.astype('float32'), doc_ids

# --- FAISS 索引構建函數 ---
def build_faiss_index(doc_vectors, index_path):
    """使用文檔向量構建 FAISS 索引並保存。"""
    if doc_vectors is None or doc_vectors.shape[0] == 0: return None
    vector_dim = doc_vectors.shape[1]
    print(f"\n開始構建 FAISS 索引 (向量維度: {vector_dim})...")
    start_time = time.time()
    index = faiss.IndexFlatIP(vector_dim)
    print(f"正在添加 {doc_vectors.shape[0]} 個向量到索引...")
    index.add(doc_vectors)
    end_time = time.time()
    print(f"FAISS 索引構建完成。索引中文檔數: {index.ntotal}. 耗時: {end_time - start_time:.2f} 秒")
    try:
        print(f"正在將索引保存到: {index_path}")
        faiss.write_index(index, index_path)
        print("索引保存成功。")
    except Exception as e: print(f"保存 FAISS 索引時出錯: {e}")
    return index

# --- 新增：DPR 檢索器類 ---
class DprFaissRetriever:
    """
    封裝 DPR + FAISS 檢索邏輯的類。
    """
    def __init__(self, q_tokenizer, q_encoder, faiss_index, doc_ids_list, device):
        self.q_tokenizer = q_tokenizer
        self.q_encoder = q_encoder
        self.faiss_index = faiss_index
        self.doc_ids_list = doc_ids_list
        self.device = device
        self.vector_dim = q_encoder.config.hidden_size if q_encoder else None

    def retrieve(self, query, top_n=5):
        """
        執行檢索。
        """
        if not all([query, self.faiss_index, self.q_encoder, self.q_tokenizer, self.doc_ids_list, self.vector_dim]):
            print("錯誤：DPR 檢索器初始化不完整或查詢無效。")
            return []

        try:
            # 1. 編碼查詢
            inputs = self.q_tokenizer(query, max_length=MAX_LENGTH, padding=False, truncation=True, return_tensors='pt')
            inputs = {key: val.to(self.device) for key, val in inputs.items()}

            with torch.no_grad():
                outputs = self.q_encoder(**inputs)
                query_vector = outputs.pooler_output

            # 2. 歸一化查詢向量
            norm = torch.linalg.norm(query_vector)
            normalized_query_vector = query_vector / (norm + 1e-8)

            # 轉換為 NumPy float32 並 reshape
            query_np = normalized_query_vector.cpu().numpy().astype('float32').reshape(1, -1)

            # 3. 在 FAISS 中搜索
            distances, indices = self.faiss_index.search(query_np, top_n)

            # 4. 將 FAISS 索引映射回文檔 ID
            retrieved_ids = [self.doc_ids_list[i] for i in indices[0] if 0 <= i < len(self.doc_ids_list)] # 增加邊界檢查

            return retrieved_ids

        except Exception as e:
            print(f"檢索過程中出錯: {e}")
            return []

# --- 評估函數 (與之前相同) ---
def evaluate_retriever(retriever, validation_data, top_n=5):
    """評估檢索器。"""
    recall_sum = 0
    mrr_sum = 0
    total = len(validation_data)
    if total == 0: return {f'recall@{top_n}': 0, f'mrr@{top_n}': 0}
    print(f"\n開始評估 {type(retriever).__name__} (共 {total} 個問題)...")
    retrieved_results = [] # 用於存儲每次的檢索結果 (可選)
    for item in tqdm(validation_data, desc=f"Evaluating {type(retriever).__name__}"):
        question = item['question']
        true_doc_id = item['document_id']
        # 調用傳入的 retriever 對象的 retrieve 方法
        retrieved_ids = retriever.retrieve(question, top_n=top_n)
        retrieved_results.append(retrieved_ids) # 存儲結果
        if true_doc_id in retrieved_ids:
            recall_sum += 1
            try:
                rank = retrieved_ids.index(true_doc_id) + 1
                mrr_sum += 1.0 / rank
            except ValueError: pass
    recall_at_n = recall_sum / total
    mrr_at_n = mrr_sum / total
    return {f'recall@{top_n}': recall_at_n, f'mrr@{top_n}': mrr_at_n} # 可以考慮同時返回 retrieved_results 用於分析

# --- 主程序入口 ---
if __name__ == "__main__":
    # 1. 加載 DPR 模型
    question_tokenizer, question_encoder, context_tokenizer, context_encoder = load_dpr_models(
        QUESTION_ENCODER_NAME,
        CONTEXT_ENCODER_NAME,
        device
    )

    faiss_index = None
    doc_ids_list = []

    # 2. 嘗試加載已存在的索引和 ID 列表
    if os.path.exists(FAISS_INDEX_PATH) and os.path.exists(DOC_IDS_PATH):
        print(f"\n檢測到已存在的 FAISS 索引和文檔 ID 映射。")
        try:
            print("正在加載 FAISS 索引...")
            faiss_index = faiss.read_index(FAISS_INDEX_PATH)
            print(f"索引加載成功。包含 {faiss_index.ntotal} 個向量。")
            print("正在加載文檔 ID 映射...")
            with open(DOC_IDS_PATH, 'r') as f:
                doc_ids_list = json.load(f)
            print(f"文檔 ID 映射加載成功。包含 {len(doc_ids_list)} 個 ID。")
            if faiss_index.ntotal != len(doc_ids_list):
                print("警告：索引向量數與 ID 數量不匹配！")
                faiss_index = None; doc_ids_list = []
        except Exception as e:
            print(f"加載已存文件時出錯: {e}。")
            faiss_index = None; doc_ids_list = []

    # 3. 如果索引加載失敗或不存在，則嘗試構建
    if faiss_index is None:
        print("\n需要構建 FAISS 索引...")
        if context_encoder and context_tokenizer: # 確保上下文模型已加載
             print("\n--- 加載文檔數據 ---")
             documents_data = load_jsonl(DOCUMENTS_FILE)
             if documents_data:
                 document_vectors, doc_ids_list_build = encode_documents(
                     documents_data, context_tokenizer, context_encoder, device, BATCH_SIZE, MAX_LENGTH)
                 if document_vectors is not None and doc_ids_list_build:
                     faiss_index = build_faiss_index(document_vectors, FAISS_INDEX_PATH)
                     if faiss_index is not None:
                         doc_ids_list = doc_ids_list_build
                         try:
                             print(f"正在將新生成的文檔 ID 列表保存到: {DOC_IDS_PATH}")
                             with open(DOC_IDS_PATH, 'w') as f: json.dump(doc_ids_list, f)
                             print("文檔 ID 列表保存成功。")
                         except Exception as e: print(f"保存文檔 ID 列表時出錯: {e}")
                 else: print("文檔編碼失敗，無法構建索引。")
             else: print("無法加載文檔數據，無法構建索引。")
        else:
             print("錯誤：上下文編碼器未成功加載，無法構建新的 FAISS 索引。")


    # 4. 檢查是否可以進行評估
    if faiss_index is not None and question_encoder is not None and question_tokenizer is not None and doc_ids_list:
        print("\n所有 DPR 檢索組件準備就緒！")

        # 5. 加載驗證數據
        print("\n--- 加載驗證數據 ---")
        val_data = load_jsonl(VAL_FILE)

        if val_data:
             # 6. 創建 DPR 檢索器實例
             dpr_retriever = DprFaissRetriever(
                 question_tokenizer,
                 question_encoder,
                 faiss_index,
                 doc_ids_list,
                 device
             )

             # 7. 執行評估
             results = {}
             # 運行 DPR 評估
             results['DPR + FAISS'] = evaluate_retriever(dpr_retriever, val_data, top_n=5)

             # (可選) 重新運行之前的基線評估進行比較
             # print("\n--- 重新運行基線評估 (可選) ---")
             # tfidf_retriever = TfidfRetriever()
             # tfidf_retriever.build_index(load_jsonl(DOCUMENTS_FILE)) # 需要重新加載文檔
             # if tfidf_retriever.tfidf_matrix is not None:
             #     results['TF-IDF'] = evaluate_retriever(tfidf_retriever, val_data, top_n=5)
             # bm25_retriever = Bm25Retriever()
             # bm25_retriever.build_index(load_jsonl(DOCUMENTS_FILE))
             # if bm25_retriever.bm25_index is not None:
             #     results['BM25'] = evaluate_retriever(bm25_retriever, val_data, top_n=5)


             # 8. 打印結果表格
             print("\n\n--- DPR 檢索器性能評估結果 (驗證集) ---")
             print("-" * 60)
             print(f"{'Retriever':<25} | {'Recall@5':<15} | {'MRR@5':<15}")
             print("-" * 60)
             # 僅打印 DPR 結果，或取消註釋上面基線部分以顯示完整比較
             order = ['DPR + FAISS'] #, 'TF-IDF', 'BM25']
             for name in order:
                 if name in results:
                     metrics = results[name]
                     recall_str = f"{metrics.get('recall@5', 'N/A'):.4f}" if isinstance(metrics.get('recall@5'), float) else str(metrics.get('recall@5', 'N/A'))
                     mrr_str = f"{metrics.get('mrr@5', 'N/A'):.4f}" if isinstance(metrics.get('mrr@5'), float) else str(metrics.get('mrr@5', 'N/A'))
                     print(f"{name:<25} | {recall_str:<15} | {mrr_str:<15}")
             print("-" * 60)

             print("\n下一步：實現答案生成。")

        else:
            print("未能加載驗證數據進行評估。")

    else:
        print("\n未能成功準備好用於 DPR 評估的組件。請檢查之前的錯誤信息。")



使用的設備: cuda
開始加載 DPR 模型...
  加載問題分詞器: facebook/dpr-question_encoder-single-nq-base


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/493 [00:00<?, ?B/s]

  加載問題編碼器: facebook/dpr-question_encoder-single-nq-base


pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  加載上下文分詞器: facebook/dpr-ctx_encoder-single-nq-base


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/492 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.


  加載上下文編碼器: facebook/dpr-ctx_encoder-single-nq-base


pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Process Process-auto_conversion:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/safetensors_conversion.py", line 93, in auto_conversion
    sharded = api.file_exists(
  File "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/huggingface_hub/hf_api.py", line 2885, in file_exists
    get_hf_file_metadata(url, token=token)
  File "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py", line 1296, in get_hf_file_metadata
    r = _request_wrappe

KeyboardInterrupt: 

In [6]:
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import faiss # 引入 faiss
import numpy as np
import time
import os
import json
from tqdm import tqdm
import math
import re # 用於簡單清理
import nltk # 用於分詞 (如果需要更精確的詞數分割)

# --- 配置 ---
QUESTION_ENCODER_NAME = 'facebook/dpr-question_encoder-single-nq-base'
CONTEXT_ENCODER_NAME = 'facebook/dpr-ctx_encoder-single-nq-base'
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data'
DOCUMENTS_FILE = os.path.join(DATA_DIR, 'documents.jsonl')
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl') # 用於評估

# --- 修改：段落索引和映射的文件路徑 ---
PASSAGE_FAISS_INDEX_PATH = "dpr_passage_faiss_index.idx"
PASSAGE_MAPPING_PATH = "dpr_passage_mapping.json" # 存儲索引位置 -> (原始文檔ID, 段落文本/ID)

# 編碼時的批次大小
BATCH_SIZE = 32
# DPR 模型最大序列長度
MAX_LENGTH = 512
# --- 新增：文檔分塊參數 ---
CHUNK_SIZE = 100 # 每個段落的目標詞數
CHUNK_OVERLAP = 20 # 段落間重疊的詞數

# --- 檢查是否有可用的 GPU ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用的設備: {device}")

# --- NLTK數據檢查 (用於分詞) ---
try:
    nltk.data.find('tokenizers/punkt')
except nltk.downloader.DownloadError:
    print("下載 NLTK 'punkt' 數據...")
    nltk.download('punkt', quiet=True)

# --- 加載 DPR 模型 (與之前類似) ---
def load_dpr_models(q_encoder_name, ctx_encoder_name, device):
    """加載 DPR 模型。"""
    print(f"開始加載 DPR 模型...")
    start_time = time.time()
    q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder = None, None, None, None
    try:
        print(f"  加載問題分詞器: {q_encoder_name}")
        q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(q_encoder_name)
        print(f"  加載問題編碼器: {q_encoder_name}")
        q_encoder = DPRQuestionEncoder.from_pretrained(q_encoder_name)
        q_encoder.to(device).eval()
        try:
            print(f"  加載上下文分詞器: {ctx_encoder_name}")
            ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(ctx_encoder_name)
            print(f"  加載上下文編碼器: {ctx_encoder_name}")
            ctx_encoder = DPRContextEncoder.from_pretrained(ctx_encoder_name)
            ctx_encoder.to(device).eval()
        except Exception as e_ctx:
             print(f"警告：加載上下文模型時出錯: {e_ctx}")
        end_time = time.time()
        print(f"DPR 模型加載完成（或部分完成）。耗時: {end_time - start_time:.2f} 秒")
        return q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder
    except Exception as e_q:
        print(f"加載 DPR 問題模型時出錯，無法繼續: {e_q}")
        return None, None, None, None

# --- 加載數據 (與之前相同) ---
def load_jsonl(file_path):
    """加載 JSON Lines 文件。"""
    data = []
    if not os.path.exists(file_path):
        print(f"警告: 文件未找到 {file_path}")
        return data
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                 if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"警告: 跳過無法解析的行: {line.strip()} - 錯誤: {e}")
        print(f"成功加載 {len(data)} 條記錄從 {file_path}")
    except Exception as e:
        print(f"加載文件時出錯 {file_path}: {e}")
    return data

# --- 新增：文檔分塊函數 ---
def split_document(doc_text, chunk_size, overlap):
    """
    將文檔文本按詞數分割成重疊的段落。
    使用 nltk.word_tokenize 進行分詞。
    """
    if not isinstance(doc_text, str) or not doc_text.strip():
        return []
    # 簡單清理：移除多餘空白
    text = re.sub(r'\s+', ' ', doc_text).strip()
    # 使用 nltk 分詞
    try:
        tokens = nltk.word_tokenize(text)
    except Exception as e:
        print(f"分詞錯誤: {e}, 使用簡單空格分割替代。")
        tokens = text.split() # 回退機制

    if not tokens:
        return []

    passages = []
    stride = chunk_size - overlap # 每次滑動的步長
    if stride <= 0: # 確保步長為正
        stride = max(1, chunk_size // 2)
        print(f"警告：重疊過大，調整步長為 {stride}")

    for i in range(0, len(tokens), stride):
        chunk = tokens[i : i + chunk_size]
        if not chunk: continue
        # 將 token 列表重新組合成字符串
        passages.append(" ".join(chunk))
        # 如果最後一個 chunk 已包含所有剩餘 token，則停止
        if i + chunk_size >= len(tokens):
            break
    return passages


# --- 修改：段落編碼函數 ---
def encode_passages(documents, tokenizer, encoder, device, batch_size=32, max_length=512, chunk_size=100, overlap=20):
    """
    對文檔進行分塊，並使用 DPR 上下文編碼器對段落進行編碼。

    Args:
        documents (list[dict]): 原始文檔數據列表。
        tokenizer, encoder, device, batch_size, max_length: 同之前。
        chunk_size (int): 段落目標詞數。
        overlap (int): 段落重疊詞數。

    Returns:
        tuple: (numpy.ndarray of passage vectors, list of passage mappings)
               passage_mappings: [(原始文檔ID, 段落文本), ...]
               如果出錯則返回 (None, None)。
    """
    all_passages_info = [] # 存儲 (原始文檔ID, 段落文本)
    print(f"\n開始對 {len(documents)} 個文檔進行分塊 (Chunk Size: {chunk_size}, Overlap: {overlap})...")
    for doc in tqdm(documents, desc="Splitting Documents"):
        doc_id = doc.get('document_id')
        doc_text = doc.get('document_text', '')
        if doc_id is None or not doc_text:
            continue
        passages = split_document(doc_text, chunk_size, overlap)
        for passage_text in passages:
            all_passages_info.append((doc_id, passage_text)) # 記錄來源文檔ID和段落文本

    if not all_passages_info:
        print("錯誤：未能從文檔中生成任何段落。")
        return None, None

    total_passages = len(all_passages_info)
    print(f"\n共生成 {total_passages} 個段落。開始對段落進行編碼 (Batch Size: {batch_size})...")

    passage_vectors = []
    passage_mapping_for_index = [] # 存儲與向量順序一致的 (原始文檔ID, 段落文本)

    with tqdm(total=total_passages, desc="Encoding Passages") as pbar:
        for i in range(0, total_passages, batch_size):
            batch_info = all_passages_info[i : i + batch_size]
            batch_texts = [info[1] for info in batch_info] # 提取段落文本

            try:
                inputs = tokenizer(batch_texts, max_length=max_length, padding='longest', truncation=True, return_tensors='pt')
                inputs = {key: val.to(device) for key, val in inputs.items()}
                with torch.no_grad():
                    outputs = encoder(**inputs)
                    batch_vectors = outputs.pooler_output
                norms = torch.linalg.norm(batch_vectors, dim=1, keepdim=True)
                normalized_vectors = batch_vectors / (norms + 1e-8)
                passage_vectors.append(normalized_vectors.cpu().numpy())
                # 將當前批次的信息加入映射列表
                passage_mapping_for_index.extend(batch_info)

            except Exception as e:
                print(f"\n處理段落批次 {i // batch_size} 時出錯: {e}")
            pbar.update(len(batch_info))

    if not passage_vectors:
        print("錯誤：未能生成任何段落向量。")
        return None, None

    all_passage_vectors = np.concatenate(passage_vectors, axis=0)
    print(f"\n段落編碼完成。生成向量矩陣形狀: {all_passage_vectors.shape}")
    # 確保向量數量和映射信息數量一致
    if all_passage_vectors.shape[0] != len(passage_mapping_for_index):
        print(f"警告：向量數量 ({all_passage_vectors.shape[0]}) 與映射信息數量 ({len(passage_mapping_for_index)}) 不匹配！")
        # 可以選擇截斷或報錯
        min_len = min(all_passage_vectors.shape[0], len(passage_mapping_for_index))
        all_passage_vectors = all_passage_vectors[:min_len]
        passage_mapping_for_index = passage_mapping_for_index[:min_len]
        print(f"已截斷至最小長度: {min_len}")

    return all_passage_vectors.astype('float32'), passage_mapping_for_index

# --- 修改：FAISS 索引構建函數 (名稱變化) ---
def build_passage_faiss_index(passage_vectors, index_path):
    """使用段落向量構建 FAISS 索引並保存。"""
    if passage_vectors is None or passage_vectors.shape[0] == 0: return None
    vector_dim = passage_vectors.shape[1]
    print(f"\n開始構建段落 FAISS 索引 (向量維度: {vector_dim})...")
    start_time = time.time()
    index = faiss.IndexFlatIP(vector_dim)
    print(f"正在添加 {passage_vectors.shape[0]} 個段落向量到索引...")
    index.add(passage_vectors)
    end_time = time.time()
    print(f"段落 FAISS 索引構建完成。索引中段落數: {index.ntotal}. 耗時: {end_time - start_time:.2f} 秒")
    try:
        print(f"正在將索引保存到: {index_path}")
        faiss.write_index(index, index_path)
        print("索引保存成功。")
    except Exception as e: print(f"保存 FAISS 索引時出錯: {e}")
    return index

# --- 修改：DPR 段落檢索器類 ---
class DprPassageFaissRetriever:
    """
    封裝 DPR (段落級) + FAISS 檢索邏輯的類。
    """
    def __init__(self, q_tokenizer, q_encoder, faiss_index, passage_mapping, device):
        """
        Args:
            passage_mapping (list): 列表，每個元素是 (原始文檔ID, 段落文本)，順序與 FAISS 索引一致。
        """
        self.q_tokenizer = q_tokenizer
        self.q_encoder = q_encoder
        self.faiss_index = faiss_index
        self.passage_mapping = passage_mapping # 存儲 (原始文檔ID, 段落文本) 列表
        self.device = device
        self.vector_dim = q_encoder.config.hidden_size if q_encoder else None

    def retrieve_passages(self, query, top_k):
        """檢索 top_k 個最相關的段落及其得分和來源文檔ID。"""
        if not all([query, self.faiss_index, self.q_encoder, self.q_tokenizer, self.passage_mapping, self.vector_dim]):
            return [], []

        try:
            inputs = self.q_tokenizer(query, max_length=MAX_LENGTH, padding=False, truncation=True, return_tensors='pt')
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            with torch.no_grad():
                outputs = self.q_encoder(**inputs)
                query_vector = outputs.pooler_output
            norm = torch.linalg.norm(query_vector)
            normalized_query_vector = query_vector / (norm + 1e-8)
            query_np = normalized_query_vector.cpu().numpy().astype('float32').reshape(1, -1)

            # 搜索 FAISS 獲取段落索引和得分（內積）
            scores, indices = self.faiss_index.search(query_np, top_k)

            results = []
            valid_indices = indices[0]
            valid_scores = scores[0]

            for i, idx in enumerate(valid_indices):
                if 0 <= idx < len(self.passage_mapping):
                    original_doc_id, passage_text = self.passage_mapping[idx]
                    results.append({
                        "doc_id": original_doc_id,
                        "passage_text": passage_text, # 可以包含段落文本用於後續處理
                        "score": float(valid_scores[i]),
                        "passage_index_in_faiss": int(idx) # 記錄在faiss中的原始索引
                    })
            return results

        except Exception as e:
            print(f"檢索段落過程中出錯: {e}")
            return []


    def retrieve(self, query, top_n=5):
        """
        執行檢索，返回 top_n 個最相關的 *文檔* ID。
        策略：檢索更多段落，按文檔聚合，取每個文檔最高得分，再排序。
        """
        # 檢索更多的段落，例如 top 20 或 50，以增加覆蓋到 top 5 文檔的可能性
        top_k_passages = 20
        passage_results = self.retrieve_passages(query, top_k=top_k_passages)

        if not passage_results:
            return []

        # 按文檔 ID 聚合，記錄每個文檔的最高得分
        doc_scores = {}
        for p in passage_results:
            doc_id = p["doc_id"]
            score = p["score"]
            if doc_id not in doc_scores or score > doc_scores[doc_id]:
                doc_scores[doc_id] = score

        # 按最高得分對文檔 ID 進行排序
        sorted_docs = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)

        # 返回 top_n 個文檔 ID
        top_n_doc_ids = [doc_id for doc_id, score in sorted_docs[:top_n]]

        return top_n_doc_ids

# --- 評估函數 (與之前相同) ---
def evaluate_retriever(retriever, validation_data, top_n=5):
    """評估檢索器。"""
    recall_sum = 0
    mrr_sum = 0
    total = len(validation_data)
    if total == 0: return {f'recall@{top_n}': 0, f'mrr@{top_n}': 0}
    print(f"\n開始評估 {type(retriever).__name__} (共 {total} 個問題)...")
    for item in tqdm(validation_data, desc=f"Evaluating {type(retriever).__name__}"):
        question = item['question']
        true_doc_id = item['document_id']
        retrieved_ids = retriever.retrieve(question, top_n=top_n) # 這裡返回的是文檔 ID
        if true_doc_id in retrieved_ids:
            recall_sum += 1
            try:
                rank = retrieved_ids.index(true_doc_id) + 1
                mrr_sum += 1.0 / rank
            except ValueError: pass
    recall_at_n = recall_sum / total
    mrr_at_n = mrr_sum / total
    return {f'recall@{top_n}': recall_at_n, f'mrr@{top_n}': mrr_at_n}

# --- 主程序入口 ---
if __name__ == "__main__":
    # 1. 加載 DPR 模型
    question_tokenizer, question_encoder, context_tokenizer, context_encoder = load_dpr_models(
        QUESTION_ENCODER_NAME,
        CONTEXT_ENCODER_NAME,
        device
    )

    passage_faiss_index = None
    passage_mapping = [] # 存儲 (doc_id, passage_text)

    # 2. 嘗試加載已存在的段落索引和映射
    if os.path.exists(PASSAGE_FAISS_INDEX_PATH) and os.path.exists(PASSAGE_MAPPING_PATH):
        print(f"\n檢測到已存在的段落 FAISS 索引和映射文件。")
        try:
            print("正在加載段落 FAISS 索引...")
            passage_faiss_index = faiss.read_index(PASSAGE_FAISS_INDEX_PATH)
            print(f"索引加載成功。包含 {passage_faiss_index.ntotal} 個段落向量。")
            print("正在加載段落映射...")
            with open(PASSAGE_MAPPING_PATH, 'r') as f:
                passage_mapping = json.load(f) # 加載 [(doc_id, passage_text), ...]
            print(f"段落映射加載成功。包含 {len(passage_mapping)} 個條目。")
            if passage_faiss_index.ntotal != len(passage_mapping):
                print("警告：索引向量數與映射數量不匹配！")
                passage_faiss_index = None; passage_mapping = []
        except Exception as e:
            print(f"加載已存文件時出錯: {e}。")
            passage_faiss_index = None; passage_mapping = []

    # 3. 如果索引加載失敗或不存在，則嘗試構建
    if passage_faiss_index is None:
        print("\n需要構建段落 FAISS 索引...")
        if context_encoder and context_tokenizer:
             print("\n--- 加載原始文檔數據 ---")
             documents_data = load_jsonl(DOCUMENTS_FILE)
             if documents_data:
                 # 執行段落編碼
                 passage_vectors, passage_mapping_build = encode_passages(
                     documents_data, context_tokenizer, context_encoder, device,
                     BATCH_SIZE, MAX_LENGTH, CHUNK_SIZE, CHUNK_OVERLAP)

                 if passage_vectors is not None and passage_mapping_build:
                     # 構建段落 FAISS 索引
                     passage_faiss_index = build_passage_faiss_index(passage_vectors, PASSAGE_FAISS_INDEX_PATH)
                     if passage_faiss_index is not None:
                         passage_mapping = passage_mapping_build # 使用新生成的映射
                         try:
                             print(f"正在將新生成的段落映射保存到: {PASSAGE_MAPPING_PATH}")
                             with open(PASSAGE_MAPPING_PATH, 'w') as f: json.dump(passage_mapping, f)
                             print("段落映射保存成功。")
                         except Exception as e: print(f"保存段落映射時出錯: {e}")
                 else: print("段落編碼失敗，無法構建索引。")
             else: print("無法加載文檔數據，無法構建索引。")
        else:
             print("錯誤：上下文編碼器未成功加載，無法構建新的 FAISS 索引。")


    # 4. 檢查是否可以進行評估
    if passage_faiss_index is not None and question_encoder is not None and question_tokenizer is not None and passage_mapping:
        print("\n所有 DPR (段落級) 檢索組件準備就緒！")

        # 5. 加載驗證數據
        print("\n--- 加載驗證數據 ---")
        val_data = load_jsonl(VAL_FILE)

        if val_data:
             # 6. 創建 DPR 段落檢索器實例
             dpr_passage_retriever = DprPassageFaissRetriever(
                 question_tokenizer,
                 question_encoder,
                 passage_faiss_index,
                 passage_mapping, # 傳遞段落映射
                 device
             )

             # 7. 執行評估
             results = {}
             results['DPR Passage'] = evaluate_retriever(dpr_passage_retriever, val_data, top_n=5)

             # 8. 打印結果表格
             print("\n\n--- DPR (段落級) 檢索器性能評估結果 (驗證集) ---")
             print("-" * 60)
             print(f"{'Retriever':<25} | {'Recall@5':<15} | {'MRR@5':<15}")
             print("-" * 60)
             order = ['DPR Passage']
             for name in order:
                 if name in results:
                     metrics = results[name]
                     recall_str = f"{metrics.get('recall@5', 'N/A'):.4f}" if isinstance(metrics.get('recall@5'), float) else str(metrics.get('recall@5', 'N/A'))
                     mrr_str = f"{metrics.get('mrr@5', 'N/A'):.4f}" if isinstance(metrics.get('mrr@5'), float) else str(metrics.get('mrr@5', 'N/A'))
                     print(f"{name:<25} | {recall_str:<15} | {mrr_str:<15}")
             print("-" * 60)

             print("\n下一步：實現答案生成。")

        else:
            print("未能加載驗證數據進行評估。")

    else:
        print("\n未能成功準備好用於 DPR (段落級) 評估的組件。請檢查之前的錯誤信息。")



使用的設備: cuda
開始加載 DPR 模型...
  加載問題分詞器: facebook/dpr-question_encoder-single-nq-base
  加載問題編碼器: facebook/dpr-question_encoder-single-nq-base


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRCon

  加載上下文分詞器: facebook/dpr-ctx_encoder-single-nq-base
  加載上下文編碼器: facebook/dpr-ctx_encoder-single-nq-base


pytorch_model.bin:  55%|#####5    | 241M/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DPR 模型加載完成（或部分完成）。耗時: 4.66 秒

需要構建段落 FAISS 索引...

--- 加載原始文檔數據 ---
成功加載 12138 條記錄從 /kaggle/input/nq10k-comp5423/data_and_code/data/documents.jsonl

開始對 12138 個文檔進行分塊 (Chunk Size: 100, Overlap: 20)...


Splitting Documents: 100%|██████████| 12138/12138 [13:02<00:00, 15.51it/s]



共生成 1932676 個段落。開始對段落進行編碼 (Batch Size: 32)...


Encoding Passages: 100%|██████████| 1932676/1932676 [3:11:10<00:00, 168.50it/s]  



段落編碼完成。生成向量矩陣形狀: (1932676, 768)

開始構建段落 FAISS 索引 (向量維度: 768)...
正在添加 1932676 個段落向量到索引...
段落 FAISS 索引構建完成。索引中段落數: 1932676. 耗時: 5.63 秒
正在將索引保存到: dpr_passage_faiss_index.idx
索引保存成功。
正在將新生成的段落映射保存到: dpr_passage_mapping.json
段落映射保存成功。

所有 DPR (段落級) 檢索組件準備就緒！

--- 加載驗證數據 ---
成功加載 1000 條記錄從 /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl

開始評估 DprPassageFaissRetriever (共 1000 個問題)...


Evaluating DprPassageFaissRetriever: 100%|██████████| 1000/1000 [34:04<00:00,  2.04s/it]



--- DPR (段落級) 檢索器性能評估結果 (驗證集) ---
------------------------------------------------------------
Retriever                 | Recall@5        | MRR@5          
------------------------------------------------------------
DPR Passage               | 0.8050          | 0.6027         
------------------------------------------------------------

下一步：實現答案生成。





In [3]:
!pip install openai



In [5]:
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import faiss
import numpy as np
import time
import os
import json
from tqdm import tqdm
import re
import nltk
from openai import OpenAI # *** 新增導入 ***
import sys # 用於退出

# --- 配置 ---
QUESTION_ENCODER_NAME = 'facebook/dpr-question_encoder-single-nq-base'
CONTEXT_ENCODER_NAME = 'facebook/dpr-ctx_encoder-single-nq-base' # 僅在需要重建索引時加載
# --- 修改：確保 DATA_DIR 路徑正確 ---
# DATA_DIR = 'data' # 如果在本地運行
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data' # 根據你的環境調整
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl') # 使用驗證集生成預測
TEST_FILE = os.path.join(DATA_DIR, 'test.jsonl') # 最終在測試集上運行
PASSAGE_FAISS_INDEX_PATH = "dpr_passage_faiss_index.idx"
PASSAGE_MAPPING_PATH = "dpr_passage_mapping.json"
# --- 輸出文件配置 ---
PREDICTION_TARGET = 'val' # 或 'test'
OUTPUT_PREDICTION_FILE = f"{PREDICTION_TARGET}_predict_dpr_passage.jsonl"

# --- 答案生成相關配置 ---
NUM_CONTEXT_PASSAGES = 3
# DPR 模型通常的最大序列長度
MAX_LENGTH = 512 # 定義在這裡
MAX_CONTEXT_LENGTH = 2000 # 根據 Qwen 模型限制調整
LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" # 使用項目指定的模型
LLM_MAX_TOKENS = 200 # LLM 生成答案的最大 token 數 (之前是 150，根據需要調整)
LLM_TEMPERATURE = 0.7 # LLM 生成的溫度參數

# --- SiliconFlow API 配置 (需要你填充！) ---
# *** 重要：請將 YOUR_API_KEY 替換為你在 SiliconFlow 創建的真實 API 密鑰 ***
SILICONFLOW_API_KEY = "sk-xvrbrpercvlfxkfsxsfaidnwpvfjdwouqrxsauhxbdjnkmhh" # 已填寫
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"

# --- 設備檢測 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用的設備: {device}")

# --- 初始化 SiliconFlow 客戶端 ---
if SILICONFLOW_API_KEY == "YOUR_API_KEY":
    print("警告：請在代碼中設置你的 SiliconFlow API 密鑰 (SILICONFLOW_API_KEY)。")
    llm_client = None
else:
    try:
        llm_client = OpenAI(api_key=SILICONFLOW_API_KEY, base_url=SILICONFLOW_BASE_URL)
        print("SiliconFlow OpenAI 客戶端初始化成功。")
    except Exception as e:
        print(f"初始化 SiliconFlow OpenAI 客戶端時出錯: {e}")
        llm_client = None

# --- 模型加載、數據加載、檢索器類定義 ---
def load_dpr_models(q_encoder_name, ctx_encoder_name, device):
    # ... (加載邏輯，與之前相同) ...
    print(f"開始加載 DPR 模型...")
    start_time = time.time()
    q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder = None, None, None, None
    try:
        print(f"  加載問題分詞器: {q_encoder_name}")
        q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(q_encoder_name)
        print(f"  加載問題編碼器: {q_encoder_name}")
        q_encoder = DPRQuestionEncoder.from_pretrained(q_encoder_name)
        q_encoder.to(device).eval()
        try:
            print(f"  嘗試加載上下文分詞器: {ctx_encoder_name}")
            ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(ctx_encoder_name)
            print(f"  嘗試加載上下文編碼器: {ctx_encoder_name}")
            ctx_encoder = DPRContextEncoder.from_pretrained(ctx_encoder_name)
            ctx_encoder.to(device).eval()
        except Exception as e_ctx:
             print(f"警告：加載上下文模型時出錯: {e_ctx}")
        end_time = time.time()
        print(f"DPR 模型加載完成（或部分完成）。耗時: {end_time - start_time:.2f} 秒")
        return q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder
    except Exception as e_q:
        print(f"加載 DPR 問題模型時出錯，無法繼續: {e_q}")
        return None, None, None, None


def load_jsonl(file_path):
    # ... (加載邏輯，與之前相同) ...
    data = []
    if not os.path.exists(file_path): return data
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip(): data.append(json.loads(line))
    print(f"成功加載 {len(data)} 條記錄從 {file_path}")
    return data

# --- 修改：DprPassageFaissRetriever 類 ---
class DprPassageFaissRetriever:
    def __init__(self, q_tokenizer, q_encoder, faiss_index, passage_mapping, device):
        self.q_tokenizer = q_tokenizer
        self.q_encoder = q_encoder
        self.faiss_index = faiss_index
        self.passage_mapping = passage_mapping
        self.device = device
        self.vector_dim = q_encoder.config.hidden_size if q_encoder else None
        # *** 新增：將 MAX_LENGTH 存儲為實例屬性 ***
        self.max_length = MAX_LENGTH
        print(f"DPR Passage Retriever 初始化完成。索引段落數: {self.faiss_index.ntotal if self.faiss_index else 'N/A'}")

    def retrieve_passages(self, query, top_k):
        if not all([query, self.faiss_index, self.q_encoder, self.q_tokenizer, self.passage_mapping, self.vector_dim]): return []
        try:
            # *** 修改：使用 self.max_length ***
            inputs = self.q_tokenizer(query, max_length=self.max_length, padding=False, truncation=True, return_tensors='pt')
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            with torch.no_grad():
                outputs = self.q_encoder(**inputs)
                query_vector = outputs.pooler_output
            norm = torch.linalg.norm(query_vector)
            normalized_query_vector = query_vector / (norm + 1e-8)
            query_np = normalized_query_vector.cpu().numpy().astype('float32').reshape(1, -1)
            scores, indices = self.faiss_index.search(query_np, top_k)
            results = []
            valid_indices = indices[0]
            valid_scores = scores[0]
            for i, idx in enumerate(valid_indices):
                if 0 <= idx < len(self.passage_mapping):
                    original_doc_id, passage_text = self.passage_mapping[idx]
                    results.append({
                        "doc_id": original_doc_id,
                        "passage_text": passage_text,
                        "score": float(valid_scores[i]),
                        "passage_index_in_faiss": int(idx)
                    })
            return results
        except Exception as e: print(f"檢索段落過程中出錯: {e}"); return []

    def retrieve(self, query, top_n=5):
        top_k_passages = 20
        passage_results = self.retrieve_passages(query, top_k=top_k_passages)
        if not passage_results: return [-1]*top_n
        doc_scores = {}
        for p in passage_results:
            doc_id = p["doc_id"]
            score = p["score"]
            if doc_id not in doc_scores or score > doc_scores[doc_id]: doc_scores[doc_id] = score
        sorted_docs = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)
        top_n_doc_ids = [doc_id for doc_id, score in sorted_docs[:top_n]]
        while len(top_n_doc_ids) < top_n:
             top_n_doc_ids.append(-1)
        return top_n_doc_ids[:top_n]


# --- generate_answer_with_qwen 函數 (與之前相同) ---
def generate_answer_with_qwen(prompt):
    """
    使用 SiliconFlow API (通過 OpenAI 兼容接口) 調用 Qwen LLM 生成答案。
    """
    global llm_client
    if llm_client is None:
        print("錯誤：SiliconFlow 客戶端未初始化（可能缺少 API 密鑰）。")
        return "LLM 客戶端未初始化"
    try:
        response = llm_client.chat.completions.create(
            model=LLM_MODEL_NAME,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=LLM_MAX_TOKENS,
            temperature=LLM_TEMPERATURE,
            stream=False
        )
        generated_answer = response.choices[0].message.content.strip()
        return generated_answer
    except Exception as e:
        print(f"調用 SiliconFlow API 時發生錯誤: {e}")
        return f"調用 LLM 時出錯: {e}"

# --- build_prompt 函數 (與之前相同) ---
def build_prompt(question, context_passages):
    """
    根據問題和檢索到的上下文段落構建 Prompt。
    """
    if not context_passages:
        context_str = "沒有找到相關上下文。"
    else:
        context_str = "\n\n".join(context_passages)
        if len(context_str) > MAX_CONTEXT_LENGTH:
            context_str = context_str[:MAX_CONTEXT_LENGTH] + "..."
    prompt = f"""請根據以下提供的上下文信息來回答問題。如果上下文沒有提供足夠的信息，請回答 "信息不足"。

上下文:
{context_str}

問題: {question}

答案:"""
    return prompt

# --- 主程序：生成預測文件 (與之前相同) ---
if __name__ == "__main__":
    # 1. 加載 DPR 模型
    question_tokenizer, question_encoder, _, _ = load_dpr_models(
        QUESTION_ENCODER_NAME,
        CONTEXT_ENCODER_NAME,
        device
    )

    # 2. 加載 FAISS 索引和段落映射
    passage_faiss_index = None
    passage_mapping = []
    # ... (省略索引和映射加載邏輯) ...
    if os.path.exists(PASSAGE_FAISS_INDEX_PATH) and os.path.exists(PASSAGE_MAPPING_PATH):
        try:
            print(f"\n正在加載段落 FAISS 索引: {PASSAGE_FAISS_INDEX_PATH}")
            passage_faiss_index = faiss.read_index(PASSAGE_FAISS_INDEX_PATH)
            print(f"正在加載段落映射: {PASSAGE_MAPPING_PATH}")
            with open(PASSAGE_MAPPING_PATH, 'r') as f: passage_mapping = json.load(f)
            if passage_faiss_index.ntotal != len(passage_mapping): raise ValueError("索引和映射數量不匹配！")
            print("索引和映射加載成功。")
        except Exception as e: print(f"加載索引或映射失敗: {e}"); passage_faiss_index = None
    else: print("未找到索引或映射文件。")


    # 3. 檢查組件是否齊全
    if all([passage_faiss_index, question_encoder, question_tokenizer, passage_mapping, llm_client]):
        print("\n所有檢索和 LLM 組件準備就緒！")

        # 4. 創建檢索器實例 (現在會存儲 max_length)
        retriever = DprPassageFaissRetriever(
            question_tokenizer,
            question_encoder,
            passage_faiss_index,
            passage_mapping,
            device
        )

        # 5. 加載目標數據集
        target_file = VAL_FILE if PREDICTION_TARGET == 'val' else TEST_FILE
        print(f"\n--- 加載目標數據集: {target_file} ---")
        target_data = load_jsonl(target_file)

        if target_data:
            print(f"\n--- 開始生成預測並寫入文件: {OUTPUT_PREDICTION_FILE} ---")
            # 6. 遍歷數據集，執行檢索、生成答案、寫入文件
            with open(OUTPUT_PREDICTION_FILE, 'w', encoding='utf-8') as outfile:
                for i, item in enumerate(tqdm(target_data, desc=f"Generating {PREDICTION_TARGET} predictions")):
                    question = item['question']
                    try:
                        # a. 檢索上下文段落
                        retrieved_passages_info = retriever.retrieve_passages(question, top_k=NUM_CONTEXT_PASSAGES)
                        context_texts = [p['passage_text'] for p in retrieved_passages_info]
                        # b. 檢索 Top-5 文檔 ID
                        top_5_doc_ids = retriever.retrieve(question, top_n=5)
                        # c. 構建 Prompt
                        prompt = build_prompt(question, context_texts)
                        # d. 調用 LLM 生成答案
                        generated_answer = generate_answer_with_qwen(prompt)
                        # e. 格式化輸出
                        output_record = {"question": question, "answer": generated_answer, "document_id": top_5_doc_ids}
                        # f. 寫入文件
                        outfile.write(json.dumps(output_record, ensure_ascii=False) + '\n')
                    except Exception as e:
                        print(f"\n處理問題 {i} ('{question}') 時發生錯誤: {e}")
                        error_record = {"question": question, "answer": f"處理時發生錯誤: {e}", "document_id": [-1]*5 }
                        outfile.write(json.dumps(error_record, ensure_ascii=False) + '\n')

            print(f"\n預測文件 {OUTPUT_PREDICTION_FILE} 生成完成。")
            print(f"下一步：使用 metrics_calculation.py 評估 {OUTPUT_PREDICTION_FILE} 文件 (如果目標是驗證集)。")
            if PREDICTION_TARGET == 'test':
                 print("請將生成的 test_predict_dpr_passage.jsonl 文件用於最終提交。")

        else:
            print(f"未能加載目標數據集 {target_file}。")

    else:
        # 提供更詳細的失敗原因
        missing = []
        if not passage_faiss_index: missing.append("FAISS 索引")
        if not question_encoder: missing.append("問題編碼器")
        if not question_tokenizer: missing.append("問題分詞器")
        if not passage_mapping: missing.append("段落映射")
        if not llm_client: missing.append("LLM 客戶端 (API Key?)")
        print(f"\n未能成功準備好用於生成預測的組件。缺少或失敗：{', '.join(missing)}。請檢查之前的錯誤信息。")



使用的設備: cuda
SiliconFlow OpenAI 客戶端初始化成功。
開始加載 DPR 模型...
  加載問題分詞器: facebook/dpr-question_encoder-single-nq-base
  加載問題編碼器: facebook/dpr-question_encoder-single-nq-base


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRCon

  嘗試加載上下文分詞器: facebook/dpr-ctx_encoder-single-nq-base
  嘗試加載上下文編碼器: facebook/dpr-ctx_encoder-single-nq-base


Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DPR 模型加載完成（或部分完成）。耗時: 2.12 秒

正在加載段落 FAISS 索引: dpr_passage_faiss_index.idx
正在加載段落映射: dpr_passage_mapping.json
索引和映射加載成功。

所有檢索和 LLM 組件準備就緒！
DPR Passage Retriever 初始化完成。索引段落數: 1932676

--- 加載目標數據集: /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl ---
成功加載 1000 條記錄從 /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl

--- 開始生成預測並寫入文件: val_predict_dpr_passage.jsonl ---


Generating val predictions: 100%|██████████| 1000/1000 [2:30:07<00:00,  9.01s/it] 


預測文件 val_predict_dpr_passage.jsonl 生成完成。
下一步：使用 metrics_calculation.py 評估 val_predict_dpr_passage.jsonl 文件 (如果目標是驗證集)。





In [12]:
cd /kaggle/input/nq10k-comp5423/data_and_code

/kaggle/input/nq10k-comp5423/data_and_code


In [30]:
!python /kaggle/input/nq10k-comp5423/data_and_code/metrics_calculation.py 

Evaluation Result:
Answer Accuracy:             0.0020
Document Retrieval Recall@5: 0.0010
Document Retrieval MRR@5   : 0.0003


In [22]:
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import faiss
import numpy as np
import time
import os
import json
from tqdm import tqdm
import re
import nltk
from openai import OpenAI # *** 新增導入 ***
import sys # 用於退出

# --- 配置 ---
QUESTION_ENCODER_NAME = 'facebook/dpr-question_encoder-single-nq-base'
CONTEXT_ENCODER_NAME = 'facebook/dpr-ctx_encoder-single-nq-base' # 僅在需要重建索引時加載
# --- 修改：確保 DATA_DIR 路徑正確 ---
# DATA_DIR = 'data' # 如果在本地運行
DATA_DIR = '/kaggle/input/nq10k-comp5423/data_and_code/data' # 根據你的環境調整
VAL_FILE = os.path.join(DATA_DIR, 'val.jsonl') # 使用驗證集生成預測
TEST_FILE = os.path.join(DATA_DIR, 'test.jsonl') # 最終在測試集上運行

# --- 修改：明確指定可寫目錄和文件路徑 ---
WRITABLE_DIR = "/kaggle/working/" # Kaggle 的可寫目錄 (如果是其他環境，請修改)
# 確保目錄存在 (雖然 /kaggle/working/ 通常存在)
os.makedirs(WRITABLE_DIR, exist_ok=True)

PASSAGE_FAISS_INDEX_PATH = os.path.join(WRITABLE_DIR, "dpr_passage_faiss_index.idx")
PASSAGE_MAPPING_PATH = os.path.join(WRITABLE_DIR, "dpr_passage_mapping.json")
# --- 輸出文件配置 ---
PREDICTION_TARGET = 'val' # 或 'test'
# *** 修改：使用 os.path.join 指定完整輸出路徑 ***
OUTPUT_PREDICTION_FILE = os.path.join(WRITABLE_DIR, f"{PREDICTION_TARGET}_predict_dpr_passage_debug.jsonl") # 加個後綴區分

# --- 答案生成相關配置 ---
NUM_CONTEXT_PASSAGES = 3
MAX_LENGTH = 512
MAX_CONTEXT_LENGTH = 2000
LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
LLM_MAX_TOKENS = 200
LLM_TEMPERATURE = 0.7

# --- SiliconFlow API 配置 ---
SILICONFLOW_API_KEY = "sk-xvrbrpercvlfxkfsxsfaidnwpvfjdwouqrxsauhxbdjnkmhh" # 已填寫
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"

# --- 設備檢測 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用的設備: {device}")

# --- 初始化 SiliconFlow 客戶端 ---
if SILICONFLOW_API_KEY == "YOUR_API_KEY" or not SILICONFLOW_API_KEY: # 檢查是否為空或佔位符
    print("警告：請在代碼中設置你的 SiliconFlow API 密鑰 (SILICONFLOW_API_KEY)。")
    llm_client = None
else:
    try:
        llm_client = OpenAI(api_key=SILICONFLOW_API_KEY, base_url=SILICONFLOW_BASE_URL)
        print("SiliconFlow OpenAI 客戶端初始化成功。")
    except Exception as e:
        print(f"初始化 SiliconFlow OpenAI 客戶端時出錯: {e}")
        llm_client = None

# --- 模型加載、數據加載、檢索器類定義 ---
# (與上一版本相同，確保 DprPassageFaissRetriever 類定義存在)
def load_dpr_models(q_encoder_name, ctx_encoder_name, device):
    # ... (加載邏輯) ...
    print(f"開始加載 DPR 模型...")
    start_time = time.time()
    q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder = None, None, None, None
    try:
        print(f"  加載問題分詞器: {q_encoder_name}")
        q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(q_encoder_name)
        print(f"  加載問題編碼器: {q_encoder_name}")
        q_encoder = DPRQuestionEncoder.from_pretrained(q_encoder_name)
        q_encoder.to(device).eval()
        try:
            print(f"  嘗試加載上下文分詞器: {ctx_encoder_name}")
            ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(ctx_encoder_name)
            print(f"  嘗試加載上下文編碼器: {ctx_encoder_name}")
            ctx_encoder = DPRContextEncoder.from_pretrained(ctx_encoder_name)
            ctx_encoder.to(device).eval()
        except Exception as e_ctx:
             print(f"警告：加載上下文模型時出錯: {e_ctx}")
        end_time = time.time()
        print(f"DPR 模型加載完成（或部分完成）。耗時: {end_time - start_time:.2f} 秒")
        return q_tokenizer, q_encoder, ctx_tokenizer, ctx_encoder
    except Exception as e_q:
        print(f"加載 DPR 問題模型時出錯，無法繼續: {e_q}")
        return None, None, None, None

def load_jsonl(file_path):
    # ... (加載邏輯) ...
    data = []
    if not os.path.exists(file_path): return data
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip(): data.append(json.loads(line))
    print(f"成功加載 {len(data)} 條記錄從 {file_path}")
    return data

class DprPassageFaissRetriever:
    # ... (與上一版本相同) ...
    def __init__(self, q_tokenizer, q_encoder, faiss_index, passage_mapping, device):
        self.q_tokenizer = q_tokenizer
        self.q_encoder = q_encoder
        self.faiss_index = faiss_index
        self.passage_mapping = passage_mapping
        self.device = device
        self.vector_dim = q_encoder.config.hidden_size if q_encoder else None
        self.max_length = MAX_LENGTH
        print(f"DPR Passage Retriever 初始化完成。索引段落數: {self.faiss_index.ntotal if self.faiss_index else 'N/A'}")

    def retrieve_passages(self, query, top_k):
        if not all([query, self.faiss_index, self.q_encoder, self.q_tokenizer, self.passage_mapping, self.vector_dim]): return []
        try:
            inputs = self.q_tokenizer(query, max_length=self.max_length, padding=False, truncation=True, return_tensors='pt')
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            with torch.no_grad():
                outputs = self.q_encoder(**inputs)
                query_vector = outputs.pooler_output
            norm = torch.linalg.norm(query_vector)
            normalized_query_vector = query_vector / (norm + 1e-8)
            query_np = normalized_query_vector.cpu().numpy().astype('float32').reshape(1, -1)
            scores, indices = self.faiss_index.search(query_np, top_k)
            results = []
            valid_indices = indices[0]
            valid_scores = scores[0]
            for i, idx in enumerate(valid_indices):
                if 0 <= idx < len(self.passage_mapping):
                    original_doc_id, passage_text = self.passage_mapping[idx]
                    results.append({
                        "doc_id": original_doc_id,
                        "passage_text": passage_text,
                        "score": float(valid_scores[i]),
                        "passage_index_in_faiss": int(idx)
                    })
            return results
        except Exception as e: print(f"檢索段落過程中出錯: {e}"); return []

    def retrieve(self, query, top_n=5):
        top_k_passages = 20
        passage_results = self.retrieve_passages(query, top_k=top_k_passages)
        if not passage_results: return [-1]*top_n
        doc_scores = {}
        for p in passage_results:
            doc_id = p["doc_id"]
            score = p["score"]
            # 確保 doc_id 是有效的整數或可比較類型
            if not isinstance(doc_id, (int, np.integer)):
                 print(f"警告：在聚合時遇到非整數 doc_id: {doc_id} (類型: {type(doc_id)})")
                 continue # 跳過無效的 doc_id
            if doc_id not in doc_scores or score > doc_scores[doc_id]: doc_scores[doc_id] = score
        sorted_docs = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)
        top_n_doc_ids = [doc_id for doc_id, score in sorted_docs[:top_n]]
        while len(top_n_doc_ids) < top_n:
             top_n_doc_ids.append(-1) # 使用 -1 作為填充/錯誤標識
        # 確保返回列表中的元素都是整數
        final_ids = []
        for doc_id in top_n_doc_ids[:top_n]:
            try:
                final_ids.append(int(doc_id))
            except (ValueError, TypeError):
                print(f"警告：無法將 doc_id '{doc_id}' 轉換為整數，使用 -1 替代。")
                final_ids.append(-1) # 如果轉換失敗，也使用 -1
        return final_ids


# --- generate_answer_with_qwen 函數 (與之前相同) ---
def generate_answer_with_qwen(prompt):
    """
    使用 SiliconFlow API (通過 OpenAI 兼容接口) 調用 Qwen LLM 生成答案。
    """
    global llm_client
    if llm_client is None:
        print("錯誤：SiliconFlow 客戶端未初始化（可能缺少 API 密鑰）。")
        return "LLM 客戶端未初始化"
    try:
        response = llm_client.chat.completions.create(
            model=LLM_MODEL_NAME,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=LLM_MAX_TOKENS,
            temperature=LLM_TEMPERATURE,
            stream=False
        )
        generated_answer = response.choices[0].message.content.strip()
        # 基本的答案清理：移除可能的引號
        if generated_answer.startswith('"') and generated_answer.endswith('"'):
             generated_answer = generated_answer[1:-1]
        if generated_answer.startswith("'") and generated_answer.endswith("'"):
             generated_answer = generated_answer[1:-1]
        # 移除常見的拒絕回答或前綴
        refusal_prefixes = ["信息不足。", "根據提供的上下文信息，", "根據上下文信息，", "答案:"]
        for prefix in refusal_prefixes:
             if generated_answer.startswith(prefix):
                  generated_answer = generated_answer[len(prefix):].strip()

        return generated_answer if generated_answer else "未能生成答案" # 避免返回空字符串

    except Exception as e:
        print(f"調用 SiliconFlow API 時發生錯誤: {type(e).__name__} - {e}")
        return f"調用 LLM 時出錯" # 返回錯誤信息

# --- build_prompt 函數 (與之前相同) ---
def build_prompt(question, context_passages):
    """
    根據問題和檢索到的上下文段落構建 Prompt。
    """
    if not context_passages:
        context_str = "沒有找到相關上下文。"
    else:
        context_str = "\n\n".join(context_passages)
        if len(context_str) > MAX_CONTEXT_LENGTH:
            context_str = context_str[:MAX_CONTEXT_LENGTH] + "..."
    prompt = f"""請根據以下提供的上下文信息來回答問題。如果上下文沒有提供足夠的信息，請回答 "信息不足"。

上下文:
{context_str}

問題: {question}

答案:"""
    return prompt

# --- 主程序：生成預測文件 ---
if __name__ == "__main__":
    # 1. 加載 DPR 模型
    question_tokenizer, question_encoder, _, _ = load_dpr_models(
        QUESTION_ENCODER_NAME,
        CONTEXT_ENCODER_NAME,
        device
    )

    # 2. 加載 FAISS 索引和段落映射
    passage_faiss_index = None
    passage_mapping = []
    # ... (省略索引和映射加載邏輯) ...
    if os.path.exists(PASSAGE_FAISS_INDEX_PATH) and os.path.exists(PASSAGE_MAPPING_PATH):
        try:
            print(f"\n正在加載段落 FAISS 索引: {PASSAGE_FAISS_INDEX_PATH}")
            passage_faiss_index = faiss.read_index(PASSAGE_FAISS_INDEX_PATH)
            print(f"正在加載段落映射: {PASSAGE_MAPPING_PATH}")
            with open(PASSAGE_MAPPING_PATH, 'r') as f: passage_mapping = json.load(f)
            if passage_faiss_index.ntotal != len(passage_mapping): raise ValueError("索引和映射數量不匹配！")
            print("索引和映射加載成功。")
        except Exception as e: print(f"加載索引或映射失敗: {e}"); passage_faiss_index = None
    else: print(f"未找到索引或映射文件 ({PASSAGE_FAISS_INDEX_PATH}, {PASSAGE_MAPPING_PATH})。需要先運行索引構建步驟。")


    # 3. 檢查組件是否齊全
    if all([passage_faiss_index, question_encoder, question_tokenizer, passage_mapping, llm_client]):
        print("\n所有檢索和 LLM 組件準備就緒！")

        # 4. 創建檢索器實例
        retriever = DprPassageFaissRetriever(
            question_tokenizer,
            question_encoder,
            passage_faiss_index,
            passage_mapping,
            device
        )

        # 5. 加載目標數據集
        target_file = VAL_FILE if PREDICTION_TARGET == 'val' else TEST_FILE
        print(f"\n--- 加載目標數據集: {target_file} ---")
        target_data = load_jsonl(target_file)

        if target_data:
            print(f"\n--- 開始生成預測並寫入文件: {OUTPUT_PREDICTION_FILE} ---")
            # 6. 遍歷數據集，執行檢索、生成答案、寫入文件
            with open(OUTPUT_PREDICTION_FILE, 'w', encoding='utf-8') as outfile:
                for i, item in enumerate(tqdm(target_data, desc=f"Generating {PREDICTION_TARGET} predictions")):
                    question = item['question']
                    generated_answer = "處理時發生內部錯誤" # 默認錯誤答案
                    top_5_doc_ids = [-1]*5 # 默認錯誤ID

                    try:
                        # a. 檢索上下文段落
                        retrieved_passages_info = retriever.retrieve_passages(question, top_k=NUM_CONTEXT_PASSAGES)
                        context_texts = [p['passage_text'] for p in retrieved_passages_info]

                        # b. 檢索 Top-5 文檔 ID
                        top_5_doc_ids = retriever.retrieve(question, top_n=5) # 確保返回的是 list[int]

                        # c. 構建 Prompt
                        prompt = build_prompt(question, context_texts)

                        # d. 調用 LLM 生成答案
                        generated_answer = generate_answer_with_qwen(prompt)

                        # e. 格式化輸出
                        output_record = {
                            "question": question,
                            "answer": str(generated_answer), # 確保答案是字符串
                            "document_id": top_5_doc_ids # 確保這是整數列表
                        }

                        # *** 加入詳細的 Debug 打印 ***
                        print(f"\n--- DEBUG: Preparing to write record {i} ---")
                        print(f"  Question Type: {type(question)}, Value: '{question[:100]}...'") # 打印部分問題
                        print(f"  Answer Type: {type(generated_answer)}, Value: '{generated_answer}'")
                        # 檢查 generated_answer 是否包含錯誤信息
                        if isinstance(generated_answer, str) and generated_answer.startswith("調用 LLM 時出錯"):
                            print("  WARNING: Generated answer indicates LLM call error.")
                        print(f"  Doc IDs Type: {type(top_5_doc_ids)}, Value: {top_5_doc_ids}")
                        # 檢查 Doc IDs 列表內容
                        if isinstance(top_5_doc_ids, list):
                            print(f"  Doc IDs List Length: {len(top_5_doc_ids)}")
                            if top_5_doc_ids and len(top_5_doc_ids) > 0 and top_5_doc_ids[0] != -1 : # 檢查第一個有效ID的類型
                                print(f"  First Valid Doc ID Type: {type(top_5_doc_ids[0])}")
                        else:
                             print(f"  WARNING: Doc IDs is not a list!")
                        print(f"  Output Record Dict: {output_record}")
                        print(f"--- End Debug Info ---")
                        # *** Debug 打印結束 ***

                        # f. 寫入文件
                        outfile.write(json.dumps(output_record, ensure_ascii=False) + '\n')

                    except Exception as e:
                        print(f"\n處理問題 {i} ('{question}') 時發生嚴重錯誤: {e}")
                        # 寫入包含錯誤信息的記錄
                        error_record = {
                            "question": question,
                            "answer": f"處理時發生嚴重錯誤: {e}",
                            "document_id": [-1]*5 # 使用錯誤標識符
                        }
                        # 在寫入錯誤記錄前也打印一下，以防萬一
                        print(f"\n--- DEBUG: Writing ERROR record {i} ---")
                        print(f"  Error Record Dict: {error_record}")
                        print(f"--- End Debug Info ---")
                        outfile.write(json.dumps(error_record, ensure_ascii=False) + '\n')


            print(f"\n預測文件 {OUTPUT_PREDICTION_FILE} 生成完成。")
            print(f"下一步：使用 metrics_calculation.py 評估 {OUTPUT_PREDICTION_FILE} 文件 (如果目標是驗證集)。")
            if PREDICTION_TARGET == 'test':
                 print("請將生成的 test_predict_dpr_passage_debug.jsonl 文件用於最終提交。")

        else:
            print(f"未能加載目標數據集 {target_file}。")

    else:
        # 提供更詳細的失敗原因
        missing = []
        if not passage_faiss_index: missing.append("FAISS 索引")
        if not question_encoder: missing.append("問題編碼器")
        if not question_tokenizer: missing.append("問題分詞器")
        if not passage_mapping: missing.append("段落映射")
        if not llm_client: missing.append("LLM 客戶端 (API Key?)")
        print(f"\n未能成功準備好用於生成預測的組件。缺少或失敗：{', '.join(missing)}。請檢查之前的錯誤信息。")



使用的設備: cuda
SiliconFlow OpenAI 客戶端初始化成功。
開始加載 DPR 模型...
  加載問題分詞器: facebook/dpr-question_encoder-single-nq-base
  加載問題編碼器: facebook/dpr-question_encoder-single-nq-base


Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRCon

  嘗試加載上下文分詞器: facebook/dpr-ctx_encoder-single-nq-base
  嘗試加載上下文編碼器: facebook/dpr-ctx_encoder-single-nq-base


Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DPR 模型加載完成（或部分完成）。耗時: 2.29 秒

正在加載段落 FAISS 索引: /kaggle/working/dpr_passage_faiss_index.idx
正在加載段落映射: /kaggle/working/dpr_passage_mapping.json
索引和映射加載成功。

所有檢索和 LLM 組件準備就緒！
DPR Passage Retriever 初始化完成。索引段落數: 1932676

--- 加載目標數據集: /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl ---
成功加載 1000 條記錄從 /kaggle/input/nq10k-comp5423/data_and_code/data/val.jsonl

--- 開始生成預測並寫入文件: /kaggle/working/val_predict_dpr_passage_debug.jsonl ---


Generating val predictions:   0%|          | 1/1000 [00:09<2:39:23,  9.57s/it]


--- DEBUG: Preparing to write record 0 ---
  Question Type: <class 'str'>, Value: 'when did the british first land in north america...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，英国首先在北美登陆的时间是1607年。具体信息来自文本中的这句话：“The first successful English settlement was established in 1607.” 这里的“settlement”可以理解为登陆并建立定居点。'
  Doc IDs Type: <class 'list'>, Value: [8502, 5655, 4507, 5318, 5222]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'when did the british first land in north america', 'answer': '根据提供的上下文信息，英国首先在北美登陆的时间是1607年。具体信息来自文本中的这句话：“The first successful English settlement was established in 1607.” 这里的“settlement”可以理解为登陆并建立定居点。', 'document_id': [8502, 5655, 4507, 5318, 5222]}
--- End Debug Info ---


Generating val predictions:   0%|          | 2/1000 [00:17<2:25:45,  8.76s/it]


--- DEBUG: Preparing to write record 1 ---
  Question Type: <class 'str'>, Value: 'when did the 1st world war officially end...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，第一次世界大战于1918年11月11日正式结束。具体答案如下：

答案: 1918年11月11日'
  Doc IDs Type: <class 'list'>, Value: [10722, 2266, 10128, 3325, 2011]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'when did the 1st world war officially end', 'answer': '根据提供的上下文信息，第一次世界大战于1918年11月11日正式结束。具体答案如下：\n\n答案: 1918年11月11日', 'document_id': [10722, 2266, 10128, 3325, 2011]}
--- End Debug Info ---


Generating val predictions:   0%|          | 3/1000 [00:25<2:17:29,  8.27s/it]


--- DEBUG: Preparing to write record 2 ---
  Question Type: <class 'str'>, Value: 'who's the girl that plays the new wonder woman...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，扮演新神奇女侠的是Gal Gadot。'
  Doc IDs Type: <class 'list'>, Value: [7556, 5970, 8669, 10786, 8552]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': "who's the girl that plays the new wonder woman", 'answer': '根据提供的上下文信息，扮演新神奇女侠的是Gal Gadot。', 'document_id': [7556, 5970, 8669, 10786, 8552]}
--- End Debug Info ---


Generating val predictions:   0%|          | 4/1000 [00:34<2:23:49,  8.66s/it]


--- DEBUG: Preparing to write record 3 ---
  Question Type: <class 'str'>, Value: 'who is the director of the cia today...'
  Answer Type: <class 'str'>, Value: 'Context中提到的是美国国家情报总监（Director of National Intelligence）的现任负责人是Christopher A. Wray，但并未提及中央情报局局长（Director of CIA）。根据提供的信息，无法确定中央情报局局长的现任负责人。'
  Doc IDs Type: <class 'list'>, Value: [457, 3556, 4733, 11721, 7816]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who is the director of the cia today', 'answer': 'Context中提到的是美国国家情报总监（Director of National Intelligence）的现任负责人是Christopher A. Wray，但并未提及中央情报局局长（Director of CIA）。根据提供的信息，无法确定中央情报局局长的现任负责人。', 'document_id': [457, 3556, 4733, 11721, 7816]}
--- End Debug Info ---


Generating val predictions:   0%|          | 5/1000 [00:42<2:18:00,  8.32s/it]


--- DEBUG: Preparing to write record 4 ---
  Question Type: <class 'str'>, Value: 'who plays ben in the new fantastic four...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，Ben Grimm / The Thing 由 Toby Kebbell 扮演。'
  Doc IDs Type: <class 'list'>, Value: [8635, 11119, 1784, 11160, 9192]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who plays ben in the new fantastic four', 'answer': '根据提供的上下文信息，Ben Grimm / The Thing 由 Toby Kebbell 扮演。', 'document_id': [8635, 11119, 1784, 11160, 9192]}
--- End Debug Info ---


Generating val predictions:   1%|          | 6/1000 [00:51<2:20:42,  8.49s/it]


--- DEBUG: Preparing to write record 5 ---
  Question Type: <class 'str'>, Value: 'when is the men's ice hockey winter olympics 2018...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，2018年冬季奥运会男子冰球比赛在2018年2月14日至25日在韩国江陵举行。

答案: 2018年2月14日至25日'
  Doc IDs Type: <class 'list'>, Value: [10903, 8675, 4047, 1725, 6126]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': "when is the men's ice hockey winter olympics 2018", 'answer': '根据提供的上下文信息，2018年冬季奥运会男子冰球比赛在2018年2月14日至25日在韩国江陵举行。\n\n答案: 2018年2月14日至25日', 'document_id': [10903, 8675, 4047, 1725, 6126]}
--- End Debug Info ---


Generating val predictions:   1%|          | 7/1000 [00:58<2:15:15,  8.17s/it]


--- DEBUG: Preparing to write record 6 ---
  Question Type: <class 'str'>, Value: 'who plays fiddle on don't pass me by...'
  Answer Type: <class 'str'>, Value: '信息不足'
  Doc IDs Type: <class 'list'>, Value: [5401, 11066, 1479, 8266, 10088]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': "who plays fiddle on don't pass me by", 'answer': '信息不足', 'document_id': [5401, 11066, 1479, 8266, 10088]}
--- End Debug Info ---


Generating val predictions:   1%|          | 8/1000 [01:06<2:12:56,  8.04s/it]


--- DEBUG: Preparing to write record 7 ---
  Question Type: <class 'str'>, Value: 'who played drusilla on the young and the restless...'
  Answer Type: <class 'str'>, Value: 'Rowell played Drucilla on The Young and the Restless。'
  Doc IDs Type: <class 'list'>, Value: [11702, 501, 7345, 7087, 5144]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who played drusilla on the young and the restless', 'answer': 'Rowell played Drucilla on The Young and the Restless。', 'document_id': [11702, 501, 7345, 7087, 5144]}
--- End Debug Info ---


Generating val predictions:   1%|          | 9/1000 [01:14<2:11:57,  7.99s/it]


--- DEBUG: Preparing to write record 8 ---
  Question Type: <class 'str'>, Value: 'name of the actor who plays captain america...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，扮演美国队长（Captain America）的演员是克里斯·埃文斯（Chris Evans）。'
  Doc IDs Type: <class 'list'>, Value: [12061, 11050, 5426, 8347, 3035]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'name of the actor who plays captain america', 'answer': '根据提供的上下文信息，扮演美国队长（Captain America）的演员是克里斯·埃文斯（Chris Evans）。', 'document_id': [12061, 11050, 5426, 8347, 3035]}
--- End Debug Info ---


Generating val predictions:   1%|          | 10/1000 [01:21<2:09:46,  7.86s/it]


--- DEBUG: Preparing to write record 9 ---
  Question Type: <class 'str'>, Value: 'where is the national multiple sclerosis society located...'
  Answer Type: <class 'str'>, Value: 'The National Multiple Sclerosis Society (NMSS) is based in New York City.'
  Doc IDs Type: <class 'list'>, Value: [6708, 145, 10932, 7939, 9558]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'where is the national multiple sclerosis society located', 'answer': 'The National Multiple Sclerosis Society (NMSS) is based in New York City.', 'document_id': [6708, 145, 10932, 7939, 9558]}
--- End Debug Info ---


Generating val predictions:   1%|          | 11/1000 [01:29<2:07:11,  7.72s/it]


--- DEBUG: Preparing to write record 10 ---
  Question Type: <class 'str'>, Value: 'who is financing the curse of oak island...'
  Answer Type: <class 'str'>, Value: '信息不足'
  Doc IDs Type: <class 'list'>, Value: [8103, 6194, 10946, 3885, 7376]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who is financing the curse of oak island', 'answer': '信息不足', 'document_id': [8103, 6194, 10946, 3885, 7376]}
--- End Debug Info ---


Generating val predictions:   1%|          | 12/1000 [01:37<2:09:21,  7.86s/it]


--- DEBUG: Preparing to write record 11 ---
  Question Type: <class 'str'>, Value: 'where is the most gold stored in the world...'
  Answer Type: <class 'str'>, Value: '提供的上下文信息没有提及世界上黄金存储的具体地点。'
  Doc IDs Type: <class 'list'>, Value: [7338, 9823, 9187, 3616, 10102]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'where is the most gold stored in the world', 'answer': '提供的上下文信息没有提及世界上黄金存储的具体地点。', 'document_id': [7338, 9823, 9187, 3616, 10102]}
--- End Debug Info ---


Generating val predictions:   1%|▏         | 13/1000 [01:47<2:21:23,  8.60s/it]


--- DEBUG: Preparing to write record 12 ---
  Question Type: <class 'str'>, Value: 'what age can you get a tattoo in louisiana...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，所有50个州和哥伦比亚特区都有法律规定，接受纹身的人必须年满18岁。因此，在路易斯安那州，你可以年满18岁时进行纹身。答案是18岁。'
  Doc IDs Type: <class 'list'>, Value: [10489, 7176, 1857, 2499, 7120]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'what age can you get a tattoo in louisiana', 'answer': '根据提供的上下文信息，所有50个州和哥伦比亚特区都有法律规定，接受纹身的人必须年满18岁。因此，在路易斯安那州，你可以年满18岁时进行纹身。答案是18岁。', 'document_id': [10489, 7176, 1857, 2499, 7120]}
--- End Debug Info ---


Generating val predictions:   1%|▏         | 14/1000 [01:56<2:19:32,  8.49s/it]


--- DEBUG: Preparing to write record 13 ---
  Question Type: <class 'str'>, Value: 'who plays the cop in once upon a time...'
  Answer Type: <class 'str'>, Value: 'James Dornan plays the cop in Once Upon a Time. Specifically, he played Sheriff Graham Humbert in the ABC series Once Upon a Time (2011 -- 2013).'
  Doc IDs Type: <class 'list'>, Value: [10373, 5461, 8352, 4888, 5234]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who plays the cop in once upon a time', 'answer': 'James Dornan plays the cop in Once Upon a Time. Specifically, he played Sheriff Graham Humbert in the ABC series Once Upon a Time (2011 -- 2013).', 'document_id': [10373, 5461, 8352, 4888, 5234]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 15/1000 [02:04<2:16:36,  8.32s/it]


--- DEBUG: Preparing to write record 14 ---
  Question Type: <class 'str'>, Value: 'who played young martha may in the grinch...'
  Answer Type: <class 'str'>, Value: '信息不足'
  Doc IDs Type: <class 'list'>, Value: [8260, 6768, 3309, 1107, 1351]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who played young martha may in the grinch', 'answer': '信息不足', 'document_id': [8260, 6768, 3309, 1107, 1351]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 16/1000 [02:12<2:19:35,  8.51s/it]


--- DEBUG: Preparing to write record 15 ---
  Question Type: <class 'str'>, Value: 'when did the first harry potter book get released...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，第一本哈利·波特书籍《哈利·波特与魔法石》（Harry Potter and the Philosopher's Stone）是在1997年6月26日发布的。'
  Doc IDs Type: <class 'list'>, Value: [12044, 2256, 284, 5621, 7470]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'when did the first harry potter book get released', 'answer': "根据提供的上下文信息，第一本哈利·波特书籍《哈利·波特与魔法石》（Harry Potter and the Philosopher's Stone）是在1997年6月26日发布的。", 'document_id': [12044, 2256, 284, 5621, 7470]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 17/1000 [02:22<2:21:58,  8.67s/it]


--- DEBUG: Preparing to write record 16 ---
  Question Type: <class 'str'>, Value: 'when did mexico come under spanish colonial rule...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文，墨西哥被西班牙殖民统治始于1519年， Hernán Cortés 和他的同伴开始他们的探索和征服之旅。具体来说，Cortés 和他的队伍是在1519年2月登陆墨西哥的。因此，墨西哥正式被西班牙殖民统治的时间是1519年。'
  Doc IDs Type: <class 'list'>, Value: [7128, 5522, 8953, 10252, 10690]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'when did mexico come under spanish colonial rule', 'answer': '根据提供的上下文，墨西哥被西班牙殖民统治始于1519年， Hernán Cortés 和他的同伴开始他们的探索和征服之旅。具体来说，Cortés 和他的队伍是在1519年2月登陆墨西哥的。因此，墨西哥正式被西班牙殖民统治的时间是1519年。', 'document_id': [7128, 5522, 8953, 10252, 10690]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 18/1000 [02:29<2:17:49,  8.42s/it]


--- DEBUG: Preparing to write record 17 ---
  Question Type: <class 'str'>, Value: 'how many seats in the house and senate...'
  Answer Type: <class 'str'>, Value: '众议院有435个席位，参议院有100个席位。'
  Doc IDs Type: <class 'list'>, Value: [7224, 801, 4798, 10243, 10156]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'how many seats in the house and senate', 'answer': '众议院有435个席位，参议院有100个席位。', 'document_id': [7224, 801, 4798, 10243, 10156]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 19/1000 [02:37<2:14:44,  8.24s/it]


--- DEBUG: Preparing to write record 18 ---
  Question Type: <class 'str'>, Value: 'who plays dr jo wilson on grey's anatomy...'
  Answer Type: <class 'str'>, Value: 'Camilla Luddington plays Dr. Jo Wilson on Grey's Anatomy.'
  Doc IDs Type: <class 'list'>, Value: [8688, 5542, 10256, 6931, 9314]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': "who plays dr jo wilson on grey's anatomy", 'answer': "Camilla Luddington plays Dr. Jo Wilson on Grey's Anatomy.", 'document_id': [8688, 5542, 10256, 6931, 9314]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 20/1000 [02:46<2:17:38,  8.43s/it]


--- DEBUG: Preparing to write record 19 ---
  Question Type: <class 'str'>, Value: 'who sang the song please come to boston...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，歌曲 "Please Come to Boston" 是由 American singer-songwriter Dave Loggins 撰写并录制的。所以答案是 Dave Loggins。'
  Doc IDs Type: <class 'list'>, Value: [10214, 11512, 11066, 6936, 8732]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who sang the song please come to boston', 'answer': '根据提供的上下文信息，歌曲 "Please Come to Boston" 是由 American singer-songwriter Dave Loggins 撰写并录制的。所以答案是 Dave Loggins。', 'document_id': [10214, 11512, 11066, 6936, 8732]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 21/1000 [02:54<2:17:28,  8.43s/it]


--- DEBUG: Preparing to write record 20 ---
  Question Type: <class 'str'>, Value: 'when does the new season of the arrangement come out...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，第二季于2017年4月13日宣布续订，但具体播出日期没有提供。因此，答案是：信息不足。'
  Doc IDs Type: <class 'list'>, Value: [8741, 7444, 7064, 6264, 420]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'when does the new season of the arrangement come out', 'answer': '根据提供的上下文信息，第二季于2017年4月13日宣布续订，但具体播出日期没有提供。因此，答案是：信息不足。', 'document_id': [8741, 7444, 7064, 6264, 420]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 22/1000 [03:02<2:13:24,  8.18s/it]


--- DEBUG: Preparing to write record 21 ---
  Question Type: <class 'str'>, Value: 'the point at which a planet is the greatest distance away from the sun...'
  Answer Type: <class 'str'>, Value: '蚀日（aphelion）'
  Doc IDs Type: <class 'list'>, Value: [7744, 1623, 4672, 489, 5523]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'the point at which a planet is the greatest distance away from the sun', 'answer': '蚀日（aphelion）', 'document_id': [7744, 1623, 4672, 489, 5523]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 23/1000 [03:10<2:14:08,  8.24s/it]


--- DEBUG: Preparing to write record 22 ---
  Question Type: <class 'str'>, Value: 'who is the vice president if the president dies...'
  Answer Type: <class 'str'>, Value: '如果总统去世，副总统将接任总统职务。'
  Doc IDs Type: <class 'list'>, Value: [10656, 11725, 563, 3975, 3222]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who is the vice president if the president dies', 'answer': '如果总统去世，副总统将接任总统职务。', 'document_id': [10656, 11725, 563, 3975, 3222]}
--- End Debug Info ---


Generating val predictions:   2%|▏         | 24/1000 [03:21<2:23:23,  8.81s/it]


--- DEBUG: Preparing to write record 23 ---
  Question Type: <class 'str'>, Value: 'when do purple martins migrate to south america...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，紫 Martins 在春季从南美洲迁移到北美繁殖，因此它们在冬季在南美洲度过。具体来说，文中有提到 "Wintering in South America, purple martins migrate to North America in spring to breed." 因此，紫 Martins 迁移到南美洲的时间是冬季。'
  Doc IDs Type: <class 'list'>, Value: [4651, 4057, 8650, -1, -1]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'when do purple martins migrate to south america', 'answer': '根据提供的上下文信息，紫 Martins 在春季从南美洲迁移到北美繁殖，因此它们在冬季在南美洲度过。具体来说，文中有提到 "Wintering in South America, purple martins migrate to North America in spring to breed." 因此，紫 Martins 迁移到南美洲的时间是冬季。', 'document_id': [4651, 4057, 8650, -1, -1]}
--- End Debug Info ---


Generating val predictions:   2%|▎         | 25/1000 [03:30<2:25:14,  8.94s/it]


--- DEBUG: Preparing to write record 24 ---
  Question Type: <class 'str'>, Value: 'loss of memory due to the passage of time during which the memory trace is not used...'
  Answer Type: <class 'str'>, Value: '信息不足

根据提供的上下文信息，没有直接提到由于时间流逝导致记忆痕迹未被使用从而造成的记忆丧失。提供的信息主要讨论了时间片错误、干扰理论以及注意力缺失等问题，但没有具体提到因时间流逝而导致的记忆丧失情况。'
  Doc IDs Type: <class 'list'>, Value: [6319, 3775, 3707, 5265, 10285]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'loss of memory due to the passage of time during which the memory trace is not used', 'answer': '信息不足\n\n根据提供的上下文信息，没有直接提到由于时间流逝导致记忆痕迹未被使用从而造成的记忆丧失。提供的信息主要讨论了时间片错误、干扰理论以及注意力缺失等问题，但没有具体提到因时间流逝而导致的记忆丧失情况。', 'document_id': [6319, 3775, 3707, 5265, 10285]}
--- End Debug Info ---


Generating val predictions:   3%|▎         | 26/1000 [03:41<2:35:52,  9.60s/it]


--- DEBUG: Preparing to write record 25 ---
  Question Type: <class 'str'>, Value: 'the printmaking technique that uses acid to cut the lines of the image is called...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，使用酸来切割图像线条的版画技术是蚀刻（etching）。

蚀刻是一种版画技术，在这种技术中，金属（通常为铜、锌或钢）制成的平板覆盖着一层蜡质或丙烯酸地膜。艺术家然后使用尖锐的蚀刻针穿过地膜进行绘制。裸露的金属线条随后浸入酸性溶液（例如硝酸或氯化铁）中，酸溶液会蚀刻裸露的金属，留下线条痕迹。因此，答案是蚀刻。'
  Doc IDs Type: <class 'list'>, Value: [12001, 687, 1102, 5113, -1]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'the printmaking technique that uses acid to cut the lines of the image is called', 'answer': '根据提供的上下文信息，使用酸来切割图像线条的版画技术是蚀刻（etching）。\n\n蚀刻是一种版画技术，在这种技术中，金属（通常为铜、锌或钢）制成的平板覆盖着一层蜡质或丙烯酸地膜。艺术家然后使用尖锐的蚀刻针穿过地膜进行绘制。裸露的金属线条随后浸入酸性溶液（例如硝酸或氯化铁）中，酸溶液会蚀刻裸露的金属，留下线条痕迹。因此，答案是蚀刻。', 'document_id': [12001, 687, 1102, 5113, -1]}
--- End Debug Info ---


Generating val predictions:   3%|▎         | 27/1000 [03:50<2:35:14,  9.57s/it]


--- DEBUG: Preparing to write record 26 ---
  Question Type: <class 'str'>, Value: 'who was the original actor in walking tall...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，乔·唐·贝克（Joe Don Baker）是1973年电影《Walking Tall》的主演。因此，乔·唐·贝克是《Walking Tall》的原始演员。'
  Doc IDs Type: <class 'list'>, Value: [740, 321, 478, 4932, 5106]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who was the original actor in walking tall', 'answer': '根据提供的上下文信息，乔·唐·贝克（Joe Don Baker）是1973年电影《Walking Tall》的主演。因此，乔·唐·贝克是《Walking Tall》的原始演员。', 'document_id': [740, 321, 478, 4932, 5106]}
--- End Debug Info ---


Generating val predictions:   3%|▎         | 28/1000 [03:59<2:27:37,  9.11s/it]


--- DEBUG: Preparing to write record 27 ---
  Question Type: <class 'str'>, Value: 'who plays jamie in a walk to remember...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，没有提到任何扮演Jamie这个角色的演员。'
  Doc IDs Type: <class 'list'>, Value: [7538, 1916, 4306, 2460, 3394]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who plays jamie in a walk to remember', 'answer': '根据提供的上下文信息，没有提到任何扮演Jamie这个角色的演员。', 'document_id': [7538, 1916, 4306, 2460, 3394]}
--- End Debug Info ---


Generating val predictions:   3%|▎         | 29/1000 [04:06<2:19:01,  8.59s/it]


--- DEBUG: Preparing to write record 28 ---
  Question Type: <class 'str'>, Value: 'what channel is fox sports on dish network...'
  Answer Type: <class 'str'>, Value: '信息不足'
  Doc IDs Type: <class 'list'>, Value: [8000, 4775, 9089, 4489, 11095]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'what channel is fox sports on dish network', 'answer': '信息不足', 'document_id': [8000, 4775, 9089, 4489, 11095]}
--- End Debug Info ---


Generating val predictions:   3%|▎         | 30/1000 [04:13<2:12:54,  8.22s/it]


--- DEBUG: Preparing to write record 29 ---
  Question Type: <class 'str'>, Value: 'who won the women's college basketball championship 2018...'
  Answer Type: <class 'str'>, Value: '信息不足'
  Doc IDs Type: <class 'list'>, Value: [10444, 12018, 8239, 9500, 11719]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': "who won the women's college basketball championship 2018", 'answer': '信息不足', 'document_id': [10444, 12018, 8239, 9500, 11719]}
--- End Debug Info ---


Generating val predictions:   3%|▎         | 31/1000 [04:23<2:19:15,  8.62s/it]


--- DEBUG: Preparing to write record 30 ---
  Question Type: <class 'str'>, Value: 'who played the district attorney in law and order...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，没有明确提到谁扮演了检察官（District Attorney）的角色。上下文提到了Henry Sharpe，他是导演，但并未说明他是否扮演了检察官。因此，答案是：

信息不足'
  Doc IDs Type: <class 'list'>, Value: [7847, 9412, 8291, 948, 3327]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'who played the district attorney in law and order', 'answer': '根据提供的上下文信息，没有明确提到谁扮演了检察官（District Attorney）的角色。上下文提到了Henry Sharpe，他是导演，但并未说明他是否扮演了检察官。因此，答案是：\n\n信息不足', 'document_id': [7847, 9412, 8291, 948, 3327]}
--- End Debug Info ---


Generating val predictions:   3%|▎         | 32/1000 [04:38<2:52:45, 10.71s/it]


--- DEBUG: Preparing to write record 31 ---
  Question Type: <class 'str'>, Value: 'where is the gizzard located on the chicken...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，关于鸡的胃囊的具体位置描述不够明确。但是，可以知道鸡的胃囊是消化道的一部分，并且在“glandular stomach (腺胃) ”之后，“Then the food passes into the gizzard ( also known as the muscular stomach or ventriculus ) .” 这句话说明食物进入胃囊。结合“Gizzard of a chicken <P> The gizzard, also referred to as the ventriculus, gastric mill, and gigerium, is an organ found in the digestive tract of some animals, including archosaurs (dinosaurs including birds, pterosaurs, crocodiles and alligators), earthworms, some gastropods, some fish, and some crustaceans.” 这段描述，可以推测鸡的胃囊在腺胃之后的消化道中，但具体的相对位置如在消化道的哪个部分，信息不足。

答案：信息不足'
  Doc IDs Type: <class 'list'>, Value: [9633, 3791, 197, 8446, 11972]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'where is the gizzard located on the chicken', 'answer': '根据提供的上下文信息，关于鸡的胃囊的具体位置描述不够明确。但是，可以知道鸡

Generating val predictions:   3%|▎         | 33/1000 [04:49<2:52:37, 10.71s/it]


--- DEBUG: Preparing to write record 32 ---
  Question Type: <class 'str'>, Value: 'what year did beyonce do the super bowl...'
  Answer Type: <class 'str'>, Value: '根据提供的上下文信息，Beyoncé在Super Bowl XLVII halftime show中表演，Super Bowl XLVII是第47届超级碗，对应的年份是2013年（2013年2月3日）。因此，Beyoncé在2013年做了超级碗表演。

答案：2013年'
  Doc IDs Type: <class 'list'>, Value: [1170, 5870, 8851, 162, 10064]
  Doc IDs List Length: 5
  First Valid Doc ID Type: <class 'int'>
  Output Record Dict: {'question': 'what year did beyonce do the super bowl', 'answer': '根据提供的上下文信息，Beyoncé在Super Bowl XLVII halftime show中表演，Super Bowl XLVII是第47届超级碗，对应的年份是2013年（2013年2月3日）。因此，Beyoncé在2013年做了超级碗表演。\n\n答案：2013年', 'document_id': [1170, 5870, 8851, 162, 10064]}
--- End Debug Info ---


Generating val predictions:   3%|▎         | 33/1000 [04:56<2:24:41,  8.98s/it]


KeyboardInterrupt: 