In [None]:
import pandas as pd
import seaborn as sns
import re
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.ticker import ScalarFormatter

from eleet.evaluate import RunDescription

METRICS_FILE = "../predictions/test/metrics_eleet-t2t-gpt-gpt4-llama-bert-tabert-rotowire-all-test-limit-4294967296.csv"
PERF_METRIC = "mean_f1"


query_dict = {'G_name(π_name,Assists,Blocks,Points,To…(Player))': 'Aggregation',
              'player_info ⨝ player_to_reports ⨝ Player': 'Join',
              'player_stats ∪ Player': 'Union',
              'π_name,Assists,Blocks,Points,To…(Player)': 'Scan',
              'π_name,Assists,Blocks,Points,To…(σ_Points=28[0.096](Player))': 'Selection'}

metrics = pd.read_csv(METRICS_FILE)
metrics.replace({"ELEET": "ELEET weights", "LLaMA": "LLaMA-2 (7B)", 4096: 3398, 16384: 3398, "TABERT": "TaBERT weights", "BERT": "BERT weights",
                 **query_dict}, inplace=True)
metrics = metrics[metrics["split_size"] < 4000]



# matplotlib.rcParams.update({'font.size': })
plt.rc('text', usetex=False)
sns.set_theme(style="whitegrid", font_scale=1.0)


fig = plt.figure(figsize=(6, 4))
gs = fig.add_gridspec(nrows=2, ncols=3, hspace=0.5, wspace=0.1)

labels = list()
handles = list()

for i, query in enumerate(["Scan", "Join", "Union", "Selection", "Aggregation"]):
    ax = fig.add_subplot(gs[i // 3, i % 3])

    sns.lineplot(metrics[metrics["query"] == query], y="mean_f1", x="split_size", hue="method", ax=ax, marker="X",
                 hue_order=["ELEET weights", "TaBERT weights", "BERT weights"])
    ax.set_title(query)

    handles_ax, labels_ax = ax.get_legend_handles_labels()
    handles.extend(handles_ax)
    labels.extend(labels_ax)
    ax.get_legend().remove()

    ax.set_xscale("log")
    ax.set_xticks([4, 64, 1024])
    ax.set_xticklabels(["4", "16", "1024"])
    ax.set_yticks([0.2, 0.4, 0.6, 0.8])
    ax.set_xlabel("Number of Labeled Texts" if i == 4 else "")
    ax.set_ylabel("Mean F1")
    ax.set_ylim((None, 1.0))
    if i < 2:
        ax.set_xlabel(None)
        ax.set_xticklabels(["", "", ""])
    if i % 3:
        ax.set_ylabel(None)
        ax.set_yticklabels(["", "", "", ""])

hl_dict = {l.split("_")[0]: h for l, h in zip(labels, handles)}
print(hl_dict)
labels = ["ELEET weights", "TaBERT weights", "BERT weights"]
leg = fig.legend([hl_dict[l] for l in labels], labels, bbox_to_anchor=(0.68, 0.09), ncol=1, loc='lower left', borderaxespad=0., frameon=False)

plt.savefig("/home/murban/exp4.pdf", bbox_inches="tight")
plt.show()