In [38]:
import pickle as dill
from pathlib import Path

from scripts.executor import WrappedResults


def load_wrapped_results(path: Path | str) -> WrappedResults:
    """
    Load WrappedResults from a file.
    :param path: Path to the file.
    :return: WrappedResults object.
    """
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path}")

    with open(path, "rb") as f:
        return dill.load(f)

In [39]:
from enum import Enum

import numpy as np

from scripts.executor import ExecutionResult
from scripts.evaluation import is_model_response_correct


class AnswerExtractorType(Enum):
    """
    Enum for the different types of answer extractors.
    """

    RAW = "raw"
    POSTPROCESSED = "postprocessed"


# Modified version from https://github.com/1171-jpg/BrainTeaser/blob/main/Prompting/utils.py
# to match the format of the results
def getResultdata(
    execution_results: list[ExecutionResult], extractor_type: AnswerExtractorType
) -> dict:
    word_play = {}
    reverse_play = {}
    for execution_result in execution_results:
        item = execution_result.riddle
        item_type = item.id.split("-")[0]
        item_id = item.id.split("-")[1].split("_")[0]
        if item_type == "WP":
            if item_id not in word_play:
                word_play[item_id] = [0, 0, 0]
        else:
            if item_id not in reverse_play:
                reverse_play[item_id] = [0, 0, 0]

    for execution_result in execution_results:
        item = execution_result.riddle
        item_type = item.id.split("-")[0]
        item_id = item.id.split("-")[1].split("_")[0]
        ad_type = 0
        if "SR" in item.id:
            ad_type = 1
        elif "CR" in item.id:
            ad_type = 2
        else:
            ad_type = 0

        raw_correct, postprocessed_correct = is_model_response_correct(execution_result)

        if item_type == "WP":
            word_play[item_id][ad_type] = (
                raw_correct
                if extractor_type == AnswerExtractorType.RAW
                else postprocessed_correct
            )
        else:
            reverse_play[item_id][ad_type] = (
                raw_correct
                if extractor_type == AnswerExtractorType.RAW
                else postprocessed_correct
            )

    return word_play, reverse_play


def getMetric(data_list):
    data_list = np.array(data_list)
    overall_accuracy = np.sum(data_list) / 3 / len(data_list)
    original_accuracy = np.sum(data_list, axis=0)[0] / len(data_list)
    semantic_accuracy = np.sum(data_list, axis=0)[1] / len(data_list)
    context_accuracy = np.sum(data_list, axis=0)[2] / len(data_list)
    ori_sema = np.sum(
        [1 if item[0] == 1 and item[1] == 1 else 0 for item in data_list]
    ) / len(data_list)
    ori_sema_cont = np.sum(
        [
            1 if item[0] == 1 and item[1] == 1 and item[2] == 1 else 0
            for item in data_list
        ]
    ) / len(data_list)

    return {
        "over_all_accuracy": np.round(overall_accuracy * 100, 2),
        "original_accuracy": np.round(original_accuracy * 100, 2),
        "semantic_accuracy": np.round(semantic_accuracy * 100, 2),
        "context_accuracy": np.round(context_accuracy * 100, 2),
        "ori_sema": np.round(ori_sema * 100, 2),
        "ori_sema_cont": np.round(ori_sema_cont * 100, 2),
    }


def getSeperateResult(word_play, reverse_thinking):
    final_result = {}
    word_data_list = []
    for item in word_play.values():
        word_data_list.append(item)
    final_result["wordplay"] = getMetric(word_data_list)

    reverse_data_list = []
    for item in reverse_thinking.values():
        reverse_data_list.append(item)
    final_result["sentence"] = getMetric(reverse_data_list)

    all_data = word_data_list + reverse_data_list
    final_result["all"] = getMetric(all_data)

    return final_result


def getModelEvaluations(model_results: list[ExecutionResult]):
    word_play_raw, sentence_play_raw = getResultdata(
        model_results, AnswerExtractorType.RAW
    )
    word_play_postprocessed, sentence_play_postprocessed = getResultdata(
        model_results, AnswerExtractorType.POSTPROCESSED
    )
    final_result_raw = getSeperateResult(word_play_raw, sentence_play_raw)
    final_result_postprocessed = getSeperateResult(
        word_play_postprocessed, sentence_play_postprocessed
    )

    return final_result_raw, final_result_postprocessed


def results_to_texttable(
    wrapped_results: WrappedResults,
    caption: str,
    label: str,
):
    """
    Convert WrappedResults to a LaTex Table.
    :param wrapped_results: WrappedResults object.
    :return: Texttable.
    """
    combines_model_execution_results: dict[str, list[ExecutionResult]] = {}
    for _, models_results in wrapped_results.results.items():
        for model_name, results in models_results.items():
            if model_name not in combines_model_execution_results:
                combines_model_execution_results[model_name] = []
            combines_model_execution_results[model_name].extend(results)

    def print_line(
        model_name: str, key: str, model_evaluation_raw, model_evaluation_postprocessed
    ):
        model_name = model_name.replace("_", "-").replace("-instruct-", "-")
        model_text = f"""
        \multicolumn{{1}}{{c|}}{{{model_name}}} & {model_evaluation_raw[key]["original_accuracy"]} ({model_evaluation_postprocessed[key]["original_accuracy"]}) & {model_evaluation_raw[key]["semantic_accuracy"]} ({model_evaluation_postprocessed[key]["semantic_accuracy"]}) & \multicolumn{{1}}{{c|}}{{{model_evaluation_raw[key]["context_accuracy"]} ({model_evaluation_postprocessed[key]["context_accuracy"]})}} & {model_evaluation_raw[key]["ori_sema"]} ({model_evaluation_postprocessed[key]["ori_sema"]}) & \multicolumn{{1}}{{c|}}{{{model_evaluation_raw[key]["ori_sema_cont"]} ({model_evaluation_postprocessed[key]["ori_sema_cont"]})}} & {model_evaluation_raw[key]["over_all_accuracy"]} ({model_evaluation_postprocessed[key]["over_all_accuracy"]}) \\\\
        """
        return model_text

    text_sp = []
    text_wp = []
    for (
        model_name,
        models_execution_results,
    ) in combines_model_execution_results.items():
        model_evaluation_raw, model_evaluation_postprocessed = getModelEvaluations(
            models_execution_results
        )

        text_sp.append(
            print_line(
                model_name,
                "sentence",
                model_evaluation_raw,
                model_evaluation_postprocessed,
            )
        )
        text_wp.append(
            print_line(
                model_name,
                "wordplay",
                model_evaluation_raw,
                model_evaluation_postprocessed,
            )
        )

    sp_data = "\n".join(text_sp)
    wp_data = "\n".join(text_wp)
    full_table = (
        """
    \\begin{table}[]\r\n\\caption{{caption}}\r\n\\label{tab:{label}}\r\n\\resizebox{\\textwidth}{!}{%\r\n\\begin{tabular}{ccccccc}\r\n\\cline{1-7}\r\n & \\multicolumn{3}{|c|}{Instance-based} & \\multicolumn{2}{c|}{Group-based} \\\\ \\cline{1-6}\r\n\\multicolumn{1}{c|}{\\textbf{Model}} & \\textbf{Original} & \\textbf{Semantic} & \\multicolumn{1}{c|}{\\textbf{Context}} & \\textbf{Ori \\& Sem} & \\multicolumn{1}{c|}{\\textbf{Ori \\& Sem \\& Con}} & \\textbf{Overall} \\\\ \\hline\r\n\\multicolumn{7}{c}{Sentence Puzzle} \\\\ \\hline\r\n\\multicolumn{1}{c|}{\\textbf{Random*}} & 25.52 & 24.88 & \\multicolumn{1}{c|}{22.81} & 5.58 & \\multicolumn{1}{c|}{1.44} & 24.40 \\\\ \\hline\r\n{{sp_data}}\r\n\\hline\r\n\\multicolumn{7}{c}{Word Puzzle} \\\\ \\hline\r\n\\multicolumn{1}{c|}{\\textbf{Random*}} & 26.02 & 27.85 & \\multicolumn{1}{c|}{22.51} & 7.32 & \\multicolumn{1}{c|}{1.83} & 25.34 \\\\ \\hline\r\n{{wp_data}}\r\n\\\\ \\hline\r\n\\end{tabular}%\r\n}\r\n\\end{table}
    """.replace("{{sp_data}}", sp_data)
        .replace("{{wp_data}}", wp_data)
        .replace("{caption}", caption)
        .replace("{label}", label)
    )
    output_table = Path("tables") / f"{label}.tex"
    if not output_table.parent.exists():
        output_table.parent.mkdir(parents=True)
    with open(output_table, "w") as f:
        f.write(full_table)


def create_latex_table_from_results_file(results_file: Path, caption: str, label: str):
    wrapped_results = load_wrapped_results(results_file)
    results_to_texttable(wrapped_results, caption, label)

In [None]:
create_latex_table_from_results_file(
    "results/baseline-zero-shot-evaluation/baseline-zero-shot-evaluation_results.pkl",
    "Zero-Shot Accuracy (Baseline)",
    "baseline-zero-shot-evaluation",
)

In [None]:
create_latex_table_from_results_file(
    "results/system-optimized-zero-shot-evaluation/system-optimized-zero-shot-evaluation_results.pkl",
    "Zero-Shot Accuracy (with Prompt Engeneering)",
    "optimized-zero-shot-evaluation",
)