# 第一版本

In [None]:
import json
import os
import torch
import re
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from tqdm import tqdm

# CoT 提示
cot_prompt = PromptTemplate.from_template("""
You are a research assistant. Answer the question strictly based on the provided context.
Carefully analyze the context and extract only the most relevant information.
Provide a concise, direct answer to the question in a single sentence. 
If the context gives any clues or evidence, formulate an answer based on those.
If the context provides no useful information, try your best to infer an answer from what you have.
Only respond with "I don't know" if absolutely no relevant clues can be found in the context.
Avoid explanations, introductory phrases, or unnecessary details. No extra steps or reasoning, just the final answer.
If a part of the context seems irrelevant to the question, do not include it in your answer.

Context:
{context}

Question:
{question}

Answer:
""")

def clean_text(text):
    text = re.sub(r'<[^>]+>', '', text)  # 移除 HTML 標籤
    text = re.sub(r'BIBREF\d+', '', text)  # 移除參考文獻標記
    text = re.sub(r'INLINEFORM\d+', '', text)  # 移除公式標記
    text = re.sub(r'\s+', ' ', text.strip())  # 規範空白
    return text

def normalize_newlines(text):
    return re.sub(r'\n{2,}', '<SECTION>', text)

# 按論文結構分割
def split_by_sections(text):
    sections = re.split(
        r'<SECTION>|(?=Abstract\n|Introduction\n|Related Work\n|Background\n|Data\n|Approach\n|Methodology\n|Evaluation\n|Experiments\n|Conclusion\n|Acknowledgements\n|(?:\w+\s*:::\s*.+\n))',
        text
    )
    return [s.strip() for s in sections if s.strip()]


def validate_answer(data):
    issues = []
    for item in data:
        full_text = clean_text(item['full_text'])  # 清理 full_text
        answers = item['answer']
        evidence = item['evidence']
        
        # 檢查答案
        for ans in answers:
            cleaned_ans = clean_text(ans)
            if cleaned_ans not in full_text and cleaned_ans.lower() not in full_text.lower():
                issues.append(f"Answer '{ans}' not found in full_text for title: {item['title']}")
        
        # 檢查證據
        for ev in evidence:
            cleaned_ev = clean_text(ev)
            ev_words = set(cleaned_ev.lower().split())
            full_text_words = set(full_text.lower().split())
            overlap = len(ev_words & full_text_words) / len(ev_words) if ev_words else 0
            if overlap < 0.8:
                issues.append(f"Evidence '{ev}' not found or insufficient overlap ({overlap:.2f}) in full_text for title: {item['title']}")
    
    return issues

def create_documents(data):
    all_docs = []
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=200)
    for item in data:
        cleaned_text = clean_text(item['full_text'])
        normalized_text = normalize_newlines(cleaned_text)
        sections = split_by_sections(normalized_text)
        for section in sections:
            chunks = text_splitter.split_text(section) if len(section) > 400 else [section]
            for chunk in chunks:
                doc = Document(
                    page_content=chunk,
                    metadata={
                        "title": item["title"],
                        "question": item["question"],
                        "answer": item["answer"],
                        "evidence": [clean_text(ev) for ev in item["evidence"]],
                        "section": section.split('\n')[0] if section else "Unknown",
                        "dataset": "Airbnb" if "Airbnb" in chunk else "PrivacyQA" if "PrivacyQA" in chunk else "Unknown"
                    }
                )
                all_docs.append(doc)
    return all_docs

with open("datasets/public_dataset.json", "r", encoding="utf-8") as f:
    public_data = json.load(f)

# 驗證答案
issues = validate_answer(public_data)
if issues:
    print("Found issues:", issues)
else:
    print("All answers and evidence are valid.")

all_docs = create_documents(public_data)

embedding = HuggingFaceEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
)

faiss_db = FAISS.from_documents(all_docs, embedding)
faiss_db.save_local("faiss_index")

retriever = faiss_db.as_retriever(
    search_type="mmr",  #knn
    search_kwargs={"k": 15, "fetch_k": 40}
)

llm = ChatGroq(
    model="llama-3.1-8b-instant",
    api_key="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",
    max_tokens=1024,
)

# 建立 RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type="stuff",
    chain_type_kwargs={"prompt": cot_prompt},
    return_source_documents=True
)

with open("datasets/private_dataset.json", "r", encoding="utf-8") as f:
    private_data = json.load(f)


submission_results = []

def extract_final_answer(cot_output):
    if "Final Answer:" in cot_output:
        match = re.search(r"Final Answer:\s*(.*?)(?:\n|$)", cot_output, re.DOTALL)
        return match.group(1).strip() if match else "I don't know"
    lines = cot_output.strip().split("\n")
    return lines[-1].strip() if lines else "I don't know"


def is_how_many_question(q):
    return q.lower().strip().startswith("how many")

for item in tqdm(private_data, desc="Processing questions"):
    question = item["question"]
    title = item["title"]

    if is_how_many_question(question):
        # 處理 'How many' 類問題，返回相關的數字答案
        answer_text = f"There are {len(public_data)} articles in the dataset."
        qa_result = qa_chain.invoke({"query": question})
        evidence_list = [doc.page_content for doc in qa_result.get("source_documents", [])]
    else:
        qa_result = qa_chain.invoke({"query": question})
        cot_output = qa_result.get("answer") or qa_result.get("result") or "I don't know"
        answer_text = extract_final_answer(cot_output)
        evidence_list = [doc.page_content for doc in qa_result.get("source_documents", [])]

    submission_results.append({
        "title": title,
        "answer": answer_text,
        "evidence": evidence_list
    })

with open("sample_submission_llama.json", "w", encoding="utf-8") as f:
    json.dump(submission_results, f, indent=2, ensure_ascii=False)

print("RAG流程完成，結果已寫入 sample_submission_llama.json")

# 第二版本

In [None]:
import json
import logging
from tqdm import tqdm
from langchain.docstore.document import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_groq import ChatGroq
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.retrieval import create_retrieval_chain

from transformers import pipeline

from rich.console import Console
from rich.logging import RichHandler

In [None]:
console = Console(stderr=True, record=True)
log_handler = RichHandler(rich_tracebacks=True, console=console, markup=True)
logging.basicConfig(format="%(message)s",datefmt="[%X]",handlers=[log_handler])
log = logging.getLogger("rich")
log.setLevel(logging.DEBUG)

DATASET_PATH = "datasets/public_dataset.json"
RETRIEVE_TOP_K = 12
CHUNK_SIZE = 400
CHUNK_OVERLAP = 200

with open(DATASET_PATH, "r", encoding="utf-8") as f:
    dataset = json.load(f)

In [None]:
import re

def clean_text(text: str) -> str:
    return re.sub(r"(INLINEFORM\d+|DISPLAYFORM\d+|SECREF\d+|TABREF\d+|UID\d+)", "", text)

# 對 full_text 清洗
for data in dataset:
    data["full_text"] = clean_text(data["full_text"])

In [None]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
    length_function=len,
    add_start_index=True,
)

embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")

all_docs = []
for data in dataset:
    full_text = data["full_text"]
    chunks = full_text.split("\n\n\n")[:-1]  # 避免最後一個空白段落
    docs = [Document(page_content=doc) for doc in chunks]
    splits = text_splitter.split_documents(docs)
    all_docs.extend(splits)

vector_store = FAISS.from_documents(all_docs, embedding_model)

In [None]:
llm = ChatGroq(
  model="gemma2-9b-it",
  api_key="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",
  temperature=0.4,
  max_tokens=512,
)

CHAT_TEMPLATE_RAG = (
    """human: You are an expert academic assistant. Carefully read the provided context and follow these steps:
1. Identify the key concepts and reasoning required to answer the question.
2. Synthesize a concise answer using your own words.
3. Do not copy sentences directly from the context.

context:
{context}

question:
{input}

assistant:"""
)

retrieval_qa_prompt = PromptTemplate.from_template(template=CHAT_TEMPLATE_RAG)
combine_docs_chain = create_stuff_documents_chain(llm, retrieval_qa_prompt)

rag_qa_chain = create_retrieval_chain(
    retriever = vector_store.as_retriever(
    search_type="similarity",  # 也可以改 "mmr"
    search_kwargs={"k": RETRIEVE_TOP_K}
),
    combine_docs_chain=combine_docs_chain
)

In [None]:
results = []

for i, item in tqdm(enumerate(dataset), total=len(dataset),desc="QA 生成中..."):
    question = item["question"]
    title = item["title"]

    response = rag_qa_chain.invoke({"input": question})
    
    answer = response.get("answer", "")
    evidence_chunks = [doc.page_content for doc in response.get("context", [])]

    results.append({
        "title": title,
        "answer": answer,
        "evidence": evidence_chunks,
    })

with open("sample_submission_public.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=2)
    print("Results saved to sample_submission_public.json")

# 第三版本(最終版)

In [None]:
import json
import re
import torch
from tqdm import tqdm
from langchain.docstore.document import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.llm import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_groq import ChatGroq
from langchain.docstore.document import Document

from sentence_transformers import SentenceTransformer, util

=== 初始化 Embedding 模型 ===

In [None]:
# DATASET_PATH = "datasets/public_dataset.json"
# OUTPUT_PATH = "sample_submission_public.json"

DATASET_PATH = "datasets/private_dataset.json"
OUTPUT_PATH = "sample_submission_private.json"
RETRIEVE_TOP_K = 30

embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
# embedding_model = HuggingFaceEmbeddings(model_name="intfloat/e5-large-v2")
llm = ChatGroq(
    model="meta-llama/llama-4-maverick-17b-128e-instruct",
    api_key="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",   
    temperature=0.4,
    max_tokens=256
)

=== 定義函數 ===

In [None]:
def clean_text(text: str) -> str:
    # 1. 移除顯示用標記
    text = re.sub(r"(INLINEFORM\d+|DISPLAYFORM\d+|SECREF\d+|TABREF\d+|UID\d+|FIGREF\d+)", "", text)
    
    # 2. 移除 ::: 標記或章節分隔線
    text = re.sub(r"\s*:::+\s*", "\n", text)
    
    # 3. 移除參考文獻標題
    text = re.sub(r"(?i)\nreferences\n.*", "", text, flags=re.DOTALL)
    text = re.sub(r"\(Table \d+\)|\(Figure \d+\)", "", text)  # 移除表格/圖表引用
    text = re.sub(r"\n\d+\s*$", "", text, flags=re.MULTILINE)  # 移除頁碼

    return text.strip()

sent_embed_model = SentenceTransformer("BAAI/bge-reranker-large")

def rerank_sentences_by_similarity(question, chunks, top_n=20, min_word_count=2):
    seen = set()
    sentences = []

    for chunk in chunks:
        for s in re.split(r'(?<=[.。!?])\s+', chunk):
            s = s.strip()
            word_count = len(s.split())
            
            if word_count >= min_word_count and s not in seen:
                seen.add(s)
                sentences.append(s)

    # 計算語意相似度分數
    query_embedding = sent_embed_model.encode(question, convert_to_tensor=True)
    scored = [
        (s, util.pytorch_cos_sim(query_embedding, sent_embed_model.encode(s, convert_to_tensor=True)).item())
        for s in sentences
    ]
    scored.sort(key=lambda x: x[1], reverse=True)
    return [s[0] for s in scored[:top_n]]

=== Prompt 模板 ===

In [None]:
CHAT_TEMPLATE_RAG = (
    """human: You are an academic QA assistant. Use the context to answer precisely.
Please think about the question step by step, and then answer a ***concise***, precise answer based on the context and evidence.
Please try to find the right keywords to answer the question based on the evidence or context you find.
If the answer is a name, number, or keyword, extract it directly.
Avoid vague or overly broad answers. Answer in a concise phrase.
Format your answer similarly to human-written academic answers from datasets like SQuAD or CoQA.

Context:  
{context}

Question:
{input}

assistant:"""
)

retrieval_qa_prompt = PromptTemplate.from_template(template=CHAT_TEMPLATE_RAG)

=== Evidence Confidence 模板 ===

In [None]:
# 信心判斷 prompt：請根據 evidence 判斷是否足以回答問題
CONFIDENCE_PROMPT = PromptTemplate.from_template("""
You are a QA validation model. Based on the following retrieved context and question, judge if the context provides enough information to confidently answer the question.

Context:
{context}

Question:
{question}

Respond with only "YES" or "NO".
""")
confidence_chain = LLMChain(llm=llm, prompt=CONFIDENCE_PROMPT)

=== 主迴圈 ===

In [None]:
# === 載入資料集 ===
with open(DATASET_PATH, "r") as f:
    dataset = json.load(f)

sample_submission = []

# === 迴圈處理每一題 ===
for demo_id, item in enumerate(tqdm(dataset, desc="QA 回答中...")):
    title = item["title"]
    full_text = clean_text(item["full_text"])
    question = item["question"]

    # 拆分文件段落
    documents = full_text.split("\n\n\n")
    docs = [Document(page_content=doc) for doc in documents]

    # 切 chunk
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=800,
        chunk_overlap=320,
        length_function=len,
        add_start_index=True,
    )
    docs_splits = text_splitter.split_documents(docs)

    vector_store = FAISS.from_documents(docs_splits, embedding_model)

    # === 動態 retrieval（逐步 k 增加）===
    retrieved_chunks = []
    max_k = RETRIEVE_TOP_K

    for k in range(1, max_k + 1):
        retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": k})
        topk_docs = retriever.get_relevant_documents(question)
        retrieved_chunks = [doc.page_content for doc in topk_docs]
        reranked_sentences = rerank_sentences_by_similarity(question, retrieved_chunks, top_n=20)    
        context_text = "\n".join(reranked_sentences)

        # LLM 判斷是否足夠
        judge_result = confidence_chain.run({"context": context_text, "question": question}).strip().upper()
        if "YES" in judge_result:
            break  

    # === 啟用 RAG QA chain ===
    combine_docs_chain = create_stuff_documents_chain(llm, retrieval_qa_prompt)
    rag_qa_chain = create_retrieval_chain(
        retriever=retriever,
        combine_docs_chain=combine_docs_chain
    )
    response = rag_qa_chain.invoke({"input": question})
    predicted_answer = response["answer"].strip()
    
    # 只保留不是單字的句子（整段當作一句處理）
    predicted_evidence = []
    for doc in response["context"]:
        s = doc.page_content.strip()
        if len(s.split()) >= 2:
            predicted_evidence.append(s)
    
    sample_submission.append({
        "title": title,
        "answer": predicted_answer,
        "evidence": predicted_evidence
    })

with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
    json.dump(sample_submission, f, indent=2, ensure_ascii=False)

print(f"\n成功儲存：{OUTPUT_PATH}")

# Evaluate

In [None]:
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)

def compute_evidence_rouge(gt_evidence, predicted_chunks):
    """計算一筆資料中 retrieved evidence 和 ground truth 之間的 ROUGE-L 平均 f1 分數"""
    if not predicted_chunks or not gt_evidence:
        return 0.0

    f_scores = []
    for pred in predicted_chunks:
        scores = scorer.score_multi(
            targets=gt_evidence,
            prediction=pred,
        )
        f_scores.append(scores["rougeL"].fmeasure)
    
    return sum(f_scores) / len(f_scores)

# 在迴圈結束後，整體計算所有筆數的 evidence score
total_score = 0.0
valid_count = 0

for i, item in enumerate(dataset):
    gt_evidence = item["evidence"]  # 標準答案中的 evidence sentences
    pred_evidence = sample_submission[i]["evidence"]  # 模型取出的句子

    if pred_evidence and gt_evidence:
        score = compute_evidence_rouge(gt_evidence, pred_evidence)
        total_score += score
        valid_count += 1

average_score = total_score / valid_count if valid_count > 0 else 0.0
print(f"[Total]: {valid_count} samples with valid evidence")
print(f"[Average ROUGE-L Evidence F1]: {average_score:.4f}")