# 本程式建議使用 colab 執行

# 環境設定

In [1]:
## 若執行環境為 google colab pro+請執行以下程序

# from google.colab import drive
# drive.mount('/content/drive')

## terminal 指令集，以下程序建議打開 colab pro+ 的 terminal 執行
# !curl -fsSL https://ollama.com/install.sh | sh
# !ollama serve &
# !ollama pull llama3.1:8b
# !pip install -U langchain langchain-community langchain-openai langchain-ollama tiktoken ragas sacrebleu faiss-cpu

In [None]:
import os
import pandas as pd
import numpy as np
import json
import pickle
import hashlib
import time
import logging
from datetime import datetime
from tqdm import tqdm
from typing import Dict, List, Any, Optional, Union

from langchain_community.vectorstores import FAISS
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

In [None]:
try:
    from langchain_openai import ChatOpenAI, OpenAIEmbeddings
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False

In [None]:
## Embedding 模型與 LLM 呼叫設定
# emvbedding 模型，預設使用 ollama 執行
embedding_model_init = OllamaEmbeddings(model="llama3.1:8b")

## llm 模型
llm = ChatOllama(model="llama3.1:8b", temperature=0.6, max_tokens=4096)

# 若要使用 openai api 請取消下面兩行註解，並貼上自己的 api token
# api_key="sk-proj--..."
# llm = ChatOpenAI(model="gpt-4o", temperature=1, openai_api_key=api_key, max_tokens = 4096)

In [None]:
# 向量資料庫路徑設定
VECTORSTORE_ROOT_PATH = '/content/drive/MyDrive/03_碩班研究/llm_risk/data_0511/2_pdf_extract_db/ollama_300_char/vector_store'

In [None]:
# 設定執行結果的輸出資料夾路徑
timestamp = datetime.now().strftime("%Y%m%d_%H%M")

output_dir = '/content/drive/MyDrive/03_碩班研究/llm_risk/data_0511/2_pdf_extract_db/0718_Report/chatGPT'

os.makedirs(output_dir, exist_ok=True)


# 任務文字

In [None]:
# 只傳遞給 Retriever 的 Query 文字
query_task1 = """
Extract geographic information and location entities from text. Identify locations, company facilities, and addresses in text. Prompt for a text-based geographic information extraction system
"""

query_task2 = """
Extract extreme weather events and TCFD climate risks from text. Identify climate-related risks like floods, heatwaves, and TCFD disclosures. Tool for parsing corporate disclosures to find mentions of extreme weather.
"""

In [None]:
# 傳給 LLM 的 prompt 內容
TASK1_ZERO_SHOT = """
You are a geographic information extraction system. Your task is to identify and extract geographic entities from the given text.

Extract the following types of geographic information:
- Countries, cities, provinces, regions
- Company facilities (plants, factories, offices, sites)
- Addresses and specific locations

Return your answer in the following JSON format ONLY:
{
  "detection": 1 or 0,
  "entities": [
    {
      "text": "extracted entity text",
      "type": "location/facility/address",
      "page": page_number
    }
  ]
}

Rules:
- detection: 1 if ANY geographic entities found, 0 if NONE found
- Only return the JSON format above, no other text
- If no entities found, return: {"detection": 0, "entities": []}
"""

TASK1_ONE_SHOT = """
You are a geographic information extraction system. Your task is to identify and extract geographic entities from the given text.

Extract the following types of geographic information:
- Countries, cities, provinces, regions
- Company facilities (plants, factories, offices, sites)
- Addresses and specific locations

Example:
Input: "Our headquarters is located in Taipei, Taiwan. We operate manufacturing plants in Suzhou, China and have a regional office in Singapore."
Output: {
  "detection": 1,
  "entities": [
    {"text": "Taipei", "type": "location", "page": 1},
    {"text": "Taiwan", "type": "location", "page": 1},
    {"text": "Suzhou", "type": "location", "page": 1},
    {"text": "China", "type": "location", "page": 1},
    {"text": "Singapore", "type": "location", "page": 1},
    {"text": "headquarters", "type": "facility", "page": 1},
    {"text": "manufacturing plants", "type": "facility", "page": 1},
    {"text": "regional office", "type": "facility", "page": 1}
  ]
}

Return your answer in the following JSON format ONLY:
{
  "detection": 1 or 0,
  "entities": [
    {
      "text": "extracted entity text",
      "type": "location/facility/address",
      "page": page_number
    }
  ]
}

Rules:
- detection: 1 if ANY geographic entities found, 0 if NONE found
- Only return the JSON format above, no other text
- If no entities found, return: {"detection": 0, "entities": []}
"""

TASK2_ZERO_SHOT = """
You are a climate risk disclosure analyzer. Your task is to identify extreme weather events and TCFD-related climate risk information.

Identify the following extreme weather events:
- Heat waves, cold waves, typhoons, floods, droughts, landslides, sea level rise

Also identify TCFD disclosure elements:
- Risk identification, risk assessment, financial impact, response strategies

Return your answer in the following JSON format ONLY:
{
  "detection": 1 or 0,
  "events": [
    {
      "text": "extracted event/disclosure text",
      "type": "weather_event/tcfd_element",
      "category": "specific category",
      "page": page_number
    }
  ]
}

Rules:
- detection: 1 if ANY extreme weather events or TCFD elements found, 0 if NONE found
- Only return the JSON format above, no other text
- If nothing found, return: {"detection": 0, "events": []}
"""

TASK2_ONE_SHOT = """
You are a climate risk disclosure analyzer. Your task is to identify extreme weather events and TCFD-related climate risk information.

Identify the following extreme weather events:
- Heat waves, cold waves, typhoons, floods, droughts, landslides, sea level rise

Also identify TCFD disclosure elements:
- Risk identification, risk assessment, financial impact, response strategies

Example:
Input: "Climate change poses significant risks to our operations. Extreme weather events such as typhoons and floods may disrupt our supply chain. We assess physical risks including heat waves that could affect our facilities."
Output: {
  "detection": 1,
  "events": [
    {"text": "typhoons", "type": "weather_event", "category": "typhoon", "page": 1},
    {"text": "floods", "type": "weather_event", "category": "flood", "page": 1},
    {"text": "heat waves", "type": "weather_event", "category": "heat_wave", "page": 1},
    {"text": "assess physical risks", "type": "tcfd_element", "category": "risk_assessment", "page": 1},
    {"text": "risks to our operations", "type": "tcfd_element", "category": "risk_identification", "page": 1}
  ]
}

Return your answer in the following JSON format ONLY:
{
  "detection": 1 or 0,
  "events": [
    {
      "text": "extracted event/disclosure text",
      "type": "weather_event/tcfd_element",
      "category": "specific category",
      "page": page_number
    }
  ]
}

Rules:
- detection: 1 if ANY extreme weather events or TCFD elements found, 0 if NONE found
- Only return the JSON format above, no other text
- If nothing found, return: {"detection": 0, "events": []}
"""

# 功能函數

## 取得檔案、資料排序

In [None]:
def natural_sort(s):
    import re
    return [int(text) if text.isdigit() else text.lower()
            for text in re.split('([0-9]+)', s)]

def get_pdf_list(vectorstore_path):
    pdf_vectorstore_list = sorted(os.listdir(vectorstore_path), key=natural_sort)
    if '.DS_Store' in pdf_vectorstore_list:
        pdf_vectorstore_list.remove('.DS_Store')
    return pdf_vectorstore_list

## RAG Pipeline

In [None]:
import json
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

def rag_qa_pipeline(pdfname, question, llm, limited_token=4096, query=None, QUERY_OPT=False, vectorstore_path = VECTORSTORE_ROOT_PATH, embedding_model = embedding_model_init, search_kwargs_init = 20):
  # llm 跟 embedding 都是在外部初始化匯入

  Sys_Template = f"""
  You are a professional information extraction assistant.
  Please analyze the provided context and answer the question based ONLY on the given context.
  Context:
  {{context}}

  Question:
  {{question}}

  Answer (in English, JSON format only)
  """

      # 合併多個答案的模板
  Merge_Template = f"""
  You are a professional information synthesis assistant.
  Please synthesize the following partial answers into one comprehensive, coherent answer.

  Partial Answers:
  {{partial_answers}}

  Original Question:
  {{question}}

  Final Answer (in English, JSON format only):
  """
  # 匯入向量資料庫
  vectorstore = FAISS.load_local(
      folder_path = os.path.join(vectorstore_path, pdfname),
      embeddings = embedding_model,
      allow_dangerous_deserialization=True
  )

  # 設定檢索器
  retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": search_kwargs_init})

  # 建立問答執行流程

  if QUERY_OPT == True:
    print(f'{pdfname} 執行加入額外索引問答的方式')

    prompt = ChatPromptTemplate.from_template(Sys_Template)
    setup_and_retrieval = RunnableParallel(
        # 研究設計編號 2 的執行流程在這裡
        # !!! 這一行的 context 就是獨立丟一個 query 給 retriever 進行查詢。
        context = (itemgetter("retriever_query") | retriever),
        question = itemgetter("question")
    )
    setup_result = setup_and_retrieval.invoke({
        "retriever_query":query,
        "question":question
    })
  else:
    prompt = ChatPromptTemplate.from_template(Sys_Template)
    # 研究設計編號 1 的執行流程在這裡
    setup_and_retrieval = RunnableParallel(
        {"context": retriever, "question": RunnablePassthrough()}
    )
    setup_result = setup_and_retrieval.invoke(question)

  # 針對輸入 llm 含有完整上下文的 prompt 計算 token 長度

  formatted_prompt = prompt.format(context = setup_result["context"], question= setup_result["question"])
  prompt_tokens = llm.get_num_tokens(formatted_prompt)
  if prompt_tokens > limited_token:
    print(f"Token 數量 ({prompt_tokens}) 超出限制 ({limited_token})，開始分批處理")
    # 分批處理 context
    context_docs = setup_result["context"]
    question_text = setup_result["question"]

    # 計算基礎 prompt 的 token 數量（不包含 context）
    base_prompt = prompt.format(context="", question=question_text)
    base_tokens = llm.get_num_tokens(base_prompt)

    # 計算每批 context 可用的 token 數量
    available_tokens = limited_token - base_tokens - 100

    # 將 context 分批
    batches = []
    current_batch = []
    current_batch_tokens = 0

    for doc in context_docs:
      doc_text = doc.page_content if hasattr(doc, 'page_content') else str(doc)
      doc_tokens = llm.get_num_tokens(doc_text)
      # 檢查是否需要開始新批次
      if current_batch_tokens + doc_tokens > available_tokens and current_batch:
        batches.append(current_batch)
        current_batch = [doc]
        current_batch_tokens = doc_tokens
      else:
        current_batch.append(doc)
        current_batch_tokens += doc_tokens

    if current_batch:
      batches.append(current_batch)

    print(f"將 context 分為 {len(batches)} 批次處理")

    partial_answers = []
    for i, batch in enumerate(batches):
      # print(f"處理第 {i+1}/{len(batches)} 批次")
      # 建立本批次的 qa_chain
            batch_setup_and_retrieval = RunnableParallel(
                context=lambda x: batch,
                question=lambda x: question_text
            )

            batch_qa_chain = batch_setup_and_retrieval.assign(
                answer=(lambda x: {"context": x["context"], "question": x["question"]}) | prompt | llm | StrOutputParser()
            )

            # 執行本批次問答
            batch_response = batch_qa_chain.invoke({})
            partial_answers.append(batch_response["answer"])

    # 合併所有部分答案
    print("合併所有部分答案")
    merge_prompt = ChatPromptTemplate.from_template(Merge_Template)
    merge_chain = merge_prompt | llm | StrOutputParser()

    final_answer = merge_chain.invoke({
        "partial_answers": "\n\n".join([f"Answer {i+1}: {ans}" for i, ans in enumerate(partial_answers)]),
        "question": question_text
    })

    # 組成最終 response
    response = {
        "context": setup_result["context"],
        "question": question_text,
        "answer": final_answer
    }


  else:
    print(f"Token 數量 ({prompt_tokens}) 在限制內，使用正常流程")
    qa_chain = setup_and_retrieval.assign(answer = (lambda x: {"context": x["context"], "question": x["question"]}) | prompt | llm | StrOutputParser())
    # 執行問答
    if QUERY_OPT == True:
      response = qa_chain.invoke({
          "retriever_query": query,
          "question":question
      })
    else:
      response = qa_chain.invoke(question)

  print(f'{pdfname}問答完成')

  return response

## 解析 RAG 回傳結果（Response）

In [None]:
# llm 回傳的 json 格式
def safe_json_parse(text):
    """Safely parse JSON response from LLM"""
    try:
        # Clean the text - remove any markdown formatting or extra text
        text = text.strip()
        if text.startswith('```json'):
            text = text[7:]
        if text.endswith('```'):
            text = text[:-3]
        text = text.strip()

        # Find JSON object boundaries
        start_idx = text.find('{')
        end_idx = text.rfind('}') + 1

        if start_idx != -1 and end_idx > start_idx:
            json_text = text[start_idx:end_idx]
            return json.loads(json_text)
        else:
            return {"detection": 0, "entities": [], "events": []}
    except:
        return {"detection": 0, "entities": [], "events": []}

In [None]:
# 快速計算 Performance
def calculate_performance_metrics(ground_truth, predicted):
    """
    Calculate TP, FP, FN, TN for evaluation
    """
    # Convert to binary
    gt_binary = 1 if ground_truth == 1 else 0
    pred_binary = 1 if predicted == 1 else 0

    if gt_binary == 1 and pred_binary == 1:
        return 'TP'
    elif gt_binary == 0 and pred_binary == 1:
        return 'FP'
    elif gt_binary == 1 and pred_binary == 0:
        return 'FN'
    else:  # gt_binary == 0 and pred_binary == 0
        return 'TN'

In [None]:
# 解析 rag_qa_pipeline 回傳內容的執行程式
# 會完整的匯出所有 retriever 提供的 context ，並且辨別 LLM 回傳的資料哪一筆其實是幻覺
import re
from typing import List, Dict, Tuple

def decode_response(response, task_name):
    contexts = response['context']
    llm_answer = response['answer']
    parsed_result = safe_json_parse(llm_answer)

    pdf_name = f"{contexts[0].metadata['company']}_{contexts[0].metadata['year']}"


    retrieval_records = []
    for doc_idx, doc in enumerate(contexts):
      pdfname = pdf_name
      query_id = f"{task_name}_{pdfname}"
      content_hash = hashlib.md5(doc.page_content.encode()).hexdigest()[:8]
      doc_id = f"{doc.metadata.get('company', 'unknown')}_{doc.metadata.get('year', 'unknown')}_{doc.metadata.get('page', 'unknown')}_{content_hash}"

      retrieval_record = {
          'query_id': query_id,
          'doc_index': doc_idx,
          'doc_id': doc_id,
          'content_hash': content_hash,
          'pdf_name': pdf_name,
          'company': doc.metadata.get('company', 'unknown'),
          'year': doc.metadata.get('year', 'unknown'),
          'page': doc.metadata.get('page', 'unknown'),
          'standard': doc.metadata.get('standard', 'unknown'),
          'label': doc.metadata.get('label', 'unknown'),
          'geog_ans_ground_truth': doc.metadata.get('geog_ans', None),
          'org_ans_ground_truth': doc.metadata.get('org_ans', None),
          'events_ans_ground_truth': doc.metadata.get('event_ans', None),
          'content': doc.page_content,
          'content_length': len(doc.page_content)
      }
      retrieval_records.append(retrieval_record)

    doc_metadata = []
    if task_name == 'Task1_ZeroShot' or task_name == 'Task1_OneShot':
        entity_key = 'entities'
    else:
        entity_key = 'events'

    detection = parsed_result.get('detection', 0)
    entities_or_events = parsed_result.get(entity_key, [])

    page_list = [doc.metadata['page'] for doc in contexts]
    text_list = [doc.page_content for doc in contexts]

    def preprocess_text(text: str) -> str:
        text = re.sub(r'\s+', ' ', text.strip())
        text = text.lower()
        return text

    def find_exact_substring(llm_text: str, original_text: str) -> Tuple[bool, str]:
        llm_processed = preprocess_text(llm_text)
        original_processed = preprocess_text(original_text)
        if llm_processed in original_processed:
            return True
        if original_processed in llm_processed:
            return True
        for i in range(len(llm_processed) - 20):
            substring = llm_processed[i:i+20]
            if substring in original_processed:
                return True
        return False

    # 為每個 context 處理
    for context in contexts:
        metadata = context.metadata

        # 計算 ground_truth
        if task_name == 'Task1_ZeroShot' or task_name == 'Task1_OneShot':
            if metadata['geog_ans'] == 1 or metadata['org_ans'] == 1:
                ground_truth = 1
            else:
                ground_truth = 0
        else:
            ground_truth = metadata['event_ans']

        # 初始化基本 doc_info
        doc_info = {
            'pdf_name': pdf_name,
            'company': metadata['company'],
            'year': metadata['year'],
            'standard': metadata['standard'],
            'page': metadata['page'],
            'label': metadata['label'],
            "llm_page": 0,
            'llm_ans': 0,
            'llm_detection': detection,  # 統一設置 detection
            'ground_truth': ground_truth,
            'performance': None,
            "results": '',
            'type': "",
            'num_retrieved_docs': len(contexts),
            'raw_llm_response': llm_answer,
            'task_type': task_name
        }

        # 標記是否找到匹配
        found_match = False

        if entities_or_events:
            # print('開始對答案')

            for item in entities_or_events:
                page_num = item.get('page', 'unknown')

                # 情況1: 頁數完全匹配
                if page_num == metadata['page']:
                    # print('成功找出對應頁數')
                    predicted = 1 if item.get('text', '').strip() else 0
                    performance = calculate_performance_metrics(ground_truth, predicted) if ground_truth is not None else 'Unknown'

                    doc_info.update({
                        'llm_page': page_num,
                        'llm_ans': predicted,
                        'performance': performance,
                        'results': item.get('text', ''),
                        'type': item.get('type', '')
                    })
                    found_match = True
                    break  # 找到匹配就跳出，避免重複

                # 情況2: 頁數未知，嘗試文本匹配
                elif page_num == 'unknown':
                    # print('檢查文本匹配')
                    if find_exact_substring(item.get('text', ''), context.page_content):
                        # print('通過文本匹配找到對應')
                        predicted = 1 if item.get('text', '').strip() else 0
                        performance = calculate_performance_metrics(ground_truth, predicted) if ground_truth is not None else 'Unknown'

                        doc_info.update({
                            'llm_page': 'unknown',
                            'llm_ans': predicted,
                            'performance': performance,
                            'results': item.get('text', ''),
                            'type': item.get('type', '')
                        })
                        found_match = True
                        break  # 找到匹配就跳出

            # 如果沒有找到匹配，檢查是否有幻覺情況需要記錄
            if not found_match:
                # 檢查是否有針對不存在頁數的回答（幻覺）
                for item in entities_or_events:
                    page_num = item.get('page', 'unknown')
                    if page_num != 'unknown' and page_num not in page_list:
                        # print('檢測到幻覺，但不為當前 context 創建記錄')
                        # 不在這裡處理幻覺，避免重複
                        pass

                # 沒有找到任何匹配，使用預設值
                # print('本輪 metadata 沒有跟 llm 輸出對應')
                performance = calculate_performance_metrics(ground_truth, 0) if ground_truth is not None else 'Unknown'
                doc_info['performance'] = performance

        else:
            # print('跳過對答案 - 沒有 entities_or_events')
            performance = calculate_performance_metrics(ground_truth, 0) if ground_truth is not None else 'Unknown'
            doc_info['performance'] = performance

        # 每個 context 只加入一筆記錄
        doc_metadata.append(doc_info.copy())  # 使用 copy() 避免引用問題

    # 單獨處理幻覺情況（避免重複）
    hallucination_items = []
    for item in entities_or_events:
        page_num = item.get('page', 'unknown')
        if page_num != 'unknown' and page_num not in page_list:
            hallucination_items.append(item)

    # 為幻覺項目創建單獨記錄
    for item in hallucination_items:
        print('處理幻覺項目')
        first_metadata = contexts[0].metadata  # 使用第一個 context 的基本信息

        doc_info = {
            'pdf_name': pdf_name,
            'company': first_metadata['company'],
            'year': first_metadata['year'],
            'standard': "",
            'page': "",
            'label': "",
            "llm_page": item.get('page', 'unknown'),
            'llm_ans': 1 if item.get('text', '').strip() else 0,
            'llm_detection': detection,
            'ground_truth': 0,  # 幻覺情況 ground_truth 為 0
            'performance': calculate_performance_metrics(0, 1),
            "results": item.get('text', ''),
            'type': item.get('type', ''),
            'num_retrieved_docs': len(contexts),
            'raw_llm_response': llm_answer,
            'task_type': task_name
        }
        doc_metadata.append(doc_info)

    return doc_metadata, retrieval_records

In [None]:
# 合併本輪問答問題所有取得的 context 內容
def create_lookup_index(retrieval_records):
    """
    Create lookup indices for easy querying
    """
    # Create DataFrame for easy querying
    retrieval_df = pd.DataFrame(retrieval_records)

    # Create various lookup indices
    lookup_indices = {
        'by_query_id': retrieval_df.groupby('query_id').apply(lambda x: x.to_dict('records')).to_dict(),
        'by_doc_id': retrieval_df.groupby('doc_id').apply(lambda x: x.to_dict('records')).to_dict(),
        'by_pdf_name': retrieval_df.groupby('pdf_name').apply(lambda x: x.to_dict('records')).to_dict(),
        'by_company_year': retrieval_df.groupby(['company', 'year']).apply(lambda x: x.to_dict('records')).to_dict(),
        'by_performance': retrieval_df.groupby(['geog_ans_ground_truth', 'events_ans_ground_truth']).apply(lambda x: x.to_dict('records')).to_dict()
    }

    return lookup_indices, retrieval_df

In [None]:
# 將前面函數計算的 Performance 轉換成 F1-Score 需要的各項參數並進行計算
def analyze_performance_detailed(results_df, retrieval_df, task_name):
    """
    Create detailed performance analysis
    """
    analysis = {}

    # Overall performance metrics
    if 'performance' in results_df.columns:
        performance_counts = results_df['performance'].value_counts()
        total = len(results_df)

        analysis['confusion_matrix'] = {
            'TP': performance_counts.get('TP', 0),
            'FP': performance_counts.get('FP', 0),
            'FN': performance_counts.get('FN', 0),
            'TN': performance_counts.get('TN', 0)
        }

        # Calculate metrics
        tp = performance_counts.get('TP', 0)
        fp = performance_counts.get('FP', 0)
        fn = performance_counts.get('FN', 0)
        tn = performance_counts.get('TN', 0)

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        accuracy = (tp + tn) / total if total > 0 else 0

        analysis['metrics'] = {
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'accuracy': accuracy
        }

    # Performance by company
    if 'company' in results_df.columns and 'performance' in results_df.columns:
        company_performance = results_df.groupby('company')['performance'].value_counts().unstack(fill_value=0)
        analysis['by_company'] = company_performance.to_dict()

    # Performance by standard
    if 'standard' in results_df.columns and 'performance' in results_df.columns:
        standard_performance = results_df.groupby('standard')['performance'].value_counts().unstack(fill_value=0)
        analysis['by_standard'] = standard_performance.to_dict()

    return analysis

In [None]:
# =================== LOOKUP AND QUERY FUNCTIONS ===================
def lookup_retrieval_by_query_id(lookup_indices, query_id):
    """查詢特定 query_id 的檢索結果"""
    return lookup_indices['by_query_id'].get(query_id, [])

def lookup_retrieval_by_pdf_name(lookup_indices, pdf_name):
    """查詢特定PDF的所有檢索結果"""
    return lookup_indices['by_pdf_name'].get(pdf_name, [])

def lookup_retrieval_by_doc_id(lookup_indices, doc_id):
    """根據文件ID查詢檢索記錄"""
    return lookup_indices['by_doc_id'].get(doc_id, [])

def find_documents_with_ground_truth(retrieval_df, geog_ans=None, events_ans=None):
    """尋找有特定ground truth標籤的文件"""
    query = retrieval_df
    if geog_ans is not None:
        query = query[query['geog_ans_ground_truth'] == geog_ans]
    if events_ans is not None:
        query = query[query['events_ans_ground_truth'] == events_ans]
    return query

def cross_reference_results(results_df, retrieval_df, query_id):
    """交叉參考特定查詢的結果和檢索記錄"""
    result_records = results_df[results_df['query_id'] == query_id]
    retrieval_records = retrieval_df[retrieval_df['query_id'] == query_id]

    return {
        'llm_results': result_records.to_dict('records'),
        'retrieved_documents': retrieval_records.to_dict('records'),
        'summary': {
            'num_results': len(result_records),
            'num_retrieved_docs': len(retrieval_records),
            'ground_truth_distribution': retrieval_records.groupby(['geog_ans_ground_truth', 'events_ans_ground_truth']).size().to_dict()
        }
    }

# 主要執行程序

In [None]:
tasks = {
        "Task1_ZeroShot": TASK1_ZERO_SHOT,
        "Task1_OneShot": TASK1_ONE_SHOT,
        "Task2_ZeroShot": TASK2_ZERO_SHOT,
        "Task2_OneShot": TASK2_ONE_SHOT
    }

# queries = {
#     'Task1_query': query_task1,
#     'Task2_query':query_task2
# }

In [None]:
pdf_list = get_pdf_list(VECTORSTORE_ROOT_PATH)
print(f"Found {len(pdf_list)} PDF folders to process")

In [None]:
# 儲存本輪執行結果
all_results = {}
all_retrieval_records = []
all_lookup_indices = {}

for task_name, task_prompt in tasks.items():
  task_type = task_name.split('_')[0]
  print(task_type)
  print(f"\n{'-'*40}")
  print(f"Running {task_name}")
  print(f"{'-'*40}")
  tmp_results = []
  tmp_retrieval_records = []
  tmp_lookup_indices = []
  for pdf_name in tqdm(pdf_list):
    # 如果需要額外傳入 query ，請調整 rag_qa_pipeline 的參數
    response = rag_qa_pipeline(pdf_name, task_prompt, llm)
    results, retrieval_records = decode_response(response, task_name)
    tmp_results.append(results)
    tmp_retrieval_records.extend(retrieval_records)
  tmp_results = [item for sublist in tmp_results for item in sublist]
  all_results[task_name] = tmp_results
  all_retrieval_records.extend(tmp_retrieval_records)

  # 轉換成 pandas_df
  df = pd.DataFrame(tmp_results)
  all_results[f"{task_name}_DataFrame"] = df

  # 設定輸出路徑
  output_path = os.path.join(output_dir, f"{task_name}_results_{timestamp}.xlsx")
  with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
    df.to_excel(writer, sheet_name='Results', index=False)

    task_retrieval_df = pd.DataFrame([r for r in tmp_retrieval_records])
    task_retrieval_df.to_excel(writer, sheet_name='Retrieval_Records', index=False)

    if not df.empty:
      analysis = analyze_performance_detailed(df, task_retrieval_df, task_name)
      if 'confusion_matrix' in analysis:
        cm_df = pd.DataFrame([analysis['confusion_matrix']])
        cm_df.to_excel(writer, sheet_name='Confusion_Matrix', index=False)
      if 'metrics' in analysis:
        metrics_df = pd.DataFrame([analysis['metrics']])
        metrics_df.to_excel(writer, sheet_name='Performance_Metrics', index=False)

  print(f"Saved {task_name} detailed results to: {output_path}")

  if not df.empty:
    total_docs = len(df)
    detected_docs = len(df[df['llm_detection'] == 1])
    found_entities = len(df[df['llm_ans'] == 1])
    print(f"Summary for {task_name}:")
    print(f"  - Total documents processed: {total_docs}")
    print(f"  - Documents with detection: {detected_docs}")
    print(f"  - Entities/Events found: {found_entities}")
    print(f"  - Detection rate: {detected_docs/total_docs*100:.2f}%")
  if 'performance' in df.columns:
    perf_counts = df['performance'].value_counts()
    print(f"  - Performance breakdown:")
    for perf_type, count in perf_counts.items():
      print(f"    {perf_type}: {count} ({count/total_docs*100:.2f}%)")


In [None]:
# 輸出匯總報表以及摘要
all_retrieval_df = pd.DataFrame(all_retrieval_records)
lookup_indices, retrieval_df = create_lookup_index(all_retrieval_records)
comprehensive_data = {
        'task_results': all_results,
        'retrieval_records': all_retrieval_records,
        'lookup_indices': lookup_indices,
        'retrieval_dataframe': all_retrieval_df
    }

pickle_path = os.path.join(output_dir, f'comprehensive_rag_evaluation_results{timestamp}.pkl')
with open(pickle_path, 'wb') as f:
  pickle.dump(comprehensive_data, f)

retrieval_db_path = os.path.join(output_dir, f'retrieval_database{timestamp}.xlsx')
with pd.ExcelWriter(retrieval_db_path, engine='openpyxl') as writer:
  all_retrieval_df.to_excel(writer, sheet_name='All_Retrievals', index=False)
  pdf_summary = all_retrieval_df.groupby(['company','year']).agg({
            'query_id': 'count',
            'geog_ans_ground_truth': lambda x: (x == 1).sum(),
            'events_ans_ground_truth': lambda x: (x == 1).sum()
        }).rename(columns={
            'query_id': 'total_retrievals',
            'geog_ans_ground_truth': 'geog_positive_docs',
            'events_ans_ground_truth': 'events_positive_docs'
        })
  pdf_summary.to_excel(writer, sheet_name='PDF_Summary')

  company_summary = all_retrieval_df.groupby('company').agg({
            'query_id': 'count',
            'geog_ans_ground_truth': lambda x: (x == 1).sum(),
            'events_ans_ground_truth': lambda x: (x == 1).sum()
        }).rename(columns={
            'query_id': 'total_retrievals',
            'geog_ans_ground_truth': 'geog_positive_docs',
            'events_ans_ground_truth': 'events_positive_docs'
        })
  company_summary.to_excel(writer, sheet_name='Company_Summary')

In [None]:
summary_data = []
for task_name in tasks.keys():
  if f"{task_name}_DataFrame" in all_results:
    df = all_results[f"{task_name}_DataFrame"]
    if not df.empty:
      total = len(df)
      detected = len(df[df['llm_detection'] == 1])
      found = len(df[df['llm_ans'] == 1])
      summary_row = {
                    'Task': task_name,
                    'Total_Documents': total,
                    'Documents_With_Detection': detected,
                    'Entities_Found': found,
                    'Detection_Rate': detected/total*100 if total > 0 else 0,
                    'Success_Rate': found/total*100 if total > 0 else 0
                }
      if 'performance' in df.columns:
                    perf_counts = df['performance'].value_counts()
                    summary_row.update({
                        'TP': perf_counts.get('TP', 0),
                        'FP': perf_counts.get('FP', 0),
                        'FN': perf_counts.get('FN', 0),
                        'TN': perf_counts.get('TN', 0),
                        'Precision': perf_counts.get('TP', 0) / (perf_counts.get('TP', 0) + perf_counts.get('FP', 0)) if (perf_counts.get('TP', 0) + perf_counts.get('FP', 0)) > 0 else 0,
                        'Recall': perf_counts.get('TP', 0) / (perf_counts.get('TP', 0) + perf_counts.get('FN', 0)) if (perf_counts.get('TP', 0) + perf_counts.get('FN', 0)) > 0 else 0,
                        'Accuracy': (perf_counts.get('TP', 0) + perf_counts.get('TN', 0)) / total if total > 0 else 0
                    })
      summary_data.append(summary_row)

summary_df = pd.DataFrame(summary_data)
summary_path = os.path.join(output_dir, f"evaluation_summary_{timestamp}.xlsx")
summary_df.to_excel(summary_path, index=False)
print(f"Saved comprehensive summary report to: {summary_path}")