In [1]:
import json
import os

import hashlib
from typing import List, Dict, Any
from tqdm import tqdm
import sys
import concurrent.futures
import random

from get_text_embedding import get_text_embedding

from dotenv import load_dotenv
from openai import OpenAI

In [2]:
class PageChunkLoader:
    def __init__(self, json_path: str):
        self.json_path = json_path
    def load_chunks(self) -> List[Dict[str, Any]]:
        with open(self.json_path, 'r', encoding='utf-8') as f:
            return json.load(f)

In [3]:
class EmbeddingModel:
    def __init__(self, batch_size: int = 64):
        self.api_key = os.getenv('LOCAL_API_KEY')
        self.base_url = os.getenv('LOCAL_BASE_URL')
        self.embedding_model = os.getenv('LOCAL_EMBEDDING_MODEL')
        self.batch_size = batch_size
        if not self.api_key or not self.base_url:
            raise ValueError('请在.env中配置LOCAL_API_KEY和LOCAL_BASE_URL')

    def embed_texts(self, texts: List[str]) -> List[List[float]]:
        return get_text_embedding(
            texts,
            api_key=self.api_key,
            base_url=self.base_url,
            embedding_model=self.embedding_model,
            batch_size=self.batch_size
        )

    def embed_text(self, text: str) -> List[float]:
        return self.embed_texts([text])[0]

In [4]:
class SimpleVectorStore:
    def __init__(self):
        self.embeddings = []
        self.chunks = []
    def add_chunks(self, chunks: List[Dict[str, Any]], embeddings: List[List[float]]):
        self.chunks.extend(chunks)
        self.embeddings.extend(embeddings)
    def search(self, query_embedding: List[float], top_k: int = 3) -> List[Dict[str, Any]]:
        from numpy import dot
        from numpy.linalg import norm
        import numpy as np
        if not self.embeddings:
            return []
        emb_matrix = np.array(self.embeddings)
        query_emb = np.array(query_embedding)
        sims = emb_matrix @ query_emb / (norm(emb_matrix, axis=1) * norm(query_emb) + 1e-8)
        idxs = sims.argsort()[::-1][:top_k]
        return [self.chunks[i] for i in idxs]

In [5]:
class SimpleRAG:
    def __init__(self, chunk_json_path: str, model_path: str = None, batch_size: int = 32):
        self.loader = PageChunkLoader(chunk_json_path)
        self.embedding_model = EmbeddingModel(batch_size=batch_size)
        self.vector_store = SimpleVectorStore()
    def setup(self):
        print("加载所有页chunk...")
        chunks = self.loader.load_chunks()
        print(f"共加载 {len(chunks)} 个chunk")
        print("生成嵌入...")
        embeddings = self.embedding_model.embed_texts([c['content'] for c in chunks])
        print("存储向量...")
        self.vector_store.add_chunks(chunks, embeddings)
        print("RAG向量库构建完成！")
    def query(self, question: str, top_k: int = 3) -> Dict[str, Any]:
        q_emb = self.embedding_model.embed_text(question)
        results = self.vector_store.search(q_emb, top_k)
        return {
            "question": question,
            "chunks": results
        }

    def generate_answer(self, question: str, top_k: int = 3) -> Dict[str, Any]:
        """
        检索+大模型生成式回答，返回结构化结果
        """
        qwen_api_key = os.getenv('LOCAL_API_KEY')
        qwen_base_url = os.getenv('LOCAL_BASE_URL')
        qwen_model = os.getenv('LOCAL_TEXT_MODEL')
        if not qwen_api_key or not qwen_base_url or not qwen_model:
            raise ValueError('请在.env中配置LOCAL_API_KEY、LOCAL_BASE_URL、LOCAL_TEXT_MODEL')
        q_emb = self.embedding_model.embed_text(question)
        chunks = self.vector_store.search(q_emb, top_k)
        # 拼接检索内容，带上元数据
        context = "\n".join([
            f"[文件名]{c['metadata']['file_name']} [页码]{c['metadata']['page']}\n{c['content']}" for c in chunks
        ])
        # 明确要求输出JSON格式 answer/page/filename
        prompt = (
            f"你是一名专业的金融分析助手，请根据以下检索到的内容回答用户问题。\n"
            f"请严格按照如下JSON格式输出：\n"
            f'{{"answer": "你的简洁回答", "filename": "来源文件名", "page": "来源页码"}}'"\n"
            f"检索内容：\n{context}\n\n问题：{question}\n"
            f"请确保输出内容为合法JSON字符串，不要输出多余内容。"
        )
        client = OpenAI(api_key=qwen_api_key, base_url=qwen_base_url)
        completion = client.chat.completions.create(
            model=qwen_model,
            messages=[
                {"role": "system", "content": "你是一名专业的金融分析助手。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.2,
            max_tokens=1024
        )
        
        import json as pyjson
        from extract_json_array import extract_json_array
        raw = completion.choices[0].message.content.strip()
        # 用 extract_json_array 提取 JSON 对象
        json_str = extract_json_array(raw, mode='objects')
        if json_str:
            try:
                arr = pyjson.loads(json_str)
                # 只取第一个对象
                if isinstance(arr, list) and arr:
                    j = arr[0]
                    answer = j.get('answer', '')
                    filename = j.get('filename', '')
                    page = j.get('page', '')
                else:
                    answer = raw
                    filename = chunks[0]['metadata']['file_name'] if chunks else ''
                    page = chunks[0]['metadata']['page'] if chunks else ''
            except Exception:
                answer = raw
                filename = chunks[0]['metadata']['file_name'] if chunks else ''
                page = chunks[0]['metadata']['page'] if chunks else ''
        else:
            answer = raw
            filename = chunks[0]['metadata']['file_name'] if chunks else ''
            page = chunks[0]['metadata']['page'] if chunks else ''
        # 结构化输出
        return {
            "question": question,
            "answer": answer,
            "filename": filename,
            "page": page,
            "retrieval_chunks": chunks
        }



In [41]:
# 1) Imports & Paths
from pathlib import Path
import os, json, random
from tqdm.auto import tqdm
import concurrent.futures
import math

# Notebook is in .../notebook; project root is parent
NOTEBOOK_DIR = Path.cwd()
PROJ_ROOT = NOTEBOOK_DIR.parent

# Try common locations for train.json
CANDIDATE_TRAIN = [
    PROJ_ROOT / "datas" / "train.json",
    PROJ_ROOT / "data" / "train.json",
    NOTEBOOK_DIR / "datas" / "train.json",
    NOTEBOOK_DIR / "data" / "train.json",
]
TRAIN_PATH = next((p for p in CANDIDATE_TRAIN if p.exists()), None)
if TRAIN_PATH is None:
    raise FileNotFoundError(f"train.json not found in: {CANDIDATE_TRAIN}")

# Chunk JSON path (your earlier structure)
CHUNK_JSON_PATH = PROJ_ROOT / "notebook" / "sample_pdf_page_chunks.json"

# Outputs
EVAL_RAW_PATH = PROJ_ROOT / "eval_train_raw.json"
EVAL_SUMMARY_PATH = PROJ_ROOT / "eval_train_scored.json"

print("Notebook Dir:", NOTEBOOK_DIR)
print("Project Root :", PROJ_ROOT)
print("Train JSON   :", TRAIN_PATH)
print("Chunks JSON  :", CHUNK_JSON_PATH)


Notebook Dir: d:\Datawhale\Multimodal-RAG-Competitions\notebook
Project Root : d:\Datawhale\Multimodal-RAG-Competitions
Train JSON   : d:\Datawhale\Multimodal-RAG-Competitions\data\train.json
Chunks JSON  : d:\Datawhale\Multimodal-RAG-Competitions\notebook\sample_pdf_page_chunks.json


In [64]:
# ⚙️ 2) Initialize RAG
# from your_module.rag import SimpleRAG  # ← update to your actual import path

try:
    rag = SimpleRAG(str(CHUNK_JSON_PATH))
    rag.setup()
    print("RAG initialized.")
except NameError:
    raise NameError("SimpleRAG is not defined. Import your class (e.g., `from your_module.rag import SimpleRAG`).")


加载所有页chunk...
共加载 7802 个chunk
生成嵌入...


Embedding: 100%|██████████| 244/244 [05:26<00:00,  1.34s/batch]

存储向量...
RAG向量库构建完成！
RAG initialized.





In [65]:
# 3) Load train and sample
with open(TRAIN_PATH, "r", encoding="utf-8") as f:
    train_data = json.load(f)

N = len(train_data)
random.seed(42)

sample_size = max(1, math.ceil(N * 0.10))
all_idx = list(range(N))
sample_idx = sorted(random.sample(all_idx, sample_size)) if sample_size < N else all_idx

print(f"Train size = {N} | Sample size = {len(sample_idx)}")
sample_idx[:10]


Train size = 118 | Sample size = 12


[3, 13, 14, 17, 28, 31, 35, 69, 81, 86]

In [66]:
# 4) Jaccard helper
def jaccard_char(a: str, b: str) -> float:
    a = (a or "").strip()
    b = (b or "").strip()
    if not a and not b:
        return 1.0
    set_a, set_b = set(a), set(b)
    union = set_a | set_b
    inter = set_a & set_b
    return len(inter) / len(union) if union else 0.0


In [67]:
# 5) Inference
def run_one(idx):
    item = train_data[idx]
    q = item.get("question", "")
    tqdm.write(f"[{sample_idx.index(idx)+1}/{len(sample_idx)}] {q[:60]}...")
    pred = rag.generate_answer(q, top_k=5)
    return idx, pred

results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as ex:
    for out in tqdm(ex.map(run_one, sample_idx), total=len(sample_idx), desc="Infer on train sample"):
        results.append(out)

# Save raw (idx, pred) for debugging
with open(EVAL_RAW_PATH, "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=2)
print(f"Saved raw predictions to: {EVAL_RAW_PATH}")


[1/12] 联邦制药的UBT37034在超重/肥胖适应症方面取得了哪些临床前数据？...
[2/12] 根据华创证券对凌云股份的深度研究报告，请问该公司在2024年的主要产品收入占比是多少？...
[3/12] 关于凌云股份（600480）的热冲压技术应用和发展前景，能否详细解释热冲压成型工艺与冷冲压成型工艺的主要区别？...
[4/12] 关于凌云股份（600480）的德国WAG业务板块及客户情况，请问具体有哪些主要客户？...
[5/12] 广联达的数字施工业务在2020年的资金压力如何？与同行业其他企业相比，其资金压力有何特点？...
[6/12] 如何评估广联达在数字化转型过程中面临的挑战及其应对策略？...
[7/12] 如何评估广联达在数字化转型中的竞争优势？...
[8/12] 如何分析广联达（002410.SZ）在2021年的PS估值水平及其与可比公司的差异？...
[9/12] 千味央厨的餐饮大客户经营数据在2022年第三季度有何变化？...
[10/12] 千味央厨公司在2020年的毛利率受原材料价格波动影响如何？...


Infer on train sample:   0%|          | 0/12 [00:30<?, ?it/s]

[11/12] 关于伊利股份的历史发展和市场竞争，请问在2005年至2013年间，伊利如何通过创新产品和营销策略实现营收突破100亿大关...


Infer on train sample:   0%|          | 0/12 [00:35<?, ?it/s]

[12/12] 广联达（002410）的数字设计业务在2021年下半年将如何推进？...


Infer on train sample: 100%|██████████| 12/12 [01:06<00:00,  5.55s/it]

Saved raw predictions to: d:\Datawhale\Multimodal-RAG-Competitions\eval_train_raw.json





## 6) Scoring vs Ground Truth
Score per item:
- page_match: 1 if exact page equals, else 0 (×0.25)
- filename_match: 1 if exact filename equals, else 0 (×0.25)
- answer_jaccard: char Jaccard (×0.5)


In [68]:
# 6) Score predictions
idx2pred = {idx: pred for idx, pred in results}

scored_rows = []
for idx in sample_idx:
    gt = train_data[idx]
    pred = idx2pred.get(idx, {})

    gt_q = gt.get("question", "")
    gt_a = gt.get("answer", "")
    gt_f = gt.get("filename", "")
    gt_p = gt.get("page", "")

    pr_a = pred.get("answer", "")
    pr_f = pred.get("filename", "")
    pr_p = pred.get("page", "")

    page_match = 1.0 if str(pr_p) == str(gt_p) else 0.0
    filename_match = 1.0 if str(pr_f) == str(gt_f) else 0.0
    answer_sim = jaccard_char(str(pr_a), str(gt_a))

    score = 0.25 * page_match + 0.25 * filename_match + 0.5 * answer_sim

    scored_rows.append({
        "idx": idx,
        "question": gt_q,
        "gt_answer": gt_a,
        "pr_answer": pr_a,
        "gt_filename": gt_f,
        "pr_filename": pr_f,
        "gt_page": gt_p,
        "pr_page": pr_p,
        "page_match": page_match,
        "filename_match": filename_match,
        "answer_jaccard": answer_sim,
        "score": score,
    })

# Sort by score ascending to inspect weak cases first
scored_rows_sorted = sorted(scored_rows, key=lambda x: x["score"])

with open(EVAL_SUMMARY_PATH, "w", encoding="utf-8") as f:
    json.dump(scored_rows_sorted, f, ensure_ascii=False, indent=2)

print(f"Saved scored results to: {EVAL_SUMMARY_PATH}")
print(f"max score: {max(r['score'] for r in scored_rows_sorted)}")
print(f"Mean score: {sum(r['score'] for r in scored_rows_sorted)/len(scored_rows_sorted):.4f}")
print(f"min score: {min(r['score'] for r in scored_rows_sorted) }")
print(f"Mean Jaccard: {sum(r['answer_jaccard'] for r in scored_rows_sorted)/len(scored_rows_sorted):.4f}")
print(f"Filename exact@1: {sum(r['filename_match'] for r in scored_rows_sorted)/len(scored_rows_sorted):.4f}")
print(f"Page exact@1: {sum(r['page_match'] for r in scored_rows_sorted)/len(scored_rows_sorted):.4f}")


Saved scored results to: d:\Datawhale\Multimodal-RAG-Competitions\eval_train_scored.json
max score: 0.5489690721649485
Mean score: 0.2631
min score: 0.08582089552238806
Mean Jaccard: 0.3178
Filename exact@1: 0.4167
Page exact@1: 0.0000


In [69]:
# Show a couple of worst and best cases inline (adjust k as needed)
k = 1
print("— Worst cases —")
for r in scored_rows_sorted[:k]:
    print("\nScore:", r["score"])
    print("Q:", r["question"])
    print("GT:", r["gt_answer"])
    print("PR:", r["pr_answer"])
    print("GT file/page:", r["gt_filename"], r["gt_page"])
    print("PR file/page:", r["pr_filename"], r["pr_page"])

print("\n— Best cases —")
for r in scored_rows_sorted[-k:]:
    print("\nScore:", r["score"])
    print("Q:", r["question"])
    print("GT:", r["gt_answer"])
    print("PR:", r["pr_answer"])
    print("GT file/page:", r["gt_filename"], r["gt_page"])
    print("PR file/page:", r["pr_filename"], r["pr_page"])


— Worst cases —

Score: 0.08582089552238806
Q: 广联达的数字施工业务在2020年的资金压力如何？与同行业其他企业相比，其资金压力有何特点？
GT: 根据图片中的图表和文字内容，可以得出以下结论：

1. **资金压力情况**：
   - 图表35显示，龙元建设在2015年至2020年间，应收账款占比和已完工未结算资产占比均呈现上升趋势。特别是在2015年到2016年间，应收账款占比和已完工未结算资产占比都有显著增加。

2. **与其他企业的比较**：
   - 图表36显示，宏润建设在2017年至2020年间，应收账款占比和建造合同形成的已完工未结算资产占比也
PR: 2020年广联达施工业务资金压力主要源于智慧工地和解决方案类业务占比提升导致毛利率下降，但凭借造价市场的客户基础和SaaS化转型，其资金压力较同行业企业更可控。
GT file/page: 广联达-再谈广联达当前时点下如何看待其三条增长曲线-220217131页.pdf 28
PR file/page: 广联达-深度跟踪报告设计助推数字建筑一体化落地-22031838页.pdf 6

— Best cases —

Score: 0.5489690721649485
Q: 联邦制药的UBT37034在超重/肥胖适应症方面取得了哪些临床前数据？
GT: 根据图片中的文字内容，联邦制药的UBT37034在超重/肥胖适应症方面的临床前数据如下：

1. UBT37034在饮食诱导肥胖大鼠（DIO Rats）上的临床前数据表明，21天给药后，UBT37034联用替尔泊肽减重13.6%。

2. 相比之下，Petrelintide联用替尔泊肽减重13.6%（-9.38%），Cagrilintide联用替尔泊肽减
PR: UBT37034在饮食诱导肥胖大鼠（DIO Rats）上21天给药后，联用替尔泊肽减重13.6%，优于Petrelintide联用替尔泊肽（-9.38%）、Cagrilintide联用替尔泊肽（-10.89%）及替尔泊肽单药（-3.02%）效果。
GT file/page: 联邦制药-港股公司研究报告-创新突破三靶点战略联姻诺和诺德-25071225页.pdf 11
PR file/page: 联邦制药-港股公司研究报告-创新突破三靶点战略联姻诺和

### Notes
- Set a different **sample fraction** by changing the `0.10` in `math.ceil(N * 0.10)`.
- If `filename`/`page` in ground truth differ in minor formatting (e.g., case, spaces), add normalization before comparison.
- You can plug this same scorer later for validation on a dev split.
