In [None]:
from notebooks.ablations.utils import TOPICS_UNDER_10K
from models import BM25Reranker
from utils import HyResearch
from tqdm.auto import tqdm
from glob import glob
import pandas as pd
import numpy as np

files = glob("ablation/hyde/**/*.parquet")
files = [i.replace("\\", "/") for i in files]
eval_df = pd.read_excel("data/eval_cps.xlsx")

In [None]:
def bm25_query(topic: str, query: str) -> str:
    return "\n".join(HyResearch().generate_n_queries(query, topic, 1))

def get_topic_data(df: pd.DataFrame, topic: str) -> list[str, dict[str, str]]:
    topic_df = df[df["topic"] == topic]
    query = topic_df["query"].iloc[0]
    query = "\n".join(query) if isinstance(query, (list, np.ndarray)) else query
    documents = dict(zip(topic_df["id"].values, [i for i in topic_df["text"].values]))
    return query, documents

def run():
    reranker = BM25Reranker()
    for file in tqdm(files, desc="Files"):
        df = pd.read_parquet(file)
        output_path = file.replace("/hyde/", "/hyqe/")
        results = {
            "topic": [],
            "id": [],
            "scores": [],
        }
        for topic, top_n in tqdm(TOPICS_UNDER_10K.items(), desc="Topics", leave=False):
            query, documents = get_topic_data(df, topic)
            query = bm25_query(topic, query)
            scores = reranker.rerank(query, documents)
            results["topic"].extend([topic] * top_n)
            results["id"].extend(list(scores.keys())[:top_n])
            results["scores"].extend(list(scores.values())[:top_n])

        pd.DataFrame(results).to_parquet(output_path)

In [None]:
def parse_data(files: list[str]) -> tuple[str, pd.DataFrame]:
    dfs = []
    for file in tqdm(files):
        cp_type = file.split("/")[2]
        hyde_n = file.split("_HYDE")[1][0]
        df_id = f"CP type: {cp_type} HyDe: {hyde_n}"
        hyde_df = pd.read_parquet(file.replace("/hyqe/", "/hyde/"))
        hyqe_df = pd.read_parquet(file)
        results = {
            "topic": [],
            "hyde": [],
            "hyqe": [],
        }
        for topic, top_n in TOPICS_UNDER_10K.items():
            hyde_topic = (
                hyde_df[hyde_df["topic"] == topic]
                .sort_values("score", ascending=False)
                .head(top_n)
            )
            hyqe_topic = (
                hyqe_df[hyqe_df["topic"] == topic]
                .sort_values("scores", ascending=False)
                .head(top_n)
            )
            eval_topic = eval_df[eval_df["topic"] == topic]["id"].tolist()

            hyde_cores = hyde_topic["id"].isin(eval_topic).sum()
            hyqe_cores = hyqe_topic["id"].isin(eval_topic).sum()
            eval_cores = len(eval_topic)
            results["topic"].append(topic)
            results["hyde"].append(hyde_cores / eval_cores)
            results["hyqe"].append(hyqe_cores / eval_cores)
        dfs.append((df_id, pd.DataFrame(results)))

    return dfs


def show_results(files: list[str]):
    dfs = parse_data(files)  # type: list[tuple[str, pd.DataFrame]]
    mean_recalls = []
    to_display = []
    for name, df in dfs:
        mean_recalls.append(df["hyqe"].mean(axis=0, numeric_only=True))
        df.loc["Average"] = df.mean(axis=0, numeric_only=True)
        to_display.append(df.style.highlight_max(axis=1, subset=["hyde", "hyqe"], color="green")
            .set_caption(name)
            .format("{:.2}", subset=["hyde", "hyqe"])
            .set_table_attributes('style="width: 50%;"'))
        
    print(f"HyDe=0: {mean_recalls[0]:.3f}")
    print(f"HyDe=1: {mean_recalls[1]:.3f}")
    print(f"HyDe=2: {mean_recalls[2]:.3f}")
    [display(i) for i in to_display]

#### Using the Worst CP

In [None]:
worst_files = glob("ablation/hyqe/worst/All_*.parquet")
show_results(worst_files)

#### Using the Average CP

In [None]:
average_files = glob("ablation/hyqe/average/All_*.parquet")
show_results(average_files)

#### Using the Best CP

In [None]:
best_files = glob("ablation/hyqe/best/All_*.parquet")
show_results(best_files)