In [None]:
import pandas as pd
import seaborn as sns
import re
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.ticker import ScalarFormatter

In [None]:
palette = sns.color_palette()
color_matching = {
    "ELEET": palette[0],  # blue
    "Text-To-Table": palette[2],  # green
    "LLaMA-2 7B (ic)": palette[1],  # orange
    "LLaMA-2 7B (ft)": palette[3],  # red
    "gpt-3.5-turbo-0125": palette[4],  #purple
    "gpt-4-0613": palette[6]
}
methods = ["ELEET", "Text-To-Table", "LLaMA-2 7B (ic)", "LLaMA-2 7B (ft)"]

hue_methods = ["LLaMA-2 7B (ic)", "LLaMA-2 7B (ft)", "ELEET", "Text-To-Table"]
color_palette = [color_matching[m] for m in hue_methods]


In [None]:
from eleet.evaluate import RunDescription

# METRICS_FILE = "../predictions/test/metrics_eleet-t2t-gpt-gpt4-llama-bert-tabert-gpt-ft-rotowire-all-test-limit-4294967296.csv"
METRICS_FILE = "../predictions/test/metrics_eleet-t2t-llama-gpt-3.5-turbo-0125-gpt-4-0613-llama-ft-gpt-ft-rotowire-all-test-limit-4294967296.csv"
PERF_METRIC = "mean_f1"

query_dict = {'G_name(π_name,Assists,Blocks,Points,To…(Player))': 'Aggregation',
              'player_info ⨝ player_to_reports ⨝ Player': 'Join',
              'player_stats ∪ Player': 'Union',
              'π_name,Assists,Blocks,Points,To…(Player)': 'Scan',
              'π_name,Assists,Blocks,Points,To…(σ_Points=28[0.096](Player))': 'Selection'}

metrics = pd.read_csv(METRICS_FILE)
metrics["method"] = metrics["method"].apply(lambda x: "gpt-3.5-turbo (ft)" if x.startswith("ft:") else x)
metrics.replace({"ELEET": "ELEET", "LLaMA": "LLaMA-2 7B (ic)", 4096: 3398, 16384: 3398, "LLaMA-FT": "LLaMA-2 7B (ft)", "gpt-3.5-turbo-0125": "gpt-3.5-turbo-0125 (ic)",
                 "gpt-4-0613": "gpt-4-0613 (ic)", **query_dict}, inplace=True)

# adjustments
metrics = metrics[metrics["split_size"] < 4000]  # dataset is smaller than 4000
metrics = metrics.loc[~((metrics["method"] == "LLaMA-2 7B (ic)") & (metrics["split_size"] > 4))]  # context size too small
metrics.loc[metrics["method"] == "gpt-4-0613 (ic)", "split_size"] = 4



slm = ["ELEET", "Text-To-Table"]
in_context = ["LLaMA-2 7B (ic)", "gpt-3.5-turbo-0125 (ic)", "gpt-4-0613 (ic)"]
markers = ["X" if m in slm else ("o" if m in in_context else ".") for m in methods]


# matplotlib.rcParams.update({'font.size': })
plt.rc('text', usetex=False)
sns.set_theme(style="whitegrid", font_scale=1.0)


fig = plt.figure(figsize=(6, 4))
gs = fig.add_gridspec(nrows=2, ncols=3, hspace=0.5, wspace=0.1)

labels = list()
handles = list()

for i, query in enumerate(["Scan", "Join", "Union", "Selection", "Aggregation"]):
    ax = fig.add_subplot(gs[i // 3, i % 3])

    sns.lineplot(metrics[metrics["query"] == query], y="mean_f1", x="split_size", hue="method", ax=ax, dashes=False,
                 hue_order=hue_methods[::-1], palette=sns.color_palette(color_palette[::-1]), marker = "X")  # style_order=methods[::-1] markers=markers[::-1], style="method",
    ax.set_title(query)

    handles_ax, labels_ax = ax.get_legend_handles_labels()
    handles.extend(handles_ax)
    labels.extend(labels_ax)
    ax.get_legend().remove()

    ax.set_xscale("log")
    ax.set_xticks([4, 64, 1024])
    ax.set_xticklabels(["4", "64", "1024"])
    ax.set_yticks([0.2, 0.4, 0.6, 0.8])
    ax.set_xlabel("Number of Labeled Texts" if i == 4 else "")
    ax.set_ylabel("Mean F1")
    ax.set_ylim((None, 1.0))
    if i < 2:
        ax.set_xlabel(None)
        ax.set_xticklabels(["", "", ""])
    if i % 3:
        ax.set_ylabel(None)
        ax.set_yticklabels(["", "", "", ""])

print(labels)
hl_dict = {l.split("_")[0]: h for l, h in zip(labels, handles)}
print(hl_dict)
leg = fig.legend([hl_dict[l] for l in methods], methods, bbox_to_anchor=(0.65, 0.07), ncol=1, loc='lower left', borderaxespad=0., frameon=False)

plt.savefig("/home/murban/exp2.pdf", bbox_inches="tight")
plt.show()