## Step 0: Load Dependencies

In [None]:
!pip3 install transformers rank_bm25 sentence-transformers faiss-cpu jieba tqdm

In [None]:
!python3 setup_models.py

## Step 1: Intialize Models

if models haven't installed, it will check and install to "models/" directory.

In [None]:
# Step 0: Initialize Models (Colab Local)
from pathlib import Path
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification

def download_model(name, hf_id, is_classifier=True):
    save_dir = Path("models") / name
    if save_dir.exists() and any(save_dir.iterdir()):
        print(f"[✓] {name} already exists, skipping download.")
        return
    print(f"↓ Downloading {name} from HuggingFace...")
    tokenizer = AutoTokenizer.from_pretrained(hf_id)
    model_cls = AutoModelForSequenceClassification if is_classifier else AutoModel
    model = model_cls.from_pretrained(hf_id)
    save_dir.mkdir(parents=True, exist_ok=True)
    tokenizer.save_pretrained(save_dir)
    model.save_pretrained(save_dir)
    print(f"[✓] {name} saved to {save_dir}")

download_model("zhbert", "hfl/chinese-roberta-wwm-ext", is_classifier=True)
download_model("labse", "sentence-transformers/LaBSE", is_classifier=False)
download_model("cross_encoder", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", is_classifier=True)

## Step 2: Extract Paragraphs from PDF (PyMuPDF + EasyOCR fallback)

In [None]:

import fitz
import easyocr
import json
import os
from pdf2image import convert_from_path

reader = easyocr.Reader(['ch_tra', 'en'], gpu=False)

def extract_blocks_with_heuristics(pdf_path, min_block_length=40):
    doc = fitz.open(pdf_path)
    results = []
    doc_id = os.path.splitext(os.path.basename(pdf_path))[0]
    for page_num, page in enumerate(doc):
        blocks = page.get_text("blocks")
        for i, block in enumerate(sorted(blocks, key=lambda b: b[1])):
            x0, y0, x1, y1, text, *_ = block
            clean_text = text.strip().replace("\n", " ")
            if len(clean_text) >= min_block_length:
                results.append({
                    "pid": f"{doc_id}_p{page_num}_b{i}",
                    "page": page_num,
                    "bbox": [x0, y0, x1, y1],
                    "text": clean_text
                })
    return results

def fallback_ocr_easyocr(pdf_path):
    images = convert_from_path(pdf_path, dpi=300)
    results = []
    doc_id = os.path.splitext(os.path.basename(pdf_path))[0]
    for page_num, image in enumerate(images):
        ocr_result = reader.readtext(image)
        full_text = " ".join([res[1] for res in ocr_result if len(res[1].strip()) > 0])
        if full_text.strip():
            results.append({
                "pid": f"{doc_id}_ocr_{page_num}",
                "page": page_num,
                "bbox": None,
                "text": full_text.strip()
            })
    return results

def process_pdf_file(pdf_path):
    try:
        segments = extract_blocks_with_heuristics(pdf_path)
        if not segments or all(len(seg['text']) < 40 for seg in segments):
            raise ValueError("Fallback to OCR due to poor extraction.")
        return segments
    except:
        return fallback_ocr_easyocr(pdf_path)

pdf_path = "example.pdf"
results = process_pdf_file(pdf_path)
with open("clir_pipeline/outputs/structured_passages.jsonl", "w", encoding="utf-8") as f:
    for r in results:
        json.dump(r, f, ensure_ascii=False)
        f.write("\n")
print("✅ Done extracting passages.")


## Step 3: GPT Translate

In [None]:
# 🌐 Step 3: GPT-based Batch Translation with Safety Check
import openai
import json
from tqdm import tqdm

openai.api_key = "your-api-key-here"  # 替換成你的 GPT API key

#載入查詢資料 
"""路徑要調整"""
with open("/content/questions_translated_en_fixed_q1.json", "r", encoding="utf-8") as f: 
    full_queries = json.load(f)

# GPT 翻譯函式（略過 placeholder）
def translate_with_gpt(query_en, model="gpt-3.5-turbo"):
    if "EN Translation of:" in query_en:
        return query_en  # 視為未處理或 placeholder
    try:
        messages = [
            {"role": "system", "content": "You are a professional translator who translates English financial search queries into Traditional Chinese."},
            {"role": "user", "content": f"Translate this search query into Traditional Chinese: '{query_en}'"}
        ]
        response = openai.ChatCompletion.create(
            model=model,
            messages=messages,
            temperature=0,
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error translating: {query_en} -> {e}")
        return ""

# 執行翻譯
translated_output = []
for item in tqdm(full_queries):
    zh = translate_with_gpt(item["query_en"])
    item["query_zh_gpt"] = zh
    translated_output.append(item)

# 輸出翻譯檔
with open("/content/translated_queries_gpt.json", "w", encoding="utf-8") as f:
    json.dump(translated_output, f, ensure_ascii=False, indent=2)

print("✅ GPT 翻譯完成，共處理 %d 筆查詢。" % len(translated_output))

## Step 4: Run Retrieval (4 Models with Runtime Logging)

In [None]:

from run_all_retrievals import run_all_retrievals
run_all_retrievals()


## Step 5: Evaluate Retrieval Results (Top-K)

In [None]:

from evaluation_summary import evaluate_all_models
import pandas as pd

K_values = [10, 100]
csv_rows = []

for k in K_values:
    df = evaluate_all_models(
        ranking_path="outputs/retrieval_rankings.json",
        ground_truth_path="data/ground_truths_example.json",
        output_csv_path=f"outputs/evaluation_summary_k{k}.csv",
        k=k
    )
    df["TopK"] = k
    csv_rows.append(df)

final_df = pd.concat(csv_rows)
final_df.to_csv("outputs/evaluation_summary_all.csv", index=False)
final_df


## Step 6: Translation Error Impact Analysis

In [None]:

from translate_error_analysis import extract_translation_impact

impact = extract_translation_impact(
    queries_path="data/translated_query.json",
    predictions_path="outputs/retrieval_rankings.json",
    ground_truth_path="data/ground_truths_example.json"
)

for category, group in impact.items():
    print(f"\n== {category.upper()} ({len(group)} samples) ==")
    for qid, en, zh, pred, gt in group[:1]:
        print(f"QID: {qid}\nEN: {en}\nZH: {zh}\nPRED: {pred}\nGT: {gt}\n---")
