# AllTask Baseline for FinanceRAG Tasks

loads data from the local `data`, runs a retrieval and reranking pipeline, and aggregates the results.

In [None]:
#TWCC專用
%pip install sentence-transformers datasets pytrec_eval accelerate pandas pyarrow

In [None]:
# TWCC專用

# 1. 卸載目前不相容的 PyTorch
%pip uninstall -y torch torchvision torchaudio

# 2. 安裝官方穩定版 (支援 Tesla V100)
# CUDA 12.1 或 12.4 的版本，目前最通用的指令
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

%pip install ipywidgets

In [None]:
import logging
import os
import csv
import json
import heapq
import abc
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast

import numpy as np
import torch
from datasets import Dataset, Value, load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder
from pydantic import BaseModel, Field
from tqdm.auto import tqdm
import pandas as pd

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## Core Classes (DataLoader, Encoder, Retrieval, Reranker)

In [None]:
class HFDataLoader:
    def __init__(
            self,
            data_folder: str,
            subset: str,
            corpus_file: str,
            query_file: str,
            keep_in_memory: bool = False,
    ):
        self.corpus: Optional[Dataset] = None
        self.queries: Optional[Dataset] = None
        self.data_folder = data_folder
        self.subset = subset
        self.corpus_file = os.path.join(data_folder, corpus_file)
        self.query_file = os.path.join(data_folder, query_file)
        self.keep_in_memory = keep_in_memory

    def load(self) -> Tuple[Dataset, Dataset]:
        if self.corpus is None:
            logger.info(f"Loading Corpus from {self.corpus_file}...")
            if self.corpus_file.endswith(".parquet"):
                 self.corpus = load_dataset("parquet", data_files=self.corpus_file, split="train", keep_in_memory=self.keep_in_memory)
            else:
                self.corpus = load_dataset("json", data_files=self.corpus_file, split="train", keep_in_memory=self.keep_in_memory)
            
            # Standardize columns
            if "_id" in self.corpus.column_names:
                self.corpus = self.corpus.cast_column("_id", Value("string"))
                self.corpus = self.corpus.rename_column("_id", "id")
            
            # Keep only necessary columns
            keep_cols = ["id", "text", "title"]
            self.corpus = self.corpus.remove_columns([c for c in self.corpus.column_names if c not in keep_cols])
            
        if self.queries is None:
            logger.info(f"Loading Queries from {self.query_file}...")
            if self.query_file.endswith(".parquet"):
                self.queries = load_dataset("parquet", data_files=self.query_file, split="train", keep_in_memory=self.keep_in_memory)
            else:
                self.queries = load_dataset("json", data_files=self.query_file, split="train", keep_in_memory=self.keep_in_memory)
            
            if "_id" in self.queries.column_names:
                self.queries = self.queries.cast_column("_id", Value("string"))
                self.queries = self.queries.rename_column("_id", "id")
            
            keep_cols = ["id", "text"]
            self.queries = self.queries.remove_columns([c for c in self.queries.column_names if c not in keep_cols])

        return self.corpus, self.queries

class Encoder(abc.ABC):
    @abc.abstractmethod
    def encode_queries(self, queries: List[str], **kwargs) -> Union[torch.Tensor, np.ndarray]: raise NotImplementedError
    @abc.abstractmethod
    def encode_corpus(self, corpus: Union[List[Dict], Dict], **kwargs) -> Union[torch.Tensor, np.ndarray]: raise NotImplementedError

class Retrieval(abc.ABC):
    @abc.abstractmethod
    def retrieve(self, corpus, queries, top_k, **kwargs) -> Dict[str, Dict[str, float]]: raise NotImplementedError

class Reranker(abc.ABC):
    @abc.abstractmethod
    def rerank(self, corpus, queries, results, top_k, **kwargs) -> Dict[str, Dict[str, float]]: raise NotImplementedError

class SentenceTransformerEncoder(Encoder):
    def __init__(self, model_name_or_path: str, query_prompt: str = None, doc_prompt: str = None, **kwargs):
        self.q_model = SentenceTransformer(model_name_or_path, **kwargs)
        self.doc_model = self.q_model
        self.query_prompt = query_prompt
        self.doc_prompt = doc_prompt

    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs):
        if self.query_prompt:
            queries = [self.query_prompt + q for q in queries]
        return self.q_model.encode(queries, batch_size=batch_size, **kwargs)

    def encode_corpus(self, corpus: Union[List[Dict], Dict], batch_size: int = 8, **kwargs):
        if isinstance(corpus, dict):
            sentences = [(corpus["title"][i] + " " + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus["text"]))]
        else:
            sentences = [(doc.get("title", "") + " " + doc["text"]).strip() for doc in corpus]
        if self.doc_prompt:
            sentences = [self.doc_prompt + s for s in sentences]
        return self.doc_model.encode(sentences, batch_size=batch_size, **kwargs)

class DenseRetrieval(Retrieval):
    def __init__(self, model: Encoder, batch_size: int = 64, corpus_chunk_size: int = 50000):
        self.model = model
        self.batch_size = batch_size
        self.corpus_chunk_size = corpus_chunk_size

    def retrieve(self, corpus, queries, top_k=100, score_function="cos_sim", **kwargs):
        logger.info("Encoding queries...")
        query_ids = list(queries.keys())
        query_texts = [queries[qid] for qid in queries]
        query_embeddings = self.model.encode_queries(query_texts, batch_size=self.batch_size, **kwargs)
        
        logger.info("Encoding corpus and searching...")
        sorted_corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
        corpus_list = [corpus[cid] for cid in sorted_corpus_ids]
        
        self.results = {qid: {} for qid in query_ids}
        result_heaps = {qid: [] for qid in query_ids}

        for start_idx in tqdm(range(0, len(corpus), self.corpus_chunk_size), desc="Retrieving Chunks"):
            end_idx = min(start_idx + self.corpus_chunk_size, len(corpus_list))
            sub_corpus_embeddings = self.model.encode_corpus(corpus_list[start_idx:end_idx], batch_size=self.batch_size, **kwargs)
            
            if isinstance(query_embeddings, np.ndarray): query_embeddings = torch.from_numpy(query_embeddings)
            if isinstance(sub_corpus_embeddings, np.ndarray): sub_corpus_embeddings = torch.from_numpy(sub_corpus_embeddings)
            
            if torch.cuda.is_available():
                query_embeddings = query_embeddings.cuda()
                sub_corpus_embeddings = sub_corpus_embeddings.cuda()

            q_norm = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
            c_norm = torch.nn.functional.normalize(sub_corpus_embeddings, p=2, dim=1)
            cos_scores = torch.mm(q_norm, c_norm.transpose(0, 1))
            cos_scores[torch.isnan(cos_scores)] = -1
            cos_scores = cos_scores.cpu()

            values, indices = torch.topk(cos_scores, min(top_k+1, cos_scores.size(1)), dim=1)
            values, indices = values.tolist(), indices.tolist()

            for i, qid in enumerate(query_ids):
                for score, idx in zip(values[i], indices[i]):
                    doc_id = sorted_corpus_ids[start_idx + idx]
                    if doc_id != qid:
                        if len(result_heaps[qid]) < top_k:
                            heapq.heappush(result_heaps[qid], (score, doc_id))
                        else:
                            heapq.heappushpop(result_heaps[qid], (score, doc_id))

        for qid in result_heaps:
            for score, doc_id in result_heaps[qid]:
                self.results[qid][doc_id] = score
        return self.results

class CrossEncoderReranker(Reranker):
    def __init__(self, model: CrossEncoder):
        self.model = model

    def rerank(self, corpus, queries, results, top_k, batch_size=32, **kwargs):
        sentence_pairs, pair_ids = [], []
        for query_id in results:
            sorted_docs = sorted(results[query_id].items(), key=lambda item: item[1], reverse=True)[:top_k]
            for doc_id, _ in sorted_docs:
                pair_ids.append([query_id, doc_id])
                corpus_text = (corpus[doc_id].get("title", "") + " " + corpus[doc_id].get("text", "")).strip()
                sentence_pairs.append([queries[query_id], corpus_text])

        logger.info(f"Starting Reranking for {len(sentence_pairs)} pairs...")
        scores = self.model.predict(sentence_pairs, batch_size=batch_size, show_progress_bar=True, **kwargs)
        
        reranked_results = {qid: {} for qid in results}
        for (qid, doc_id), score in zip(pair_ids, scores):
            reranked_results[qid][doc_id] = float(score)
        return reranked_results

## Task Definitions

In [None]:
class TaskMetadata(BaseModel):
    name: str
    dataset: dict
    description: str = ""

class BaseTask:
    def __init__(self, metadata: TaskMetadata, data_folder: str = "./data"):
        self.metadata = metadata
        self.data_folder = data_folder
        self.queries = None
        self.corpus = None
        self.retrieve_results = None
        self.rerank_results = None
        self.qrels = None
        self.load_data()

    def load_data(self):
        subset = self.metadata.dataset["subset"]
        # Determine folder name based on subset
        folder_name = subset.lower()
        if subset == "MultiHiertt":
            folder_name = "multiheirtt"
        
        if subset == 'FinDER':
            corpus_file = "FinanceRAG_corpus.parquet"
            query_file = "FinanceRAG_queries.parquet"
        else:
            corpus_file = f"{folder_name}_corpus.jsonl/corpus.jsonl"
            query_file = f"{folder_name}_queries.jsonl/queries.jsonl"
        
        loader = HFDataLoader(
            data_folder=self.data_folder,
            subset=subset,
            corpus_file=corpus_file,
            query_file=query_file
        )
        corpus, queries = loader.load()
        
        self.queries = {row["id"]: row["text"] for row in queries}
        self.corpus = {row["id"]: {"title": row.get("title", ""), "text": row.get("text", "")} for row in corpus}
        logger.info(f"Loaded {len(self.corpus)} docs and {len(self.queries)} queries for {subset}.")
        
        # Load Qrels
        qrels_file = os.path.join(self.data_folder, f"{subset}_qrels.tsv")
        if os.path.exists(qrels_file):
            logger.info(f"Loading qrels from {qrels_file}...")
            self.qrels = pd.read_csv(qrels_file, sep='\t')
            # Ensure columns are strings for merging
            self.qrels['query_id'] = self.qrels['query_id'].astype(str)
            self.qrels['corpus_id'] = self.qrels['corpus_id'].astype(str)

    def retrieve(self, retriever, top_k=100, **kwargs):
        self.retrieve_results = retriever.retrieve(self.corpus, self.queries, top_k=top_k, **kwargs)
        return self.retrieve_results

    def rerank(self, reranker, results=None, top_k=100, batch_size=32, **kwargs):
        if results is None: results = self.retrieve_results
        self.rerank_results = reranker.rerank(self.corpus, self.queries, results, top_k, batch_size, **kwargs)
        return self.rerank_results
    
    def evaluate(self, results, top_k=10):
        if self.qrels is None:
            return None
        
        # Convert results to DataFrame for easier evaluation
        res_rows = []
        for qid, docs in results.items():
            sorted_docs = sorted(docs.items(), key=lambda x: x[1], reverse=True)[:top_k]
            for doc_id, score in sorted_docs:
                res_rows.append({'query_id': str(qid), 'corpus_id': str(doc_id)})
        
        df_res = pd.DataFrame(res_rows)
        
        # Merge with qrels to find hits
        # We check if (query_id, corpus_id) pair exists in qrels
        merged = df_res.merge(self.qrels, on=['query_id', 'corpus_id'], how='inner')
        
        # Calculate Accuracy (at least one correct document retrieved per query)
        # Note: This is a simplified metric based on the user's code1.ipynb logic
        unique_queries_with_hits = merged['query_id'].nunique()
        total_queries = len(results)
        accuracy = unique_queries_with_hits / total_queries if total_queries > 0 else 0.0
        
        return {'accuracy': accuracy, 'hits': unique_queries_with_hits, 'total_queries': total_queries}

class FinDER(BaseTask):
    def __init__(self, data_folder="./data"):
        metadata = TaskMetadata(name="FinDER", dataset={"subset": "FinDER"})
        super().__init__(metadata, data_folder)

class FinQABench(BaseTask):
    def __init__(self, data_folder="./data"):
        metadata = TaskMetadata(name="FinQABench", dataset={"subset": "FinQABench"})
        super().__init__(metadata, data_folder)

class FinQA(BaseTask):
    def __init__(self, data_folder="./data"):
        metadata = TaskMetadata(name="FinQA", dataset={"subset": "FinQA"})
        super().__init__(metadata, data_folder)

class FinanceBench(BaseTask):
    def __init__(self, data_folder="./data"):
        metadata = TaskMetadata(name="FinanceBench", dataset={"subset": "FinanceBench"})
        super().__init__(metadata, data_folder)

class ConvFinQA(BaseTask):
    def __init__(self, data_folder="./data"):
        metadata = TaskMetadata(name="ConvFinQA", dataset={"subset": "ConvFinQA"})
        super().__init__(metadata, data_folder)

class MultiHiertt(BaseTask):
    def __init__(self, data_folder="./data"):
        metadata = TaskMetadata(name="MultiHiertt", dataset={"subset": "MultiHiertt"})
        super().__init__(metadata, data_folder)

class TATQA(BaseTask):
    def __init__(self, data_folder="./data"):
        metadata = TaskMetadata(name="TATQA", dataset={"subset": "TATQA"})
        super().__init__(metadata, data_folder)

In [None]:
# Execution Pipeline
import gc

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Initialize Models
encoder_model = SentenceTransformerEncoder(
    model_name_or_path='intfloat/e5-large-v2',
    query_prompt='query: ',
    doc_prompt='passage: ',
    device=device
)
retrieval_model = DenseRetrieval(model=encoder_model, batch_size=128)

reranker_model = CrossEncoderReranker(
    model=CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', device=device)
)

# List of Tasks
task_classes = [
    FinDER,
    FinQABench,
    FinQA,
    FinanceBench,
    ConvFinQA,
    MultiHiertt,
    TATQA
]

all_results = []

for TaskClass in tqdm(task_classes, desc="Processing Tasks"):
    try:
        task = TaskClass(data_folder="./data")
        print(f"\n>>> Processing Task: {task.metadata.name} <<<")
        
        # Retrieve
        print(f"--- Retrieving ({task.metadata.name}) ---")
        retrieve_result = task.retrieve(retrieval_model, top_k=100)
        
        # Rerank
        print(f"--- Reranking ({task.metadata.name}) ---")
        reranking_result = task.rerank(reranker_model, results=retrieve_result, top_k=100, batch_size=64)
        
        # Evaluate
        print(f"--- Evaluating ({task.metadata.name}) ---")
        if task.qrels is not None:
            metrics = task.evaluate(reranking_result)
            print(f"Task {task.metadata.name} Metrics: {metrics}")
        else:
            print(f"No qrels found for {task.metadata.name}, skipping evaluation.")

        # Collect Results
        print(f"--- Collecting Results ({task.metadata.name}) ---")
        for query_id, result in reranking_result.items():
            top_10 = sorted(result.items(), key=lambda x: x[1], reverse=True)[:10]
            for corpus_id, score in top_10:
                all_results.append({
                    'query_id': query_id,
                    'corpus_id': corpus_id,
                    'score': score
                })
        
        # Cleanup
        del task
        del retrieve_result
        del reranking_result
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    except Exception as e:
        print(f"Error processing task {TaskClass.__name__}: {e}")
        import traceback
        traceback.print_exc()

# Save Final Results
print("\n=== Saving Submission File ===")
output_file = 'submission.csv'
with open(output_file, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['query_id', 'corpus_id'])
    for row in all_results:
        writer.writerow([row['query_id'], row['corpus_id']])
print(f"Results saved to {output_file}")