In [None]:
import re
import ast
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import classification_report

tqdm.pandas()
sns.set_style("darkgrid")
sns.set_context("notebook")
aug_regex = re.compile(r"<aug>(.*?)</aug>", re.DOTALL)

In [None]:
sst5_data = pd.read_csv("../datasets/analysis/sst5_stabilityaiStableBeluga-7b_random_16_Kyle1668boss-sentiment-bert-base-uncased_style_logs.csv")
display(sst5_data.head(1))
display(sst5_data.shape)

toxigen_data = pd.read_csv("../datasets/analysis/toxigen_stabilityaiStableBeluga-7b_random_16_Kyle1668boss-toxicity-bert-base-uncased_style_logs.csv")
display(toxigen_data.head(1))
display(toxigen_data.shape)

agt_data = pd.read_csv("../datasets/analysis/test_stabilityaiStableBeluga-7b_random_16_Kyle1668ag-news-bert-base-uncased_style_logs.csv")
display(agt_data.head(1))
display(agt_data.shape)

## Analyze ICR Generations

In [None]:
sst5_data["input"].values[0]

## Does TTA Effect Some Classes More Than Others?

In [None]:
sst5_data.value_counts(["label", "outcome"]).sort_index()

In [None]:
# get the overall ratio of New Corrections to New Mistakes for sst5_data, toxigen_data, and agt_data
pd.concat([sst5_data, toxigen_data, agt_data]).value_counts(["outcome"]).sort_index()

In [None]:
# for each dataset, get the percent of examples that are unchanged vs new
sst5_outcomes = sst5_data["outcome"].value_counts(normalize=True)
new_predcitions_percent = 100 * sst5_outcomes[sst5_outcomes.index == "New Correct"].values[0] + sst5_outcomes[sst5_outcomes.index == "New Mistake"].values[0]
print(f"SST-5: {new_predcitions_percent:.2f}% of examples are new predictions")

toxicgen_outcomes = toxigen_data["outcome"].value_counts(normalize=True)
new_predcitions_percent = 100 * toxicgen_outcomes[toxicgen_outcomes.index == "New Correct"].values[0] + toxicgen_outcomes[toxicgen_outcomes.index == "New Mistake"].values[0]
print(f"ToxicGen: {new_predcitions_percent:.2f}% of examples are new predictions")

agt_outcomes = agt_data["outcome"].value_counts(normalize=True)
new_predcitions_percent = 100 * agt_outcomes[agt_outcomes.index == "New Correct"].values[0] + agt_outcomes[agt_outcomes.index == "New Mistake"].values[0]
print(f"AGT: {new_predcitions_percent:.2f}% of examples are new predictions")

In [None]:
# clear plots
plt.clf()

# Create three histograms on one row
fig, axes = plt.subplots(1, 3, figsize=(20, 5))

sst5_labels = {
    0: "Negative",
    1: "Positive",
    2: "Neutral",
}
sst5_corruptions_corrections = sst5_data[(sst5_data["outcome"] == "New Correct") | (sst5_data["outcome"] == "New Mistake")]
sst5_corruptions_corrections.sort_values(by=["label", "outcome"], inplace=True)
sst5_corruptions_corrections["label"] = sst5_corruptions_corrections["label"].apply(lambda l: sst5_labels[l])
# sort values by Negative, Neutral, Positive in that order
sst5_corruptions_corrections.sort_values(by=["label"], inplace=True, key=lambda x: x.map({"Negative": 0, "Neutral": 1, "Positive": 2}))
sns.histplot(data=sst5_corruptions_corrections, x="label", hue="outcome", multiple="dodge", shrink=.8, ax=axes[0])

toxigen_labels = {
    0: "Non-Toxic",
    1: "Toxic",
}
toxigen_corruptions_corrections = toxigen_data[(toxigen_data["outcome"] == "New Correct") | (toxigen_data["outcome"] == "New Mistake")]
toxigen_corruptions_corrections.sort_values(by=["label", "outcome"], inplace=True)
toxigen_corruptions_corrections["label"] = toxigen_corruptions_corrections["label"].apply(lambda l: toxigen_labels[l])
sns.histplot(data=toxigen_corruptions_corrections, x="label", hue="outcome", multiple="dodge", shrink=.8, ax=axes[1])

agt_labels = {
    0: "World",
    1: "Sports",
    2: "Business",
    3: "Sci/Tech",
}
agt_corruptions_corrections = agt_data[(agt_data["outcome"] == "New Correct") | (agt_data["outcome"] == "New Mistake")]
agt_corruptions_corrections.sort_values(by=["label", "outcome"], inplace=True)
agt_corruptions_corrections["label"] = agt_corruptions_corrections["label"].apply(lambda l: agt_labels[l])
sns.histplot(data=agt_corruptions_corrections, x="label", hue="outcome", multiple="dodge", shrink=.8, ax=axes[2])

axes[0].set_ylabel("Count", labelpad=20, fontsize=14)
axes[1].set_ylabel("")
axes[2].set_ylabel("")
axes[0].set_xlabel("SST-5", labelpad=20, fontsize=14)
axes[1].set_xlabel("Toxigen", labelpad=20, fontsize=14)
axes[2].set_xlabel("AG News Tweets", labelpad=20, fontsize=14)

# set x labels above the plots
axes[0].xaxis.set_label_position('top')
axes[1].xaxis.set_label_position('top')
axes[2].xaxis.set_label_position('top')

# Have a shared legend
axes[0].get_legend().remove()
axes[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), labels=["Corruptions", "Corrections"], ncol=2, fancybox=False, frameon=False, fontsize=14)
axes[2].get_legend().remove()


# add padding
fig.tight_layout(pad=3.0)
fig.savefig("../datasets/analysis/figures/corruptions_corrections_histograms.png", bbox_inches='tight', dpi=300)



# Entropy-Based Selective Augmentation# Entropy Analysis

## Entropy Accuracy Curves

In [None]:
for rewrites in sst5_data[sst5_data["input"].str.contains("`")]["input"].values:
    for current_rewrite in re.findall(aug_regex, rewrites):
        print(current_rewrite)
        print(ast.literal_eval(current_rewrite))
        break

## Entropy Accuracy Curves

In [None]:
thresholds = np.arange(0, 1, 0.00001)
# thresholds = np.arange(0, 1, 0.05)

baseline_perf = {
    "SST-5": 0.6847,
    "Sem Eval": 0.4498,
    "Dynasent": 0.4271,
    "ToxiGen": 0.6670,
    "Adv Civil": 0.3050,
    "Implicit Hate": 0.6454,
    "AG News Tweets": 0.8857,
}

In [None]:
# create a fix where each plot is 5 inches wide and 5 inches tall with 2 padding
fig, axs = plt.subplots(ncols=3, figsize=(15, 6))

def calculate_entropy_threshold_jugments(inference_log_frame, dataset_name, half=False):
    threshold_scores = []
    threshold_rewrite_rates = []
    for t in tqdm(thresholds, desc="Calculating entropy threshold scores"):
        t_perf, t_rate = get_threshold_accuracy(t, inference_log_frame)
        threshold_scores.append(t_perf)
        threshold_rewrite_rates.append(t_rate)

    thresholds_frame = pd.DataFrame({"threshold": thresholds, "accuracy": threshold_scores, "rewrite_rate": threshold_rewrite_rates})

    # Set line splot
    coordinates = {
        "SST-5": 0,
        "ToxiGen": 1,
        "AG News Tweets": 2,
    }

    # Create a line plot with the coordinates in the grid
    figure = axs[coordinates[dataset_name]]
    figure = sns.lineplot(data=thresholds_frame, x="rewrite_rate", y="accuracy", label="TTA", ax=figure)
    figure.set_title(dataset_name, fontsize=18, pad=15)
    figure.set_xlabel("Augmentation Rate" if dataset_name == "ToxiGen" else "", labelpad=20, fontsize=14)
    figure.set_ylabel("Accuracy" if dataset_name == "SST-5" else "", labelpad=20, fontsize=14)
    figure.title.set_size(18)
    figure.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:.0%}".format(x)))
    figure.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:.2%}".format(x)))
    figure.set_xlim(left=0)
    figure.lines[0].set_linewidth(2)
    figure.legend_.remove()

    # Display max accuracy point
    accuracy_max_point = thresholds_frame[thresholds_frame["accuracy"] == thresholds_frame.max()["accuracy"]].sort_values("rewrite_rate").iloc[-1].to_dict()
    figure.plot(accuracy_max_point["rewrite_rate"],
                accuracy_max_point["accuracy"],
                marker="o",
                markersize=6,
                label="Optimal",
                )
    figure.annotate(f"{accuracy_max_point['accuracy']:.2%}",
                    (accuracy_max_point["rewrite_rate"], accuracy_max_point["accuracy"]),
                    textcoords="offset points",
                    xytext=(10, 0),
                    ha="left",
                    fontsize=10)

    # plot dashed gray line representing the baseline withour augmentation
    figure.plot([0, 1], [baseline_perf[dataset_name], baseline_perf[dataset_name]], color="gray", linestyle="--", linewidth=1.5, alpha=0.75, label="No TTA (Baseline)")
    # figure.axhline(baseline_perf[dataset_name], color="gray", linestyle="--", linewidth=1.5, alpha=0.75)
    if dataset_name == "SST-5":
        figure.set_ylim(bottom=baseline_perf[dataset_name] - 0.005)

    if dataset_name == "ToxiGen":
        figure.legend(loc="upper center", fontsize=12, frameon=False, ncol=3,
                      bbox_to_anchor=(0.5, -0.2),
                      )

    target_threshold = None
    if half is False:
        target_threshold = thresholds_frame[thresholds_frame["accuracy"] == thresholds_frame.max()["accuracy"]].sort_values("rewrite_rate").iloc[-1]
    else:
        thresholds_deltas_list = abs(thresholds_frame["rewrite_rate"] - 50).tolist()
        closest_half_delta = min(thresholds_deltas_list)
        closest_threshold_index = thresholds_deltas_list.index(closest_half_delta)
        target_threshold = thresholds_frame.iloc[closest_threshold_index]

    rewrite_rate = target_threshold["rewrite_rate"] / 100
    original_judgments = inference_log_frame.apply(lambda row: row["original judgment"] if row["original entropy"] < target_threshold["threshold"] else row["judgment"], axis=1)
    return original_judgments, rewrite_rate


def get_threshold_accuracy(threshold, inference_logs_frame):
    threshold_judgments = inference_logs_frame.apply(lambda row: row["original judgment"] if row["original entropy"] < threshold else row["judgment"], axis=1)
    report = classification_report(inference_logs_frame["label"], threshold_judgments, digits=4, output_dict=True)
    llm_call_count = (inference_logs_frame["original entropy"] >= threshold).sum()
    llm_call_rate = llm_call_count / len(inference_logs_frame)
    return report["accuracy"], llm_call_rate

calculate_entropy_threshold_jugments(sst5_data, "SST-5")
calculate_entropy_threshold_jugments(toxigen_data, "ToxiGen")
calculate_entropy_threshold_jugments(agt_data, "AG News Tweets")
fig.tight_layout(pad=1.0)
fig.savefig("../datasets/analysis/entropy_figures/main_acc_rewrite_curves.png", bbox_inches="tight")

## Appendix Entropy Figures

In [None]:
fig, axs = plt.subplots(ncols=4, figsize=(20, 10), nrows=2)

def calculate_entropy_threshold_jugments(inference_log_frame, dataset_name, half=False):
    # thresholds = np.arange(0, 1, 0.0001)
    threshold_scores = []
    threshold_rewrite_rates = []
    for t in tqdm(thresholds, desc="Calculating entropy threshold scores"):
        t_perf, t_rate = get_threshold_accuracy(t, inference_log_frame)
        threshold_scores.append(t_perf)
        threshold_rewrite_rates.append(t_rate)

    thresholds_frame = pd.DataFrame({"threshold": thresholds, "accuracy": threshold_scores, "rewrite_rate": threshold_rewrite_rates})

    # Set line splot
    coordinates = {
        "SST-5": (0, 0),
        "Sem Eval": (0, 1),
        "Dynasent": (0, 2),
        "ToxiGen": (0, 3),
        "Adv Civil": (1, 0),
        "Implicit Hate": (1, 1),
        "AG News Tweets": (1, 2),
    }

    # Create a line plot with the coordinates in the grid
    figure = axs[coordinates[dataset_name][0]][coordinates[dataset_name][1]]
    figure = sns.lineplot(data=thresholds_frame, x="rewrite_rate", y="accuracy", label="TTA", ax=figure)
    figure.set_title(dataset_name, fontsize=18, pad=15)
    figure.set_xlabel("Augmentation Rate" if dataset_name in ["ToxiGen", "Adv Civil", "Implicit Hate", "AG News Tweets"] else "", labelpad=20, fontsize=14)
    figure.set_ylabel("Accuracy" if dataset_name in ["SST-5", "Adv Civil"] else "", labelpad=20, fontsize=14)
    figure.title.set_size(18)
    figure.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:.0%}".format(x)))
    figure.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:.2%}".format(x)))
    figure.set_xlim(left=0)
    figure.lines[0].set_linewidth(2)
    figure.legend_.remove()

    # Display max accuracy point
    accuracy_max_point = thresholds_frame[thresholds_frame["accuracy"] == thresholds_frame.max()["accuracy"]].sort_values("rewrite_rate").iloc[-1].to_dict()
    figure.plot(accuracy_max_point["rewrite_rate"],
                accuracy_max_point["accuracy"],
                marker="o",
                markersize=6,
                label="Optimal",
                )
    figure.annotate(f"{accuracy_max_point['accuracy']:.2%}",
                    (accuracy_max_point["rewrite_rate"], accuracy_max_point["accuracy"]),
                    textcoords="offset points",
                    xytext=(10, 0),
                    ha="left",
                    fontsize=10)

    figure.plot([0, 1], [baseline_perf[dataset_name], baseline_perf[dataset_name]], color="gray", linestyle="--", linewidth=1.5, alpha=0.75, label="No TTA (Baseline)")

    if dataset_name == "SST-5":
        figure.set_ylim(bottom=baseline_perf[dataset_name] - 0.005)

    target_threshold = None
    if half is False:
        target_threshold = thresholds_frame[thresholds_frame["accuracy"] == thresholds_frame.max()["accuracy"]].sort_values("rewrite_rate").iloc[-1]
    else:
        thresholds_deltas_list = abs(thresholds_frame["rewrite_rate"] - 50).tolist()
        closest_half_delta = min(thresholds_deltas_list)
        closest_threshold_index = thresholds_deltas_list.index(closest_half_delta)
        target_threshold = thresholds_frame.iloc[closest_threshold_index]

    rewrite_rate = target_threshold["rewrite_rate"] / 100
    original_judgments = inference_log_frame.apply(lambda row: row["original judgment"] if row["original entropy"] < target_threshold["threshold"] else row["judgment"], axis=1)
    return original_judgments, rewrite_rate


def get_threshold_accuracy(threshold, inference_logs_frame):
    threshold_judgments = inference_logs_frame.apply(lambda row: row["original judgment"] if row["original entropy"] < threshold else row["judgment"], axis=1)
    report = classification_report(inference_logs_frame["label"], threshold_judgments, digits=4, output_dict=True)
    llm_call_count = (inference_logs_frame["original entropy"] >= threshold).sum()
    llm_call_rate = llm_call_count / len(inference_logs_frame)
    return report["accuracy"], llm_call_rate


calculate_entropy_threshold_jugments(sst5_data, "SST-5")
calculate_entropy_threshold_jugments(semval_data, "Sem Eval")
calculate_entropy_threshold_jugments(dynasent_data, "Dynasent")
calculate_entropy_threshold_jugments(toxigen_data, "ToxiGen")
calculate_entropy_threshold_jugments(adv_civil_data, "Adv Civil")
calculate_entropy_threshold_jugments(implicit_hate_data, "Implicit Hate")
calculate_entropy_threshold_jugments(agt_data, "AG News Tweets")

fig.delaxes(axs[1, -1])
fig.legend(loc="lower center", fontsize=12, frameon=False, ncol=3, labels=["TTA", "No TTA (Baseline)", "Optimal Aug Rate"], bbox_to_anchor=(0.5, -0.025))
fig.tight_layout(pad=2.0)
fig.savefig("../datasets/analysis/entropy_figures/appendix_acc_rewrite_curves.png", bbox_inches="tight")

In [None]:
sem_eval_original_entropies = semval_data["original entropy"].tolist()
figure = sns.scatterplot(data=semval_data, x=range(len(sem_eval_original_entropies)), y="original entropy", hue="outcome", s=5)
# set legend to the right vertically
figure.legend(bbox_to_anchor=(1.01, 1), borderaxespad=0, frameon=False, title="Outcome")
# make y axis log scale
figure.set_yscale("log")

In [None]:
sst5_data[["original entropy", "outcome"]].groupby("outcome").describe()

In [None]:
# heatmap between original entropy and outcome
pd.crosstab(sst5_data["original entropy"], semval_data["outcome"])