In [1]:
import json
import pandas as pd
import re
import seaborn as sns
import numpy as np
from matplotlib.lines import Line2D

from pathlib import Path

results_path = Path("..", "..", "experiments", "models", "nlp", "eval_results.json")


In [None]:
with results_path.open() as f:
    results = json.load(f)
data = {"model": [], "dataset": [], "split": [], "acc": []}
for model, model_res in results.items():
    for dataset, split_res in model_res.items():
        for split, info in split_res.items():
            data["model"].append(model)
            data["dataset"].append(dataset)
            data["split"].append(split)
            data["acc"].append(info["accuracy"])
data = pd.DataFrame.from_dict(data)


def extract_setting(s: str):
    # Define the regex pattern
    # The pattern looks for one of the words "RandomLabels", "Augmentation", "Shortcuts" between underscores
    pattern = r"_(RandomLabels|Augmentation|Shortcut|Normal)_"

    # Use re.search to find the pattern in the input string
    match = re.search(pattern, s)

    # Extract the matched string if it exists
    if match:
        extracted_label = match.group(1)
        return extracted_label
    else:
        return np.nan


def extract_setting_strength(s: str):
    pattern = r"_(RandomLabels|Augmentation|Shortcut)_(\d+)_"
    match = re.search(pattern, s)

    if match:
        extracted_label = match.group(2)
        return extracted_label
    else:
        return np.nan


data["train_dataset_base"] = data["model"].apply(lambda s: s.split("_")[2])
data["train_setting"] = data["model"].apply(extract_setting)
data["train_setting_strength"] = data["model"].apply(extract_setting_strength)
data["train_seed"] = data["model"].apply(lambda s: s.split("_")[-2])
data["arch"] = data["model"].apply(lambda s: s.split("_")[1])
data

In [None]:
data.train_setting.unique()

In [None]:
# data.loc[data.train_setting=="Shortcut"]
# data.iloc[346, 0]
data.loc[data.model == "NLP_BERT-L_mnli_sc_rate08385_Shortcut_08385_4_None"]

In [None]:
print(data[data.train_setting=="RandomLabels"].model.unique())
data.loc[data.model == "NLP_BERT-L_sst2_mem_rate05_RandomLabels_50_0_None"]

In [None]:
data[data.train_setting=="Normal"].model.unique()
data.loc[data.model == "NLP_BERT-L_mnli_Normal_3_None"]
data.loc[data.model == "NLP_BERT-L_sst2_Normal_6_None"]


In [None]:
data.groupby(["train_dataset_base", "train_setting", "model"]).count()

In [None]:
data.tail()

In [None]:
selection = data.loc[
    # (data.arch == "albert-base-v2") &
    # (data.arch == "BERT-L") &
    (data.arch == "smollm2-1.7b") &
    (data.train_dataset_base == "sst2") &
    # (data.train_dataset_base == "mnli") &
    # (data.train_setting == "Augmentation")
    (data.train_setting == "Shortcut")
    # (data.train_setting == "RandomLabels")
    # & (data.dataset=="sst2_aug_rate0")
    # & (data.dataset=="mnli_aug_rate0")
    # & (data.dataset.isin(["sst2_sc_rate10", "sst2_sc_rate0558"]))
    # & (data.dataset.isin(["mnli_sc_rate1", "mnli_sc_rate0354"]))
    # & (data.dataset.isin(["sst2_mem_rate0", "mnli_mem_rate0"]))
    # & (data.train_setting_strength.isin(["0", "50", "100"]))
]

sns.catplot(
    data=selection,
    y="acc",
    hue="train_setting_strength",
    x="split",
    col="dataset",
    palette={
                "0": "C0",
                "25": "C1",
                "50": "C2",
                "75": "C4",
                "100": "C3",
                # shortcut stuff (mnli)
                "0354": "C0",
                "05155": "C1",
                "0677": "C2",
                "08385": "C4",
                "1": "C3",
                # shortcut (sst2)
                "0558": "C0",
                "0668": "C1",
                "0779": "C2",
                "0889": "C4",
                "10": "C3",
            },

)


In [10]:
import matplotlib.pyplot as plt

In [None]:
data.dataset.unique()

In [None]:
palette = {
    0: "C0",
    25: "C4",
    50: "C2",
    75: "C1",
    100: "C3",
    # shortcut stuff (mnli)
    35.4: "C0",
    51.5: "C4",
    67.7: "C2",
    83.9: "C1",
    100: "C3",
    # shortcut (sst2)
    55.8: "C0",
    66.8: "C4",
    77.9: "C2",
    88.9: "C1",
    100: "C3",
}
data.loc[:, "Strength"] = data.loc[:, "train_setting_strength"].map(
    {
        "0": 0,
        "25": 25,
        "50": 50,
        "75": 75,
        "100": 100,
        # shortcut stuff (mnli)
        "0354": 35.4,
        "05155": 51.5,
        "0677": 67.7,
        "08385": 83.9,
        "1": 100,
        # shortcut (sst2)
        "0558": 55.8,
        "0668": 66.8,
        "0779": 77.9,
        "0889": 88.9,
        "10": 100,
    }
)

sns.set_theme("paper", style="whitegrid", font_scale=1.5)

def create_validation_plot(data, arch: str):
    fig, axes = plt.subplots(1, 3, figsize=(10, 6))
    sns.stripplot(
    data=data.loc[
        (data.train_setting == "RandomLabels")
        & (data.arch == arch)
        & data.dataset.isin(["sst2_mem_rate0", "mnli_mem_rate0"])
        & data.split.isin(["validation", "validation_matched"])
        & data.Strength.isin([0, 75, 100])
        & (data.train_seed.astype(int) < 128)
    ],
    y="acc",
    hue="Strength",
    x="train_dataset_base",
    palette=palette,
    ax=axes[0],
)
    axes[0].set_title("Label Randomization")
    axes[0].set_ylabel("Validation Accuracy")
    axes[0].set_xlabel("Dataset")
    ax = axes[0]
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=labels, ncol=1, fontsize="small", title="Rate")
    xticklabels = ["SST2", "MNLI"]
    ax.set_xticklabels(xticklabels)

    sns.stripplot(
        data=data.loc[
            (data.train_setting == "Shortcut")
        & (data.arch == arch)
            & data.dataset.isin(["sst2_sc_rate0558", "mnli_sc_rate0354"])
            & data.split.isin(["validation", "validation_matched"])
            & data.Strength.isin([55.8, 35.4, 83.9, 88.9, 100])
        ],
        y="acc",
        hue="Strength",
        x="train_dataset_base",
        palette=palette,
        ax=axes[1],
        legend=True,
    )
    axes[1].set_title("Shortcut Affinity")
    axes[1].set_ylabel("")
    axes[1].set_xlabel("Dataset")
    ax = axes[1]
    handles, labels = ax.get_legend_handles_labels()
    handles[0].set_markerfacecolor("C0")
    handles[1].set_markerfacecolor("C1")
    handles[2].set_markerfacecolor("C3")
    handles = [handles[0], handles[1], handles[2]]
    labels = ["55.8/35.4", "88.9/83.9", "100/100"]
    ax.legend(handles=handles, labels=labels, ncol=1, fontsize="small", title="Rate (SST2/MNLI)", )
    ax.set_xticklabels(xticklabels)

    if arch == "BERT-L":
        aug_rates = [0, 25, 100]
    else:
        aug_rates = [0, 100]
    sns.stripplot(
        data=data.loc[
            (data.train_setting == "Augmentation")
            & (data.arch == arch)
            & (
                (
                    data.dataset.isin(["sst2_aug_rate0"])
                    & data.split.isin(["validation"])
                    & data.Strength.isin([0, 100])
                )
                | (
                    data.dataset.isin(["mnli_aug_rate0"])
                    & data.split.isin(["validation_matched"])
                    & data.Strength.isin(aug_rates)
                )
            )
        ],
        y="acc",
        hue="Strength",
        x="train_dataset_base",
        palette=palette,
        ax=axes[2],
    )
    axes[2].set_title("Augmentation")
    axes[2].set_ylabel("")
    axes[2].set_xlabel("Dataset")

    ax = axes[2]
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=labels, ncol=1, fontsize="small", title="Rate", )
    ax.set_xticklabels(xticklabels)

    plt.tight_layout()
    fig.savefig(f"../../figs/nlp_accs_{arch}.pdf", bbox_inches="tight")


def create_validation_plot_smollm(data, arch: str):
    fig, axes = plt.subplots(1, 2, figsize=(7, 6))
    sns.stripplot(
    data=data.loc[
        (data.train_setting == "RandomLabels")
        & (data.arch == arch)
        & data.dataset.isin(["sst2_sft_mem_rate0", "mnli_sft_mem_rate0"])
        & data.split.isin(["validation", "validation_matched"])
        & data.Strength.isin([0, 75, 100])
        & (data.train_seed.astype(int) < 128)
    ],
    y="acc",
    hue="Strength",
    x="train_dataset_base",
    palette=palette,
    ax=axes[0],
)
    axes[0].set_title("Label Randomization")
    axes[0].set_ylabel("Validation Accuracy")
    axes[0].set_xlabel("Dataset")
    ax = axes[0]
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=labels, ncol=1, fontsize="small", title="Rate")
    xticklabels = ["SST2", "MNLI"]
    ax.set_xticklabels(xticklabels)

    sns.stripplot(
        data=data.loc[
            (data.train_setting == "Shortcut")
        & (data.arch == arch)
            & data.dataset.isin(["sst2_sft_sc_rate0558", "mnli_sft_sc_rate0354"])
            & data.split.isin(["validation", "validation_matched"])
            & data.Strength.isin([55.8, 35.4, 83.9, 88.9, 100])
        ],
        y="acc",
        hue="Strength",
        x="train_dataset_base",
        palette=palette,
        ax=axes[1],
        legend=True,
    )
    axes[1].set_title("Shortcut Affinity")
    axes[1].set_ylabel("")
    axes[1].set_xlabel("Dataset")
    ax = axes[1]
    handles, labels = ax.get_legend_handles_labels()
    handles[0].set_markerfacecolor("C0")
    handles[1].set_markerfacecolor("C1")
    handles[2].set_markerfacecolor("C3")
    handles = [handles[0], handles[1], handles[2]]
    labels = ["55.8/35.4", "88.9/83.9", "100/100"]
    ax.legend(handles=handles, labels=labels, ncol=1, fontsize="small", title="Rate (SST2/MNLI)", )
    ax.set_xticklabels(xticklabels)

    plt.tight_layout()
    fig.savefig(f"../../figs/nlp_accs_{arch}.pdf", bbox_inches="tight")

# create_validation_plot(data, "BERT-L")
# create_validation_plot(data, "albert-base-v2")
create_validation_plot_smollm(data, "smollm2-1.7b")

In [None]:
data

In [None]:
data.loc[(data.arch=="albert-base-v2") & (data.train_setting=="RandomLabels") & (data.acc == 0.00), "model"].value_counts()

In [None]:
data.loc[
        (data.train_setting == "RandomLabels")
        & (data.arch == "albert-base-v2")
        & data.dataset.isin(["sst2_mem_rate0", "mnli_mem_rate0"])
        & data.split.isin(["validation", "validation_matched"])
        & data.Strength.isin([0, 75, 100])
        & (data.train_seed.astype(int) < 128)

        & (~((data.Strength == 75.0) & (data.train_seed == "0")))
    ].sort_values(by="acc", ascending=False)

## albert ok?

In [None]:
data.loc[:, "arch"] = data.loc[:, "model"].apply(lambda s: s.split("_")[1])
data.head()

In [None]:
sns.catplot(data=data, x="arch", y="acc", col="train_setting", hue="train_setting_strength", row="train_dataset_base", kind="strip", sharey=False)