# 搜索算法流程

## 1. 数据预处理

In [1]:
# 导入必要的库
import json
import nltk
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
from tqdm import tqdm

# 下载停用词
try:
    stopwords.words('english')
except LookupError:
    nltk.download('stopwords')

english_stopwords = set(stopwords.words('english'))

# 数据路径
train_claims_path = 'data_cleaned_digital/train-claims.json'
dev_claims_path = 'data_cleaned_digital/dev-claims.json'
evidence_path = 'data_cleaned_digital/evidence.json'
test_claims_path = 'data_cleaned_digital/test-claims-unlabelled.json'

# 加载数据
with open(train_claims_path, 'r', encoding='utf-8') as f:
    train_claims = json.load(f)
with open(dev_claims_path, 'r', encoding='utf-8') as f:
    dev_claims = json.load(f)
with open(evidence_path, 'r', encoding='utf-8') as f:
    evidences = json.load(f) # evidences 是一个 dict {evidence_id: evidence_text}
with open(test_claims_path, 'r', encoding='utf-8') as f:
    test_claims = json.load(f)

print(f"Loaded {len(train_claims)} training claims.")
print(f"Loaded {len(dev_claims)} dev claims.")
print(f"Loaded {len(evidences)} evidences.")
print(f"Loaded {len(test_claims)} test claims.")

# 准备证据列表和ID列表，用于BM25和后续步骤
evidence_ids = list(evidences.keys())
evidence_texts = list(evidences.values())

def preprocess_text(text):
    # 简单的文本预处理：小写，去除停用词 (可根据需要扩展)
    tokens = text.lower().split()
    return [word for word in tokens if word not in english_stopwords]

processed_evidence_texts = [preprocess_text(text) for text in tqdm(evidence_texts, desc="Preprocessing evidence")]
print("Evidence texts preprocessed.")

  from .autonotebook import tqdm as notebook_tqdm


Loaded 1228 training claims.
Loaded 154 dev claims.
Loaded 1208827 evidences.
Loaded 153 test claims.


Preprocessing evidence: 100%|██████████| 1208827/1208827 [00:06<00:00, 179823.40it/s]

Evidence texts preprocessed.





## 2. 第一阶段：BM25 候选召回

In [2]:
# 初始化 BM25 模型
bm25 = BM25Okapi(processed_evidence_texts)
print("BM25 model initialized.")

def retrieve_bm25(claim_text, bm25_model, evidence_id_list, top_n=100):
    processed_claim = preprocess_text(claim_text)
    scores = bm25_model.get_scores(processed_claim)
    top_n_indices = np.argsort(scores)[::-1][:top_n]
    return [evidence_id_list[i] for i in top_n_indices]

# 示例：为一个dev claim获取BM25候选
if dev_claims: #确保dev_claims不为空
    sample_claim_id_bm25 = list(dev_claims.keys())[0]
    sample_claim_text_bm25 = dev_claims[sample_claim_id_bm25]['claim_text']
    bm25_candidates = retrieve_bm25(sample_claim_text_bm25, bm25, evidence_ids, top_n=100)
    print(f"BM25 candidates for dev claim '{sample_claim_id_bm25}': {len(bm25_candidates)} evidences")
    print(bm25_candidates[:5]) # 打印前5个候选
else:
    print("Dev claims are empty, skipping BM25 example.")
    bm25_candidates = [] # 为后续步骤提供空列表

BM25 model initialized.
BM25 candidates for dev claim 'claim-752': 100 evidences
['evidence-67732', 'evidence-572512', 'evidence-684667', 'evidence-452156', 'evidence-554677']


## 3. 第二阶段：Sentence-BERT 稠密检索

In [3]:
# 加载Sentence-BERT模型
sbert_model_name = 'all-MiniLM-L6-v2' # 可以选择其他预训练模型
sbert_model = SentenceTransformer(sbert_model_name)
print(f"Sentence-BERT model '{sbert_model_name}' loaded.")

def retrieve_sbert(claim_text, candidate_ids, all_evidences_dict, sbert_model, top_n=50):
    if not candidate_ids:
        return []
    candidate_texts = [all_evidences_dict[eid] for eid in candidate_ids if eid in all_evidences_dict]
    valid_candidate_ids = [eid for eid in candidate_ids if eid in all_evidences_dict]
    if not valid_candidate_ids: # 如果过滤后没有有效候选
        return []
    
    claim_embedding = sbert_model.encode(claim_text, convert_to_tensor=True)
    candidate_embeddings = sbert_model.encode(candidate_texts, convert_to_tensor=True, batch_size=32)
    
    # 计算余弦相似度
    cosine_scores = util.pytorch_cos_sim(claim_embedding, candidate_embeddings)[0]
    
    # 获取得分最高的候选
    top_results_indices = torch.topk(cosine_scores, k=min(top_n, len(valid_candidate_ids))).indices.tolist()
    return [valid_candidate_ids[i] for i in top_results_indices]

# 示例：对BM25的候选进行SBERT重排
if dev_claims and bm25_candidates: #确保dev_claims和bm25_candidates不为空
    sbert_candidates = retrieve_sbert(sample_claim_text_bm25, bm25_candidates, evidences, sbert_model, top_n=50)
    print(f"SBERT candidates after re-ranking BM25 results for dev claim '{sample_claim_id_bm25}': {len(sbert_candidates)} evidences")
    print(sbert_candidates[:5])
else:
    print("Skipping SBERT example due to empty dev claims or BM25 candidates.")
    sbert_candidates = [] # 为后续步骤提供空列表

Sentence-BERT model 'all-MiniLM-L6-v2' loaded.
SBERT candidates after re-ranking BM25 results for dev claim 'claim-752': 50 evidences
['evidence-572512', 'evidence-67732', 'evidence-452156', 'evidence-780332', 'evidence-1000365']


## 4. 第三阶段：BERT Cross-Encoder 重排序

In [4]:
# 加载Cross-Encoder模型
cross_encoder_model_name = 'cross-encoder/ms-marco-MiniLM-L-6-v2' # 适用于重排序任务
cross_encoder_tokenizer = AutoTokenizer.from_pretrained(cross_encoder_model_name)
cross_encoder_model = AutoModelForSequenceClassification.from_pretrained(cross_encoder_model_name)
cross_encoder_model.eval() # 设置为评估模式
print(f"Cross-Encoder model '{cross_encoder_model_name}' loaded.")

def rerank_cross_encoder(claim_text, candidate_ids, all_evidences_dict, tokenizer, model, top_n=5):
    if not candidate_ids:
        return []
    candidate_texts = [all_evidences_dict[eid] for eid in candidate_ids if eid in all_evidences_dict]
    valid_candidate_ids = [eid for eid in candidate_ids if eid in all_evidences_dict]
    if not valid_candidate_ids: # 如果过滤后没有有效候选
        return []
    
    # 创建句子对
    sentence_pairs = [[claim_text, cand_text] for cand_text in candidate_texts]
    
    # 对句子对进行打分
    with torch.no_grad():
        inputs = tokenizer(sentence_pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
        # 确保输入在模型期望的设备上 (如果GPU可用)
        # inputs = {k: v.to(model.device) for k, v in inputs.items()}
        scores = model(**inputs).logits.squeeze(-1) # 获取相关性得分, squeeze dim 1
        if scores.ndim == 0: # 如果只有一个候选，scores可能是一个标量
            scores = scores.unsqueeze(0)
            
    # 排序并选择top_n
    top_indices = torch.topk(scores, k=min(top_n, len(valid_candidate_ids))).indices.tolist()
    return [valid_candidate_ids[i] for i in top_indices]

# 示例：对SBERT的候选进行Cross-Encoder重排
if dev_claims and sbert_candidates: #确保dev_claims和sbert_candidates不为空
    final_top_5_evidences = rerank_cross_encoder(sample_claim_text_bm25, sbert_candidates, evidences, cross_encoder_tokenizer, cross_encoder_model, top_n=5)
    print(f"Final top 5 evidences for dev claim '{sample_claim_id_bm25}':")
    for i, eid in enumerate(final_top_5_evidences):
        print(f"{i+1}. {eid}: {evidences.get(eid, 'Evidence not found')[:100]}...")
else:
    print("Skipping Cross-Encoder example due to empty dev claims or SBERT candidates.")
    final_top_5_evidences = []

Cross-Encoder model 'cross-encoder/ms-marco-MiniLM-L-6-v2' loaded.
Final top 5 evidences for dev claim 'claim-752':
1. evidence-572512: south australia have the high power price in the world...
2. evidence-67732: citation need south australia have the high retail price for electricity in the country...
3. evidence-723533: accord to a sierra club analysis the us kemper project which be due to be online in 2017 be the most...
4. evidence-780332: industrialise country such as canada the us and australia be among the high per capita consumer of e...
5. evidence-622374: in australian state of south australia wind power champion by premier mike rann 20022011 now compris...


## 5. 最终输出Top 5 Evidence 并计算准确率 (针对Dev集)

In [5]:
def evaluate_retrieval(predictions_dict, gold_claims_dict, top_k=5):
    total_hits = 0
    total_relevant_in_gold = 0
    total_predicted = 0
    num_claims_evaluated = 0

    for claim_id, predicted_eids_list in predictions_dict.items():
        if claim_id not in gold_claims_dict or not gold_claims_dict[claim_id].get('evidences'):
            # print(f"Skipping claim {claim_id}: No gold evidences found.")
            continue
        
        num_claims_evaluated += 1
        gold_eids_set = set(gold_claims_dict[claim_id]['evidences'])
        # Ensure predicted_eids_list is a list, even if it's empty
        current_predicted_eids = predicted_eids_list if isinstance(predicted_eids_list, list) else []
        predicted_eids_top_k_set = set(current_predicted_eids[:top_k])
        
        hits = len(gold_eids_set.intersection(predicted_eids_top_k_set))
        total_hits += hits
        total_relevant_in_gold += len(gold_eids_set)
        total_predicted += len(predicted_eids_top_k_set) # Count unique predicted items up to K
        
    precision_at_k = total_hits / total_predicted if total_predicted > 0 else 0
    recall_at_k = total_hits / total_relevant_in_gold if total_relevant_in_gold > 0 else 0 
    f1_at_k = (2 * precision_at_k * recall_at_k) / (precision_at_k + recall_at_k) if (precision_at_k + recall_at_k) > 0 else 0
    
    print(f"Evaluated on {num_claims_evaluated} claims (out of {len(gold_claims_dict)} total gold claims).")
    print(f"Precision@{top_k}: {precision_at_k:.4f}")
    print(f"Recall@{top_k}:    {recall_at_k:.4f}")
    print(f"F1-score@{top_k}:  {f1_at_k:.4f}")
    return precision_at_k, recall_at_k, f1_at_k

# 对dev集进行完整评估
dev_set_predictions_pipeline = {}
if dev_claims: # 确保dev_claims不为空
    for claim_id, claim_data in tqdm(dev_claims.items(), desc="Pipeline: Processing dev claims"):
        claim_text = claim_data['claim_text']
        # 1. BM25
        bm25_cand = retrieve_bm25(claim_text, bm25, evidence_ids, top_n=100)
        if not bm25_cand:
            dev_set_predictions_pipeline[claim_id] = []
            continue
        # 2. SBERT
        sbert_cand = retrieve_sbert(claim_text, bm25_cand, evidences, sbert_model, top_n=50)
        if not sbert_cand:
            dev_set_predictions_pipeline[claim_id] = []
            continue
        # 3. Cross-Encoder
        final_eids = rerank_cross_encoder(claim_text, sbert_cand, evidences, cross_encoder_tokenizer, cross_encoder_model, top_n=5)
        dev_set_predictions_pipeline[claim_id] = final_eids

    print("\n--- Evaluation on Dev Set (Pipeline Top 5) ---")
    precision_dev, recall_dev, f1_dev = evaluate_retrieval(dev_set_predictions_pipeline, dev_claims, top_k=5)

    # 输出一个dev集的预测示例
    if dev_set_predictions_pipeline: # 确保有预测结果
        example_claim_id_dev_eval = list(dev_set_predictions_pipeline.keys())[0] 
        print(f"\nExample prediction for dev claim '{example_claim_id_dev_eval}':")
        print(f"Claim text: {dev_claims[example_claim_id_dev_eval]['claim_text']}")
        print(f"Predicted evidence IDs: {dev_set_predictions_pipeline[example_claim_id_dev_eval]}")
        print("Predicted evidence texts:")
        for i, eid in enumerate(dev_set_predictions_pipeline[example_claim_id_dev_eval]):
            print(f"  {i+1}. {eid}: {evidences.get(eid, 'Evidence ID not found')[:150]}...")
        if dev_claims[example_claim_id_dev_eval].get('evidences'):
            print(f"Gold evidence IDs: {dev_claims[example_claim_id_dev_eval]['evidences']}")
else:
    print("Dev claims data is empty. Skipping dev set evaluation.")

Pipeline: Processing dev claims:   1%|          | 1/154 [00:06<15:49,  6.21s/it]


KeyboardInterrupt: 

## 6. 生成测试集预测文件 (可选)

In [6]:
# 对test集进行预测 (与dev集流程类似，但不进行评估，因为没有标签)
test_set_final_predictions = {}
if test_claims: # 确保test_claims不为空
    for claim_id, claim_data in tqdm(test_claims.items(), desc="Pipeline: Processing test claims"):
        claim_text = claim_data['claim_text']
        bm25_cand = retrieve_bm25(claim_text, bm25, evidence_ids, top_n=100)
        if not bm25_cand:
            test_set_final_predictions[claim_id] = []
            continue
        sbert_cand = retrieve_sbert(claim_text, bm25_cand, evidences, sbert_model, top_n=50)
        if not sbert_cand:
            test_set_final_predictions[claim_id] = []
            continue
        final_eids = rerank_cross_encoder(claim_text, sbert_cand, evidences, cross_encoder_tokenizer, cross_encoder_model, top_n=5)
        test_set_final_predictions[claim_id] = final_eids

    # 保存测试集预测结果到JSON文件
    output_test_predictions_path = 'test_predictions_pipeline_final.json'
    with open(output_test_predictions_path, 'w', encoding='utf-8') as f:
        json.dump(test_set_final_predictions, f, indent=4)
    print(f"\nTest set predictions (pipeline) saved to {output_test_predictions_path}")
else:
    print("\nTest claims data is empty. Skipping test set prediction generation.")

Pipeline: Processing test claims: 100%|██████████| 153/153 [07:09<00:00,  2.81s/it]


Test set predictions (pipeline) saved to test_predictions_pipeline_final.json





In [7]:
# 7. 转换输出格式为指定结构

def convert_prediction_format(predictions_dict, claims_dict):
    """
    将预测结果转换为指定的输出格式
    
    Args:
        predictions_dict: 当前格式的预测结果，如 {"claim-1266": ["evidence-694262", ...]}
        claims_dict: 包含claim文本的字典
    
    Returns:
        转换后的预测结果字典
    """
    formatted_predictions = {}

    for claim_id, evidence_list in predictions_dict.items():
        # 确保claim_id在claims_dict中
        if claim_id in claims_dict:
            claim_text = claims_dict[claim_id]['claim_text']

            # 创建新格式的数据结构
            formatted_predictions[claim_id] = {
                "claim_text": claim_text,
                "claim_label": "SUPPORTS",  # 按要求统一设置为SUPPORTS
                "evidences": evidence_list
            }

    return formatted_predictions


# 转换dev集的预测结果
if 'dev_set_predictions_pipeline' in locals() and dev_claims:
    formatted_dev_predictions = convert_prediction_format(
        dev_set_predictions_pipeline, dev_claims)

    # 打印示例查看结果
    if formatted_dev_predictions:
        print("转换后的格式示例:")
        example_claim_id = list(formatted_dev_predictions.keys())[0]
        print(json.dumps(
            {example_claim_id: formatted_dev_predictions[example_claim_id]}, indent=4, ensure_ascii=False))

        # 保存转换后的dev预测结果
        output_formatted_dev_path = 'dev_predictions_formatted.json'
        with open(output_formatted_dev_path, 'w', encoding='utf-8') as f:
            json.dump(formatted_dev_predictions, f,
                      indent=4, ensure_ascii=False)
        print(f"\nDev集格式化预测结果已保存至: {output_formatted_dev_path}")

# 转换test集的预测结果
if 'test_set_final_predictions' in locals() and test_claims:
    formatted_test_predictions = convert_prediction_format(
        test_set_final_predictions, test_claims)

    # 保存转换后的test预测结果
    output_formatted_test_path = 'test-output.json'
    with open(output_formatted_test_path, 'w', encoding='utf-8') as f:
        json.dump(formatted_test_predictions, f, indent=4, ensure_ascii=False)
    print(f"\nTest集格式化预测结果已保存至: {output_formatted_test_path}")

转换后的格式示例:
{
    "claim-752": {
        "claim_text": "south australia have the most expensive electricity in the world",
        "claim_label": "SUPPORTS",
        "evidences": [
            "evidence-572512",
            "evidence-67732",
            "evidence-723533",
            "evidence-780332",
            "evidence-622374"
        ]
    }
}

Dev集格式化预测结果已保存至: dev_predictions_formatted.json

Test集格式化预测结果已保存至: test-output.json
