In [None]:
!pip install pandas==2.2.3 
!pip install jupyter==1.1.1 
!pip install langchain==0.3.23 
!pip install langchain-community==0.3.21 
!pip install rich==14.0.0 
!pip install openai==1.71.0 
!pip install langchain-groq==0.3.2 
!pip install langchain-ollama==0.3.1 
!pip install faiss-gpu==1.7.2 
!pip install "numpy<2"
!pip install rouge-score 

In [None]:
pip install sentence-transformers faiss-cpu

In [None]:
import logging
import json

from rich.console import Console
from rich.logging import RichHandler

console = Console(stderr=True, record=True)
log_handler = RichHandler(rich_tracebacks=True, console=console, markup=True)
logging.basicConfig(format="%(message)s",datefmt="[%X]",handlers=[log_handler])
log = logging.getLogger("rich")
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

DEBUG: bool = False
DATASET_PATH: str = "../datasets/public_dataset.json"

USING_MODEL: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
USING_PORT: int = 8092
API_ENDPOINT: str = f"http://192.168.0.7:{USING_PORT}/v1"
API_KEY: str = "abc"

MODEL_TEMPERATURE: float = 0.3
MODEL_MAX_TOKENS: int = 128
RETRIEVE_TOP_K: int = 5

Groq API

In [None]:
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

In [None]:
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

In [None]:
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

In [None]:
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

In [None]:
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

# 改良版

In [None]:
import json
import os
import torch
import re
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from tqdm import tqdm

# 優化 CoT 提示，要求簡潔答案
cot_prompt = PromptTemplate.from_template("""
You are a research assistant. Answer the question strictly based on the provided context.
Carefully analyze the context and extract only the most relevant information.
Provide a concise, direct answer to the question in a single sentence. 
If the context gives any clues or evidence, formulate an answer based on those.
If the context provides no useful information, try your best to infer an answer from what you have.
Only respond with "I don't know" if absolutely no relevant clues can be found in the context.
Avoid explanations, introductory phrases, or unnecessary details. No extra steps or reasoning, just the final answer.
If a part of the context seems irrelevant to the question, do not include it in your answer.

Context:
{context}

Question:
{question}

Answer:
""")


def clean_text(text):
    text = re.sub(r'<[^>]+>', '', text)  # 移除 HTML 標籤
    text = re.sub(r'BIBREF\d+', '', text)  # 移除參考文獻標記
    text = re.sub(r'INLINEFORM\d+', '', text)  # 移除公式標記
    text = re.sub(r'\s+', ' ', text.strip())  # 規範空白
    return text

# 規範換行符
def normalize_newlines(text):
    return re.sub(r'\n{2,}', '<SECTION>', text)

# 按論文結構分割
def split_by_sections(text):
    sections = re.split(
        r'<SECTION>|(?=Abstract\n|Introduction\n|Related Work\n|Background\n|Data\n|Approach\n|Methodology\n|Evaluation\n|Experiments\n|Conclusion\n|Acknowledgements\n|(?:\w+\s*:::\s*.+\n))',
        text
    )
    return [s.strip() for s in sections if s.strip()]


def validate_answer(data):
    issues = []
    for item in data:
        full_text = clean_text(item['full_text'])  # 清理 full_text
        answers = item['answer']
        evidence = item['evidence']
        
        # 檢查答案
        for ans in answers:
            cleaned_ans = clean_text(ans)
            if cleaned_ans not in full_text and cleaned_ans.lower() not in full_text.lower():
                issues.append(f"Answer '{ans}' not found in full_text for title: {item['title']}")
        
        # 檢查證據
        for ev in evidence:
            cleaned_ev = clean_text(ev)
            ev_words = set(cleaned_ev.lower().split())
            full_text_words = set(full_text.lower().split())
            overlap = len(ev_words & full_text_words) / len(ev_words) if ev_words else 0
            if overlap < 0.8:
                issues.append(f"Evidence '{ev}' not found or insufficient overlap ({overlap:.2f}) in full_text for title: {item['title']}")
    
    return issues


# 創建文檔
def create_documents(data):
    all_docs = []
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=200)
    for item in data:
        cleaned_text = clean_text(item['full_text'])
        normalized_text = normalize_newlines(cleaned_text)
        sections = split_by_sections(normalized_text)
        for section in sections:
            chunks = text_splitter.split_text(section) if len(section) > 400 else [section]
            for chunk in chunks:
                doc = Document(
                    page_content=chunk,
                    metadata={
                        "title": item["title"],
                        "question": item["question"],
                        "answer": item["answer"],
                        "evidence": [clean_text(ev) for ev in item["evidence"]],
                        "section": section.split('\n')[0] if section else "Unknown",
                        "dataset": "Airbnb" if "Airbnb" in chunk else "PrivacyQA" if "PrivacyQA" in chunk else "Unknown"
                    }
                )
                all_docs.append(doc)
    return all_docs

# 載入數據
with open("datasets/public_dataset.json", "r", encoding="utf-8") as f:
    public_data = json.load(f)

# 驗證答案
issues = validate_answer(public_data)
if issues:
    print("Found issues:", issues)
else:
    print("All answers and evidence are valid.")

# 創建文檔
all_docs = create_documents(public_data)

# 初始化嵌入模型
embedding = HuggingFaceEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
)

# 構建 FAISS 向量資料庫
faiss_db = FAISS.from_documents(all_docs, embedding)
faiss_db.save_local("faiss_index")

# 優化檢索器
retriever = faiss_db.as_retriever(
    search_type="mmr",  #knn
    search_kwargs={"k": 15, "fetch_k": 40}
)

# 準備 LLM
llm = ChatGroq(
    model="llama-3.1-8b-instant",
    api_key="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",
    max_tokens=1024,
)

# 建立 RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type="stuff",
    chain_type_kwargs={"prompt": cot_prompt},
    return_source_documents=True
)

# 載入 private_dataset.json
with open("datasets/private_dataset.json", "r", encoding="utf-8") as f:
    private_data = json.load(f)

# 4. 處理問題並生成答案
submission_results = []

def extract_final_answer(cot_output):
    if "Final Answer:" in cot_output:
        match = re.search(r"Final Answer:\s*(.*?)(?:\n|$)", cot_output, re.DOTALL)
        return match.group(1).strip() if match else "I don't know"
    lines = cot_output.strip().split("\n")
    return lines[-1].strip() if lines else "I don't know"


def is_how_many_question(q):
    return q.lower().strip().startswith("how many")

for item in tqdm(private_data, desc="Processing questions"):
    question = item["question"]
    title = item["title"]

    if is_how_many_question(question):
        # 處理 'How many' 類問題，返回相關的數字答案
        answer_text = f"There are {len(public_data)} articles in the dataset."
        qa_result = qa_chain.invoke({"query": question})
        evidence_list = [doc.page_content for doc in qa_result.get("source_documents", [])]
    else:
        qa_result = qa_chain.invoke({"query": question})
        cot_output = qa_result.get("answer") or qa_result.get("result") or "I don't know"
        answer_text = extract_final_answer(cot_output)
        evidence_list = [doc.page_content for doc in qa_result.get("source_documents", [])]

    submission_results.append({
        "title": title,
        "answer": answer_text,
        "evidence": evidence_list
    })

# 儲存結果
with open("sample_submission_llama.json", "w", encoding="utf-8") as f:
    json.dump(submission_results, f, indent=2, ensure_ascii=False)

print("RAG流程完成，結果已寫入 sample_submission_llama.json")