# Legal-RAG：构建索引、检索评估

本 Notebook 包含：
1. 克隆仓库 & 安装依赖
2. 预处理民法典合同编文本
3. 构建 Hybrid（FAISS + BM25）索引
4. 多指标评估：Hit@K、MRR、nDCG、Precision、Recall
5. 检索模型对比：Hybrid、FAISS、BM25
6. 示例展示 Hybrid 检索
7. 修改问题批量测试
8. 可视化检索表现


## 1. 环境准备

In [None]:
import os, sys, json, math
os.environ['TRANSFORMERS_NO_TF'] = '1'
os.environ['USE_TF'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from pathlib import Path

ROOT = Path("/kaggle/working").resolve()
os.chdir(ROOT)
ROOT
!nvidia-smi || echo 'No GPU. Falling back to CPU.'

## 2. 克隆 Legal-RAG 仓库

In [None]:
REPO_URL = "https://github.com/Fan-Luo/Legal-RAG.git"
REPO_DIR = ROOT / "Legal-RAG"

if not REPO_DIR.exists():
    !git clone "$REPO_URL" "$REPO_DIR"
else:
    print("Repo exists, pulling latest...")
    %cd "{REPO_DIR}"
    !git pull

%cd "{REPO_DIR}"


## 3. 安装依赖

In [None]:
%%bash
set -e
pip install -q -r requirements.txt
echo 'Dependencies Installed.'

## 4. 预处理民法典合同编文本

In [None]:
RAW_DIR = Path("data/raw")
RAW_DIR.mkdir(parents=True, exist_ok=True)

print("Raw files:", list(RAW_DIR.iterdir()))
!python -m scripts.preprocess_law

print("Processed:")
!ls data/processed

## 5. 构建索引（FAISS + BM25）

In [None]:
!python -m scripts.build_index

print("Index built:")
!ls data/index

# 6. 多指标检索评估

In [None]:
from legalrag.config import AppConfig
from legalrag.retrieval.hybrid_retriever import HybridRetriever
from legalrag.retrieval.vector_store import VectorStore
from legalrag.retrieval.bm25_retriever import BM25Retriever

import pandas as pd

cfg = AppConfig.load(None)

hybrid = HybridRetriever(cfg)
vs = VectorStore(cfg)
bm25 = BM25Retriever(cfg)

EVAL_PATH = Path("data/eval/contract_law_qa.jsonl")

eval_samples = [json.loads(l) for l in open(EVAL_PATH, "r", encoding="utf8")]
print("Loaded eval samples:", len(eval_samples))

In [None]:
def mrr(ranks):
    return 1 / min(ranks) if ranks else 0.0

def ndcg_at_k(true_ids, retrieved, k):
    import math
    gains = [1 if r in true_ids else 0 for r in retrieved[:k]]
    dcg = sum(g / math.log2(i+2) for i, g in enumerate(gains))
    ideal = sorted(gains, reverse=True)
    idcg = sum(g / math.log2(i+2) for i, g in enumerate(ideal))
    return dcg / idcg if idcg > 0 else 0.0

def _get_gold(sample: dict) -> set:
    # 兼容不同版本的字段名
    gold = sample.get("target_articles") or sample.get("article_numbers") or sample.get("gold") or []
    # 统一成字符串列表（条号）
    return set(str(x).strip() for x in gold if str(x).strip())

def evaluate(name, retriever, top_k=10):
    hits, mrrs, ndcgs, precisions, recalls = [], [], [], [], []

    for sample in eval_samples:
        q = sample["question"]
        gold = _get_gold(sample)

        hs = retriever.search(q, top_k=top_k)
        
        ret = []
        for item in hs:
            # Case 1: RetrievalHit
            if hasattr(item, "chunk"):
                c = item.chunk
                if c and getattr(c, "article_id", None):
                    ret.append(str(c.article_id).strip())
            # Case 2: (LawChunk, score) tuple
            elif isinstance(item, tuple) and len(item) >= 1:
                c = item[0]
                if c and getattr(c, "article_id", None):
                    ret.append(str(c.article_id).strip())

        hits.append(int(bool(gold) and any(r in gold for r in ret)))
        ranks = [i + 1 for i, r in enumerate(ret) if r in gold]
        mrrs.append(mrr(ranks))
        ndcgs.append(ndcg_at_k(gold, ret, top_k))

        tp = len(gold.intersection(ret))
        precisions.append(tp / len(ret) if ret else 0.0)
        recalls.append(tp / len(gold) if gold else 0.0)  # gold 为空时不算 recall

    n = len(eval_samples) if eval_samples else 1
    return {
        "retriever": name,
        "Hit@K": round(sum(hits) / n, 4),
        "MRR": round(sum(mrrs) / n, 4),
        "nDCG@K": round(sum(ndcgs) / n, 4),
        "Precision@K": round(sum(precisions) / n, 4),
        "Recall@K": round(sum(recalls) / n, 4),
    }


In [None]:
results = [
    evaluate("Hybrid", hybrid),
    evaluate("FAISS", faiss_ret),
    evaluate("BM25", bm25_ret)
]
pd.DataFrame(results)

# 7. 检索可视化示例：条文命中率分布

In [None]:
import matplotlib.pyplot as plt

df = pd.DataFrame(results)
df.plot(x="retriever", kind="bar", figsize=(10,5), title="Retrieval Metrics Comparison")
plt.show()

# 8. Hybrid 检索示例

In [None]:
examples = [
    "合同约定违约金过高是否可以调整？",
    "定金和订金有何区别？",
    "买卖合同中商品质量争议如何处理？",
    "对方迟延履行，我能否解除合同？",
    "租赁合同未约定租金的法律后果是什么？"
]

def inspect(question, top_k=5):
    print("问题：", question)
    hits = hybrid.search(question, top_k=top_k)
    for h in hits:
        c = h.chunk
        print(f"- rank {h.rank} score={h.score:.3f}  {c.article_id}")
        print("  文本:", c.text[:120].replace("\n", " "), "...")
    print()

for q in examples:
    inspect(q)

# 9. 更多测试

In [None]:
custom_questions = [
    "合同一方隐瞒重要信息是否构成欺诈？",
    "保证合同是否需要书面形式？",
    "能否要求继续履行合同？"
    "对方迟延履行，我能否解除合同？",
    "合同中格式条款是否需要特别提示？",
    "订金和定金有什么法律区别？",
    "合同约定的违约金过高是否可以请求调整？",
    "货物买卖中瑕疵担保责任如何承担？",
    "合同未约定履行期限如何处理？",
    "当事人约定自动续期条款是否有效？",
    "对方拒绝履行主要债务，我可以中止履行吗？",
    "租赁合同到期承租人继续使用是否构成默示续租？",
    "承揽合同中成果不合格可以返工吗？由谁承担费用？",
    "委托合同中受托人是否可以转委托？",
    "借款合同未约定利息是否当然产生利息？",
    "担保合同中保证期间如何计算？",
    "连带保证与一般保证的区别是什么？",
    "债务人财产不足清偿时，留置权如何行使？",
    "抵押合同未登记是否对抗第三人？",
    "因重大误解订立的合同是否有效？",
    "合同未完全履行时可否部分解除？",
    "双方互负债务但先后履行顺序不明怎么办？",
    "合同相对人可以主张合同无效吗？",
    "因不可抗力不能履行是否要承担违约责任？",
    "借用合同中使用期间造成损害由谁负责？",
    "运输合同中托运人是否可以撤回货物？",
    "技术合同中成果归属如何认定？",
    "财产租赁中出租人是否承担维修义务？",
    "提前交付不合格标的物是否构成违约？",
    "线上签订的电子合同是否具有同等法律效力？",
]

for q in custom_questions:
    inspect(q)