In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob

from tqdm.notebook import tqdm

In [None]:
subset_dataframes = {}

subset_pattern = f"./data/result_*t_vs_*f_3attempts_llama.tsv"
# Get a list of files for the current subset pattern:
subset_files = glob.glob(subset_pattern)
for file_path in subset_files:
    # Extract num_true and num_false from the file name:
    num_true, num_false = map(int, (file_path.split('_')[1][0], file_path.split('_')[3][0]))
    
    df = pd.read_csv(file_path, sep='\t', decimal=",", header=0)
    subset_dataframes[(num_true, num_false)] = df

In [None]:
subset_dataframes[(1, 1)]

In [None]:
colors_accuracy = "Blues"
colors_f1 = "YlOrBr"
valuefont = {'fontname': 'Libertinus Serif', 'fontweight': 'heavy'}
captionfont = {'fontname': 'Libertinus Serif'}

metrics = [("accuracy", colors_accuracy), ("f1", colors_f1)]

with_zeros = (0, 1) in subset_dataframes.keys() or (1, 0) in subset_dataframes.keys()

# Loop over metrics (accuracy, f1)
for k, (metric, cmap) in enumerate(metrics):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))    
    matrices = []
    for num_true, num_false in tqdm(sorted(subset_dataframes.keys())):
        metric_values = subset_dataframes[(num_true, num_false)][metric].values
        matrices.append(metric_values)
    if with_zeros:
        matrices = [np.zeros_like(matrices[0])] + matrices

    # Create an average matrix
    average_matrix = np.array([np.nanmean(matrix) for matrix in matrices])
    average_matrix = average_matrix.reshape((
        len(set(num_true for num_true, _ in subset_dataframes.keys())),
        len(set(num_false for _, num_false in subset_dataframes.keys()))
    ))
    
    ax.ticklabel_format(useLocale=True)

    ax.imshow(average_matrix, cmap=cmap, vmin=0)
    ax.invert_yaxis()

    for m in range(average_matrix.shape[0]):
        for n in range(average_matrix.shape[1]):
            if with_zeros and m == n == 0:
                ax.text(0, 0, "—", ha='center', va='center', color='xkcd:almost black', **valuefont)
            else:
                value = average_matrix[m, n]
                ax.text(n, m, f"{value:.2%}", ha='center', va='center', color='xkcd:almost black' if value < 0.4 else 'white', **valuefont)

    ax.set_xticks(range(average_matrix.shape[1]))
    ax.set_xticklabels(sorted(set(num_false for _, num_false in subset_dataframes.keys())), **captionfont)

    ax.set_yticks(range(average_matrix.shape[0]))
    ax.set_yticklabels(sorted(set(num_true for num_true, _ in subset_dataframes.keys())), **captionfont)
    ax.set_xlabel("Inkorrekte Beispiele", **captionfont)
    ax.set_ylabel("Korrekte Beispiele", **captionfont)
    #axs[k].set_title(f"Unbekannte Anzahl korrekter Beispiele ({metric.capitalize()})", **captionfont)
    fig.savefig(f"eval_llama_{metric}.pdf", bbox_inches='tight')

# Adjust layout and show the plot
plt.tight_layout()
plt.show()