In [85]:
import pickle as dill
from pathlib import Path

from scripts.executor import WrappedResults


def load_wrapped_results(path: Path | str) -> WrappedResults:
    """
    Load WrappedResults from a file.
    :param path: Path to the file.
    :return: WrappedResults object.
    """
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path}")

    with open(path, "rb") as f:
        return dill.load(f)

In [86]:
from enum import Enum
from dataclasses import dataclass

import numpy as np

from scripts.executor import ExecutionResult
from scripts.evaluation import is_model_response_correct


class AnswerExtractorType(Enum):
    """
    Enum for the different types of answer extractors.
    """

    RAW = "raw"
    POSTPROCESSED = "postprocessed"


# Modified version from https://github.com/1171-jpg/BrainTeaser/blob/main/Prompting/utils.py
# to match the format of the results
def getResultdata(
    execution_results: list[ExecutionResult], extractor_type: AnswerExtractorType
) -> dict:
    word_play = {}
    reverse_play = {}
    for execution_result in execution_results:
        item = execution_result.riddle
        item_type = item.id.split("-")[0]
        item_id = item.id.split("-")[1].split("_")[0]
        if item_type == "WP":
            if item_id not in word_play:
                word_play[item_id] = [0, 0, 0]
        else:
            if item_id not in reverse_play:
                reverse_play[item_id] = [0, 0, 0]

    for execution_result in execution_results:
        item = execution_result.riddle
        item_type = item.id.split("-")[0]
        item_id = item.id.split("-")[1].split("_")[0]
        ad_type = 0
        if "SR" in item.id:
            ad_type = 1
        elif "CR" in item.id:
            ad_type = 2
        else:
            ad_type = 0

        raw_correct, postprocessed_correct = is_model_response_correct(execution_result)

        if item_type == "WP":
            word_play[item_id][ad_type] = (
                raw_correct
                if extractor_type == AnswerExtractorType.RAW
                else postprocessed_correct
            )
        else:
            reverse_play[item_id][ad_type] = (
                raw_correct
                if extractor_type == AnswerExtractorType.RAW
                else postprocessed_correct
            )

    return word_play, reverse_play


@dataclass
class Metrics:
    """
    Dataclass to store the evaluation metrics.
    """

    overall_accuracy: float
    original_accuracy: float
    semantic_accuracy: float
    context_accuracy: float
    ori_sema: float
    ori_sema_cont: float

    def __sub__(self, other):
        return Metrics(
            overall_accuracy=self.overall_accuracy - other.overall_accuracy,
            original_accuracy=self.original_accuracy - other.original_accuracy,
            semantic_accuracy=self.semantic_accuracy - other.semantic_accuracy,
            context_accuracy=self.context_accuracy - other.context_accuracy,
            ori_sema=self.ori_sema - other.ori_sema,
            ori_sema_cont=self.ori_sema_cont - other.ori_sema_cont,
        )

    def __add__(self, other):
        return Metrics(
            overall_accuracy=self.overall_accuracy + other.overall_accuracy,
            original_accuracy=self.original_accuracy + other.original_accuracy,
            semantic_accuracy=self.semantic_accuracy + other.semantic_accuracy,
            context_accuracy=self.context_accuracy + other.context_accuracy,
            ori_sema=self.ori_sema + other.ori_sema,
            ori_sema_cont=self.ori_sema_cont + other.ori_sema_cont,
        )


def getMetric(data_list) -> Metrics:
    data_list = np.array(data_list)
    overall_accuracy = np.sum(data_list) / 3 / len(data_list)
    original_accuracy = np.sum(data_list, axis=0)[0] / len(data_list)
    semantic_accuracy = np.sum(data_list, axis=0)[1] / len(data_list)
    context_accuracy = np.sum(data_list, axis=0)[2] / len(data_list)
    ori_sema = np.sum(
        [1 if item[0] == 1 and item[1] == 1 else 0 for item in data_list]
    ) / len(data_list)
    ori_sema_cont = np.sum(
        [
            1 if item[0] == 1 and item[1] == 1 and item[2] == 1 else 0
            for item in data_list
        ]
    ) / len(data_list)

    return Metrics(
        overall_accuracy=np.round(overall_accuracy, 4) * 100,
        original_accuracy=np.round(original_accuracy, 4) * 100,
        semantic_accuracy=np.round(semantic_accuracy, 4) * 100,
        context_accuracy=np.round(context_accuracy, 4) * 100,
        ori_sema=np.round(ori_sema, 4) * 100,
        ori_sema_cont=np.round(ori_sema_cont, 4) * 100,
    )


@dataclass
class ModelEvaluationMetrics:
    wordplay: Metrics
    sentence: Metrics
    all: Metrics


def getSeperateResult(word_play, reverse_thinking) -> ModelEvaluationMetrics:
    word_data_list = []
    for item in word_play.values():
        word_data_list.append(item)

    reverse_data_list = []
    for item in reverse_thinking.values():
        reverse_data_list.append(item)

    all_data = word_data_list + reverse_data_list

    result = ModelEvaluationMetrics(
        wordplay=getMetric(word_data_list),
        sentence=getMetric(reverse_data_list),
        all=getMetric(all_data),
    )

    return result


@dataclass
class EvaluationReport:
    raw_results: ModelEvaluationMetrics
    postprocessed_results: ModelEvaluationMetrics


def generateEvaluationReport(model_results: list[ExecutionResult]) -> EvaluationReport:
    word_play_raw, sentence_play_raw = getResultdata(
        model_results, AnswerExtractorType.RAW
    )
    word_play_postprocessed, sentence_play_postprocessed = getResultdata(
        model_results, AnswerExtractorType.POSTPROCESSED
    )
    final_result_raw = getSeperateResult(word_play_raw, sentence_play_raw)
    final_result_postprocessed = getSeperateResult(
        word_play_postprocessed, sentence_play_postprocessed
    )

    return EvaluationReport(
        raw_results=final_result_raw, postprocessed_results=final_result_postprocessed
    )


def create_evaluation_from_results(
    wrapped_results_path: Path | str,
) -> dict[str, EvaluationReport]:
    with open(wrapped_results_path, "rb") as f:
        wrapped_results = dill.load(f)

    combines_model_execution_results: dict[str, list[ExecutionResult]] = {}
    for _, models_results in wrapped_results.results.items():
        for model_name, results in models_results.items():
            if model_name not in combines_model_execution_results:
                combines_model_execution_results[model_name] = []
            combines_model_execution_results[model_name].extend(results)

    results: dict[str, EvaluationReport] = {}
    for (
        model_name,
        models_execution_results,
    ) in combines_model_execution_results.items():
        report = generateEvaluationReport(models_execution_results)
        results[model_name] = report
    return results

In [87]:
def results_to_texttable(
    results: dict[str, EvaluationReport],
    optimized_results: dict[str, EvaluationReport],
    caption: str,
    label: str,
):
    """
    Convert WrappedResults to a LaTex Table.
    :param wrapped_results: WrappedResults object.
    :return: Texttable.
    """

    def print_line(
        model_name: str,
        metrics_baseline: Metrics,
        metrics_optimized: Metrics,
        postfix: str = "",
    ):
        is_post = False  # "POST" in model_name
        model_name = model_name.replace("_", "-").replace("-instruct-", "-")
        model_text = f"""
        \multicolumn{{1}}{{{"r" if is_post else "c"}|}}{{{model_name}}} & {metrics_baseline.original_accuracy:.2f} ({metrics_optimized.original_accuracy:.2f}) & {metrics_baseline.semantic_accuracy:.2f} ({metrics_optimized.semantic_accuracy:.2f}) & \multicolumn{{1}}{{c|}}{{{metrics_baseline.context_accuracy:.2f} ({metrics_optimized.context_accuracy:.2f})}} & {metrics_baseline.ori_sema:.2f} ({metrics_optimized.ori_sema:.2f}) & \multicolumn{{1}}{{c|}}{{{metrics_baseline.ori_sema_cont:.2f} ({metrics_optimized.ori_sema_cont:.2f})}} & {metrics_baseline.overall_accuracy:.2f} ({metrics_optimized.overall_accuracy:.2f}) \\\\
        {postfix}
        """
        return model_text

    text_sp = []
    text_wp = []

    for model_name in results:
        model1_raw = results[model_name].raw_results
        model1_postprocessed = results[model_name].postprocessed_results
        model2_raw = optimized_results[model_name].raw_results
        model2_postprocessed = optimized_results[model_name].postprocessed_results

        text_sp.append(
            print_line(
                f"{model_name} (RAW)",
                model1_raw.sentence,
                model2_raw.sentence,
            )
        )
        text_sp.append(
            print_line(
                '" (POST)',
                model1_postprocessed.sentence,
                model2_postprocessed.sentence,
                postfix="\cline{1-7}",
            )
        )
        text_wp.append(
            print_line(
                f"{model_name} (RAW)",
                model1_raw.wordplay,
                model2_raw.wordplay,
            )
        )
        text_wp.append(
            print_line(
                '" (POST)',
                model1_postprocessed.wordplay,
                model2_postprocessed.wordplay,
                postfix="\cline{1-7}",
            )
        )

    sp_data = "\n".join(text_sp)
    wp_data = "\n".join(text_wp)

    suffix = "'Ori & Sem' means 'Original & Semantic' while 'O & S & C' means 'Original & Semantic & Context'. 'Overall' is the average of 'Ori', 'Sem', and 'Con'. 'RAW' refers to unprocessed model responses, while 'POST' refers to postprocessed model responses."
    suffix = suffix.replace("&", "\\&")
    full_table = (
        """
    \\begin{table}[]\r\n\\caption[{caption}]{{caption}: {suffix}}\r\n\\label{tab:{label}}\r\n\\resizebox{\\textwidth}{!}{%\r\n\\begin{tabular}{ccccccc}\r\n\\cline{1-7}\r\n & \\multicolumn{3}{|c|}{Instance-based} & \\multicolumn{2}{c|}{Group-based} \\\\ \\cline{1-6}\r\n\\multicolumn{1}{c|}{\\textbf{Model}} & \\textbf{Original} & \\textbf{Semantic} & \\multicolumn{1}{c|}{\\textbf{Context}} & \\textbf{Ori \\& Sem} & \\multicolumn{1}{c|}{\\textbf{O \\& S \\& C}} & \\textbf{Overall} \\\\ \\hline\r\n\\multicolumn{7}{c}{Sentence Puzzle} \\\\ \\hline\r\n\\multicolumn{1}{c|}{\\textbf{Random*}} & 25.52 & 24.88 & \\multicolumn{1}{c|}{22.81} & 5.58 & \\multicolumn{1}{c|}{1.44} & 24.40 \\\\ \\hline\r\n{{sp_data}}\r\n\\hline\r\n\\multicolumn{7}{c}{Word Puzzle} \\\\ \\hline\r\n\\multicolumn{1}{c|}{\\textbf{Random*}} & 26.02 & 27.85 & \\multicolumn{1}{c|}{22.51} & 7.32 & \\multicolumn{1}{c|}{1.83} & 25.34 \\\\ \\hline\r\n{{wp_data}}\r\n\\\\ \\hline\r\n\\end{tabular}%\r\n}\r\n\\end{table}
    """.replace("{{sp_data}}", sp_data)
        .replace("{{wp_data}}", wp_data)
        .replace("{caption}", caption.replace("&", "\\&"))
        .replace("{label}", label)
        .replace("{suffix}", suffix)
    )
    output_table = Path("tables") / f"{label}.tex"
    if not output_table.parent.exists():
        output_table.parent.mkdir(parents=True)
    with open(output_table, "w") as f:
        f.write(full_table)


def results_to_texttable_hightliting(
    results: dict[str, EvaluationReport],
    optimized_results: dict[str, EvaluationReport],
    caption: str,
    label: str,
    postprocessed: bool = True,
    same_data: bool = False,
):
    """
    Convert WrappedResults to a LaTex Table.
    :param wrapped_results: WrappedResults object.
    :return: Texttable.
    """

    def find_best_values(all_results, raw=True, puzzle_type="sentence"):
        best_values = {}
        for col in [
            "original_accuracy",
            "semantic_accuracy",
            "context_accuracy",
            "ori_sema",
            "ori_sema_cont",
            "overall_accuracy",
        ]:
            best_values[col] = -float("inf")

        for model_name in all_results:
            metrics = (
                all_results[model_name].raw_results
                if raw
                else all_results[model_name].postprocessed_results
            )

            # Only consider the specific puzzle type (sentence or wordplay)
            metrics_for_type = getattr(metrics, puzzle_type)
            for col in best_values:
                val = getattr(metrics_for_type, col)
                best_values[col] = max(best_values[col], val)

        return best_values

    # Find best values separately for sentence and wordplay puzzles
    best_baseline_sp = find_best_values(results, not postprocessed, "sentence")
    best_optimized_sp = find_best_values(
        optimized_results, not postprocessed, "sentence"
    )
    best_baseline_wp = find_best_values(results, not postprocessed, "wordplay")
    best_optimized_wp = find_best_values(
        optimized_results, not postprocessed, "wordplay"
    )

    def print_line(
        model_name: str,
        metrics_baseline: Metrics,
        metrics_optimized: Metrics,
        puzzle_type="sentence",
    ):
        model_name = model_name.replace("_", "-").replace("-instruct-", "-")

        # Select the appropriate best values based on puzzle type
        best_baseline = (
            best_baseline_sp if puzzle_type == "sentence" else best_baseline_wp
        )
        best_optimized = (
            best_optimized_sp if puzzle_type == "sentence" else best_optimized_wp
        )

        # Format each value with highlighting if it's the best
        def format_value(value, column, is_baseline):
            best_values = best_baseline if is_baseline else best_optimized
            diff = abs(value - best_values[column])
            if diff < 0.01:  # Using a small epsilon for float comparison
                return (
                    f"\\textbf{{{value:.2f}}}"
                    if is_baseline
                    else f"\\underline{{{value:.2f}}}"
                )
            return f"{value:.2f}"

        # Format each pair of values
        orig_base = format_value(
            metrics_baseline.original_accuracy, "original_accuracy", True
        )
        orig_opt = format_value(
            metrics_optimized.original_accuracy, "original_accuracy", False
        )

        sem_base = format_value(
            metrics_baseline.semantic_accuracy, "semantic_accuracy", True
        )
        sem_opt = format_value(
            metrics_optimized.semantic_accuracy, "semantic_accuracy", False
        )

        ctx_base = format_value(
            metrics_baseline.context_accuracy, "context_accuracy", True
        )
        ctx_opt = format_value(
            metrics_optimized.context_accuracy, "context_accuracy", False
        )

        ori_sem_base = format_value(metrics_baseline.ori_sema, "ori_sema", True)
        ori_sem_opt = format_value(metrics_optimized.ori_sema, "ori_sema", False)

        ori_sem_ctx_base = format_value(
            metrics_baseline.ori_sema_cont, "ori_sema_cont", True
        )
        ori_sem_ctx_opt = format_value(
            metrics_optimized.ori_sema_cont, "ori_sema_cont", False
        )

        overall_base = format_value(
            metrics_baseline.overall_accuracy, "overall_accuracy", True
        )
        overall_opt = format_value(
            metrics_optimized.overall_accuracy, "overall_accuracy", False
        )

        model_text = f"""
        \multicolumn{{1}}{{c|}}{{{model_name}}} & {orig_base} ({orig_opt}) & {sem_base} ({sem_opt}) & \multicolumn{{1}}{{c|}}{{{ctx_base} ({ctx_opt})}} & {ori_sem_base} ({ori_sem_opt}) & \multicolumn{{1}}{{c|}}{{{ori_sem_ctx_base} ({ori_sem_ctx_opt})}} & {overall_base} ({overall_opt}) \\\\
        """
        return model_text

    text_sp = []
    text_wp = []

    for model_name in results:
        model_baseline = (
            results[model_name].raw_results
            if not postprocessed
            else results[model_name].postprocessed_results
        )
        model_optimized = (
            optimized_results[model_name].raw_results
            if not postprocessed
            else optimized_results[model_name].postprocessed_results
        )
        text_sp.append(
            print_line(
                model_name,
                model_baseline.sentence,
                model_optimized.sentence,
                puzzle_type="sentence",
            )
        )
        text_wp.append(
            print_line(
                model_name,
                model_baseline.wordplay,
                model_optimized.wordplay,
                puzzle_type="wordplay",
            )
        )

    sp_data = "\n".join(text_sp)
    wp_data = "\n".join(text_wp)
    label = f"{label}-{'postprocessed' if postprocessed else 'raw'}"
    suffix = "'Ori & Sem' means 'Original & Semantic' while 'O & S & C' means 'Original & Semantic & Context'. 'Overall' is the average of 'Ori', 'Sem', and 'Con'."
    if not postprocessed:
        suffix = (
            suffix
            + " 'RAW' means that the accuracy values were calculated from raw model responses."
        )
    else:
        suffix = (
            suffix
            + " 'Postprocessed' means that the accuracy values were calculated from postprocessed model responses."
        )
    suffix = suffix.replace("&", "\\&")
    full_table = (
        """
    \\begin{table}[hbp]\r\n\\caption[{caption}]{{caption}: {suffix}}\r\n\\label{tab:{label}}\r\n\\resizebox{\\textwidth}{!}{%\r\n\\begin{tabular}{ccccccc}\r\n\\cline{1-7}\r\n & \\multicolumn{3}{|c|}{Instance-based} & \\multicolumn{2}{c|}{Group-based} \\\\ \\cline{1-6}\r\n\\multicolumn{1}{c|}{\\textbf{Model}} & \\textbf{Original} & \\textbf{Semantic} & \\multicolumn{1}{c|}{\\textbf{Context}} & \\textbf{Ori \\& Sem} & \\multicolumn{1}{c|}{\\textbf{O \\& S \\& C}} & \\textbf{Overall} \\\\ \\hline\r\n\\multicolumn{7}{c}{Sentence Puzzle} \\\\ \\hline\r\n\\multicolumn{1}{c|}{\\textbf{Random*}} & 25.52 & 24.88 & \\multicolumn{1}{c|}{22.81} & 5.58 & \\multicolumn{1}{c|}{1.44} & 24.40 \\\\ \\hline\r\n{{sp_data}}\r\n\\hline\r\n\\multicolumn{7}{c}{Word Puzzle} \\\\ \\hline\r\n\\multicolumn{1}{c|}{\\textbf{Random*}} & 26.02 & 27.85 & \\multicolumn{1}{c|}{22.51} & 7.32 & \\multicolumn{1}{c|}{1.83} & 25.34 \\\\ \\hline\r\n{{wp_data}}\r\n\\\\ \\hline\r\n\\end{tabular}%\r\n}\r\n\\end{table}
    """.replace("{{sp_data}}", sp_data)
        .replace("{{wp_data}}", wp_data)
        .replace(
            "{caption}",
            f"{caption} ({'Postprocessed' if postprocessed else 'Raw'})".replace(
                "&", "\\&"
            ),
        )
        .replace("{suffix}", suffix)
        .replace("{label}", label)
    )
    output_table = Path("tables") / f"{label}.tex"
    if not output_table.parent.exists():
        output_table.parent.mkdir(parents=True)
    with open(output_table, "w") as f:
        f.write(full_table)

In [88]:
baseline_zero_shot = create_evaluation_from_results(
    "results/baseline-zero-shot-evaluation/baseline-zero-shot-evaluation_results.pkl",
)
optimized_zero_shot = create_evaluation_from_results(
    "results/system-optimized-zero-shot-evaluation/system-optimized-zero-shot-evaluation_results.pkl",
)

baseline_few_shot = create_evaluation_from_results(
    "results/baseline-few-shot-evaluation/baseline-few-shot-evaluation_results.pkl",
)
optimized_few_shot = create_evaluation_from_results(
    "results/system-optimized-few-shot-evaluation/system-optimized-few-shot-evaluation_results.pkl",
)


results_to_texttable(
    baseline_zero_shot,
    optimized_zero_shot,
    "Zero-Shot: Default vs. Optimized System Prompt",
    "full-zero-shot",
)

results_to_texttable(
    baseline_few_shot,
    optimized_few_shot,
    "Few-Shot: Default vs. Optimized System Prompt",
    "full-few-shot",
)

results_to_texttable(
    optimized_zero_shot,
    optimized_few_shot,
    "Optimized Zero-Shot vs. Optimized Few-Shot",
    "optimized-zero-shot-vs-optimized-few-shot",
)

results_to_texttable_hightliting(
    baseline_zero_shot,
    optimized_zero_shot,
    "Zero-Shot: Default vs. Optimized System Prompt",
    "baseline-vs-optimized-zero-shot",
    postprocessed=False,
)
results_to_texttable_hightliting(
    baseline_zero_shot,
    baseline_few_shot,
    "Default System Prompt: Zero-Shot vs. Few-Shot",
    "baseline-zero-shot-vs-baseline-few-shot",
    postprocessed=False,
)

results_to_texttable_hightliting(
    baseline_few_shot,
    optimized_few_shot,
    "Few-Shot: Default vs. Optimized System Prompt",
    "baseline-vs-optimized-few-shot",
    postprocessed=False,
)

In [89]:
# results_to_texttable_hightliting(
#     baseline_few_shot,
#     optimized_few_shot,
#     "Baseline vs. Optimized Few-Shot",
#     "baseline-vs-optimized-few-shot",
#     postprocessed=False,
# )
# results_to_texttable_hightliting(
#     baseline_few_shot,
#     optimized_few_shot,
#     "Baseline vs. Optimized Few-Shot",
#     "baseline-vs-optimized-few-shot",
#     postprocessed=True,
# )

In [90]:
import pandas as pd


def results_to_csv(baseline_results, optimized_results, filename, postprocessed=False):
    """
    Export evaluation results to a CSV file with metrics as column headers.

    Args:
        baseline_results: Evaluation results for the baseline model
        optimized_results: Evaluation results for the optimized model
        filename: Base filename for the CSV
        postprocessed: Whether to use postprocessed results (True) or raw results (False)
    """

    data_sp: dict[str, dict[str, float]] = {}
    data_wp: dict[str, dict[str, float]] = {}

    columns = []

    for model_name, results in baseline_results.items():
        baseline_results_model: ModelEvaluationMetrics = (
            results.raw_results if not postprocessed else results.postprocessed_results
        )
        optimized_results_model: ModelEvaluationMetrics = (
            optimized_results[model_name].raw_results
            if not postprocessed
            else optimized_results[model_name].postprocessed_results
        )
        sp = {
            "model_name": model_name,
            "original_accuracy_baseline": baseline_results_model.sentence.original_accuracy,
            "original_accuracy_optimized": optimized_results_model.sentence.original_accuracy,
            "original_accuracy_diff": optimized_results_model.sentence.original_accuracy
            - baseline_results_model.sentence.original_accuracy,
            "semantic_accuracy_baseline": baseline_results_model.sentence.semantic_accuracy,
            "semantic_accuracy_optimized": optimized_results_model.sentence.semantic_accuracy,
            "semantic_accuracy_diff": optimized_results_model.sentence.semantic_accuracy
            - baseline_results_model.sentence.semantic_accuracy,
            "context_accuracy_baseline": baseline_results_model.sentence.context_accuracy,
            "context_accuracy_optimized": optimized_results_model.sentence.context_accuracy,
            "context_accuracy_diff": optimized_results_model.sentence.context_accuracy
            - baseline_results_model.sentence.context_accuracy,
            "ori_sema_baseline": baseline_results_model.sentence.ori_sema,
            "ori_sema_optimized": optimized_results_model.sentence.ori_sema,
            "ori_sema_diff": optimized_results_model.sentence.ori_sema
            - baseline_results_model.sentence.ori_sema,
            "ori_sema_cont_baseline": baseline_results_model.sentence.ori_sema_cont,
            "ori_sema_cont_optimized": optimized_results_model.sentence.ori_sema_cont,
            "ori_sema_cont_diff": optimized_results_model.sentence.ori_sema_cont
            - baseline_results_model.sentence.ori_sema_cont,
            "overall_accuracy_baseline": baseline_results_model.sentence.overall_accuracy,
            "overall_accuracy_optimized": optimized_results_model.sentence.overall_accuracy,
            "overall_accuracy_diff": optimized_results_model.sentence.overall_accuracy
            - baseline_results_model.sentence.overall_accuracy,
        }

        columns = sp.keys()
        wp = {
            "model_name": model_name,
            "original_accuracy_baseline": baseline_results_model.wordplay.original_accuracy,
            "original_accuracy_optimized": optimized_results_model.wordplay.original_accuracy,
            "original_accuracy_diff": optimized_results_model.wordplay.original_accuracy
            - baseline_results_model.wordplay.original_accuracy,
            "semantic_accuracy_baseline": baseline_results_model.wordplay.semantic_accuracy,
            "semantic_accuracy_optimized": optimized_results_model.wordplay.semantic_accuracy,
            "semantic_accuracy_diff": optimized_results_model.wordplay.semantic_accuracy
            - baseline_results_model.wordplay.semantic_accuracy,
            "context_accuracy_baseline": baseline_results_model.wordplay.context_accuracy,
            "context_accuracy_optimized": optimized_results_model.wordplay.context_accuracy,
            "context_accuracy_diff": optimized_results_model.wordplay.context_accuracy
            - baseline_results_model.wordplay.context_accuracy,
            "ori_sema_baseline": baseline_results_model.wordplay.ori_sema,
            "ori_sema_optimized": optimized_results_model.wordplay.ori_sema,
            "ori_sema_diff": optimized_results_model.wordplay.ori_sema
            - baseline_results_model.wordplay.ori_sema,
            "ori_sema_cont_baseline": baseline_results_model.wordplay.ori_sema_cont,
            "ori_sema_cont_optimized": optimized_results_model.wordplay.ori_sema_cont,
            "ori_sema_cont_diff": optimized_results_model.wordplay.ori_sema_cont
            - baseline_results_model.wordplay.ori_sema_cont,
            "overall_accuracy_baseline": baseline_results_model.wordplay.overall_accuracy,
            "overall_accuracy_optimized": optimized_results_model.wordplay.overall_accuracy,
            "overall_accuracy_diff": optimized_results_model.wordplay.overall_accuracy
            - baseline_results_model.wordplay.overall_accuracy,
        }

        data_sp[model_name] = sp.values()
        data_wp[model_name] = wp.values()

    combined_data = pd.concat(
        [
            pd.DataFrame.from_dict(data_sp, orient="index", columns=columns),
            pd.DataFrame.from_dict(data_wp, orient="index", columns=columns),
        ],
        axis=0,
    )
    combined_data.to_csv(
        f"csvs/{filename}-{'postprocessed' if postprocessed else 'raw'}.csv",
        index=False,
        float_format="%.2f",
        sep=";",
    )


def results_to_diffs_csv(baseline_results, optimized_results, filename):
    """
    Export evaluation results to a CSV file with metrics as column headers.

    Args:
        baseline_results: Evaluation results for the baseline model
        optimized_results: Evaluation results for the optimized model
        filename: Base filename for the CSV
        postprocessed: Whether to use postprocessed results (True) or raw results (False)
    """

    data_sp: dict[str, dict[str, float]] = {}
    data_wp: dict[str, dict[str, float]] = {}

    columns = []

    for model_name, results in baseline_results.items():
        baseline_results_model_raw: ModelEvaluationMetrics = results.raw_results
        optimized_results_model_raw: ModelEvaluationMetrics = optimized_results[
            model_name
        ].raw_results
        baseline_results_model_postprocessed: ModelEvaluationMetrics = (
            results.postprocessed_results
        )
        optimized_results_model_postprocessed: ModelEvaluationMetrics = (
            optimized_results[model_name].postprocessed_results
        )

        sp = {
            "model_name": model_name,
            "original_accuracy_diff_raw": optimized_results_model_raw.sentence.original_accuracy
            - baseline_results_model_raw.sentence.original_accuracy,
            "semantic_accuracy_diff_raw": optimized_results_model_raw.sentence.semantic_accuracy
            - baseline_results_model_raw.sentence.semantic_accuracy,
            "context_accuracy_diff_raw": optimized_results_model_raw.sentence.context_accuracy
            - baseline_results_model_raw.sentence.context_accuracy,
            "ori_sema_diff_raw": optimized_results_model_raw.sentence.ori_sema
            - baseline_results_model_raw.sentence.ori_sema,
            "ori_sema_cont_diff_raw": optimized_results_model_raw.sentence.ori_sema_cont
            - baseline_results_model_raw.sentence.ori_sema_cont,
            "overall_accuracy_diff_raw": optimized_results_model_raw.sentence.overall_accuracy
            - baseline_results_model_raw.sentence.overall_accuracy,
            "original_accuracy_diff_postprocessed": optimized_results_model_postprocessed.sentence.original_accuracy
            - baseline_results_model_postprocessed.sentence.original_accuracy,
            "semantic_accuracy_diff_postprocessed": optimized_results_model_postprocessed.sentence.semantic_accuracy
            - baseline_results_model_postprocessed.sentence.semantic_accuracy,
            "context_accuracy_diff_postprocessed": optimized_results_model_postprocessed.sentence.context_accuracy
            - baseline_results_model_postprocessed.sentence.context_accuracy,
            "ori_sema_diff_postprocessed": optimized_results_model_postprocessed.sentence.ori_sema
            - baseline_results_model_postprocessed.sentence.ori_sema,
            "ori_sema_cont_diff_postprocessed": optimized_results_model_postprocessed.sentence.ori_sema_cont
            - baseline_results_model_postprocessed.sentence.ori_sema_cont,
            "overall_accuracy_diff_postprocessed": optimized_results_model_postprocessed.sentence.overall_accuracy
            - baseline_results_model_postprocessed.sentence.overall_accuracy,
        }

        columns = sp.keys()
        wp = {
            "model_name": model_name,
            "original_accuracy_diff_raw": optimized_results_model_raw.wordplay.original_accuracy
            - baseline_results_model_raw.wordplay.original_accuracy,
            "semantic_accuracy_diff_raw": optimized_results_model_raw.wordplay.semantic_accuracy
            - baseline_results_model_raw.wordplay.semantic_accuracy,
            "context_accuracy_diff_raw": optimized_results_model_raw.wordplay.context_accuracy
            - baseline_results_model_raw.wordplay.context_accuracy,
            "ori_sema_diff_raw": optimized_results_model_raw.wordplay.ori_sema
            - baseline_results_model_raw.wordplay.ori_sema,
            "ori_sema_cont_diff_raw": optimized_results_model_raw.wordplay.ori_sema_cont
            - baseline_results_model_raw.wordplay.ori_sema_cont,
            "overall_accuracy_diff_raw": optimized_results_model_raw.wordplay.overall_accuracy
            - baseline_results_model_raw.wordplay.overall_accuracy,
            "original_accuracy_diff_postprocessed": optimized_results_model_postprocessed.wordplay.original_accuracy
            - baseline_results_model_postprocessed.wordplay.original_accuracy,
            "semantic_accuracy_diff_postprocessed": optimized_results_model_postprocessed.wordplay.semantic_accuracy
            - baseline_results_model_postprocessed.wordplay.semantic_accuracy,
            "context_accuracy_diff_postprocessed": optimized_results_model_postprocessed.wordplay.context_accuracy
            - baseline_results_model_postprocessed.wordplay.context_accuracy,
            "ori_sema_diff_postprocessed": optimized_results_model_postprocessed.wordplay.ori_sema
            - baseline_results_model_postprocessed.wordplay.ori_sema,
            "ori_sema_cont_diff_postprocessed": optimized_results_model_postprocessed.wordplay.ori_sema_cont
            - baseline_results_model_postprocessed.wordplay.ori_sema_cont,
            "overall_accuracy_diff_postprocessed": optimized_results_model_postprocessed.wordplay.overall_accuracy
            - baseline_results_model_postprocessed.wordplay.overall_accuracy,
        }

        data_sp[model_name] = sp.values()
        data_wp[model_name] = wp.values()

    combined_data = pd.concat(
        [
            pd.DataFrame.from_dict(data_sp, orient="index", columns=columns),
            pd.DataFrame.from_dict(data_wp, orient="index", columns=columns),
        ],
        axis=0,
    )
    combined_data.to_csv(
        f"csvs/{filename}-diffs.csv",
        index=False,
        float_format="%.2f",
        sep=";",
    )


results_to_diffs_csv(
    baseline_zero_shot,
    optimized_zero_shot,
    "baseline-vs-optimized-zero-shot",
)

results_to_diffs_csv(
    baseline_zero_shot,
    baseline_few_shot,
    "baseline-zero-shot-vs-baseline-few-shot",
)

results_to_diffs_csv(
    optimized_zero_shot,
    baseline_zero_shot,
    "optimized-zero-shot-vs-baseline-zero-shot",
)

results_to_diffs_csv(
    optimized_zero_shot,
    optimized_few_shot,
    "optimized-zero-shot-vs-optimized-few-shot",
)


results_to_diffs_csv(
    baseline_few_shot,
    optimized_few_shot,
    "baseline-vs-optimized-few-shot",
)


results_to_csv(
    baseline_zero_shot,
    baseline_few_shot,
    "full-zero-shot-vs-full-few-shot",
    postprocessed=False,
)