## Import libraries

In [None]:
import os
import jsonlines

import numpy as np
from scipy.stats import spearmanr, pearsonr

In [None]:
result_dir = "../../../results"

## Correlation with consistency

In [None]:
os.makedirs(f"correlation_analysis_table", exist_ok=True)

dataset_names = ["CommonsenseQA", "QASC", "100TFQA", "GSM8K", "MMLU-Pro-Law-100Q"]
model_names = ["Llama-3.1-8B-Instruct", "gpt-4o-2024-11-20"]
prompting_strategies = ["zero-shot", "zero-shot-cot", "few-shot", "few-shot-cot"]
num_formats = NUM_FORMATS = 8

corr_matrix, p_matrix = np.zeros((len(model_names), len(dataset_names), len(prompting_strategies))), np.zeros((len(model_names), len(dataset_names), len(prompting_strategies)))
for i, model_name in enumerate(model_names):
    for j, dataset_name in enumerate(dataset_names):
        for k, prompting_strategy in enumerate(prompting_strategies):
            output_dir = f"{result_dir}/{dataset_name}/{model_name}"
            predictions_path = os.path.join(output_dir, f"{prompting_strategy}_predictions.jsonl")
            raw_predictions_path = os.path.join(output_dir, f"{prompting_strategy}_raw_predictions.jsonl")
            try:
                with jsonlines.open(predictions_path) as fin:
                    id_predictions_map, id_consistency_map = {}, {}
                    for example in fin.iter():
                        id_predictions_map[example["id"]] = example["predictions"]
                        id_consistency_map[example["id"]] = example["consistency"]["mean"]
                X, Y, Z = [], [], []
                with jsonlines.open(raw_predictions_path) as fin:
                    for example in fin.iter():
                        confidences = []
                        for format_id, top_tokens in example["top_tokens"].items():
                            confidence = -1
                            for ii, top_tokenss in enumerate(top_tokens[::-1]):
                                if top_tokenss[0] == id_predictions_map[example["id"]][format_id]:
                                    if "top_probs" in example:
                                        confidence = example["top_probs"][format_id][-(ii+1)][0]
                                    else:
                                        confidence = np.exp(example["top_logprobs"][format_id][-(ii+1)][0])
                                    confidences.append(confidence)
                                    break
                        if len(confidences) != 8:
                            # print(example["top_tokens"])
                            # print(len(confidences))
                            pass
                            # raise Exception
                        else:
                            mean_confidence = np.mean(confidences)
                            X.append(mean_confidence)
                            Y.append(id_consistency_map[example["id"]])
                            Z.append(1.0*(id_consistency_map[example["id"]] >= 0.99))
                X, Y, Z = np.array(X), np.array(Y), np.array(Z)
                Y = Z # set-wise consistency
            except Exception as e:
                # print(e)
                corr_matrix[i][j][k] = -2
                p_matrix[i][j][k] = -2
                continue

            # Compute Pearson and Spearman correlations
            pearson_corr, pearson_p = pearsonr(X, Y)
            spearman_corr, spearman_p = spearmanr(X, Y)

            # Print correlation results
            # print(f"{dataset_name} / {model_name} / {prompting_strategy}")
            # print(len(X), len(Y), len(Z))
            # print(f"Pearson correlation: {pearson_corr:.2f} (p-value: {pearson_p:.3e})")
            # print(f"Spearman correlation: {spearman_corr:.2f} (p-value: {spearman_p:.3e})")

            corr_matrix[i][j][k] = spearman_corr
            p_matrix[i][j][k] = spearman_p

prompting_strategies = ["Zero-shot", "Zero-shot CoT", "Few-shot", "Few-shot CoT"]
# Create LaTeX table string
latex_code = "\\begin{table}[ht]\n"
latex_code += "\\centering\n"
latex_code += "\\caption{Correlation analysis between model confidence and setwise consistency. Each cell represents pearson correlation coefficient and p-value (in parenthesis) for GPT-4o across different tasks and prompting strategies.}\n"
latex_code += "\\label{tab:confidence_correlation}\n"
latex_code += "\\begin{tabular}{c|c|cccc}\n"
latex_code += "\\toprule\n"
latex_code += "Model & Task & " + " & ".join(prompting_strategies) + " \\\\\n"
latex_code += "\\midrule\\midrule\n"

# Fill the table with data
for model_idx, model in enumerate(model_names):
    latex_code += f"\\multirow{{5}}{{*}}{{{model}}}\n"
    for task_idx, task in enumerate(dataset_names):
        latex_code += "      "
        latex_code += f"& {task} "
        for strat_idx in range(len(prompting_strategies)):
            corr = corr_matrix[model_idx, task_idx, strat_idx]
            p = p_matrix[model_idx, task_idx, strat_idx]

            # Skip conditions
            if corr < -1.5:
                latex_code += f"& - "
                continue

            corr_str = f"{corr:.2f}"
            p_str = f"{p:.3e}" if p < 0.001 else f"{p:.3f}"
            
            latex_code += f"& {corr_str} ({p_str}) "
        latex_code += "\\\\\n"
    latex_code += "\\midrule\n"

# Close LaTeX table
latex_code = latex_code.rstrip("\\midrule\n")  # Remove last midrule
latex_code += "\\\\\n"
latex_code += "\\bottomrule\n"
latex_code += "\\end{tabular}\n"
latex_code += "\\end{table}"

# Display the generated LaTeX table code
print(latex_code)

# with open(f"correlation_analysis_table/correlation_with_consistency_table.txt", "w") as fout:
#     fout.write(latex_code)