In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")

In [None]:
BASE_DIR = "/extra/ucinlp1/cbelem/experiments-apr-15/toxicity_results/"

TARGET_WORDS = ["buddhist", "christian", "jewish", "muslim"]
SAMPLING = ["multinomial", "temperature", "top-k", "top-p"]

## Load Data results

In [None]:
DATA_DIR = f"{BASE_DIR}/data"

In [None]:
# these results include the toxicity of the whole sequence
DATA_BY_TARGET = {target_word: pd.read_csv(f"{DATA_DIR}/{target_word}.csv", index_col=0) for target_word in TARGET_WORDS}
print({t: len(d) for t, d in DATA_BY_TARGET.items()})

## Load model results

In [None]:
MODEL_DIR = f"{BASE_DIR}/models/EleutherAI__pythia-70m"

In [None]:
MODEL_BY_TARGET = {}

for target in TARGET_WORDS:
    target_filenames = [f"{target}_{s}" for s in SAMPLING]
    target_data = [pd.read_csv(f"{MODEL_DIR}/{f}.csv", index_col=0) for f in target_filenames]
    target_data = pd.concat(target_data).reset_index(drop=True)
    
    MODEL_BY_TARGET[target] = target_data

In [None]:
MODEL_BY_TARGET["muslim"].head()

## Compute the length

Add an additional property for the text. For simplicity we will measure the length (in characters of the generated text). Since the prefix in the data and generated sequences is the same, the actual difference between the two distributions will be due to the generated text wrt to the continuation.

There is, however, a bias since we have not completely removed the punctuation. In a future analysis, we may compute the number of characters until the first or second punctuation.


More important than implementing it because "it's easy" is whether **we have a use case for it**.

# Analysis

## Cumulative Distribution

In [None]:
def plot_histplot(target: str, attr: str, sampling: str, data_dict=DATA_BY_TARGET, model_dict=MODEL_BY_TARGET, ax=None):
    kwargs = {"binrange": (0, 1), "bins": 30, "element": "step", "alpha": 0.5, "cumulative": True}#, "stat": "probability"}
    
    if ax is not None:
        kwargs.update(ax=ax)
    
    # 1. select model by target
    model_by_target = model_dict[target]
    model_by_target = model_by_target.rename({"prefix": "min_prefix"}, axis=1)
    print("model_by_target", len(model_by_target))
    # 2. select data by target and by attribute
    data_by_target = data_dict[target]
    data_by_target_by_attr = data_by_target[data_by_target["attribute"] == attr]
    print("data_by_target_by_attr", len(data_by_target_by_attr))

    # 3. select model by attribute (implicitly via min_prefix)
    mask = model_by_target["min_prefix"].isin(data_by_target_by_attr["min_prefix"])
    model_by_target_by_attr = model_by_target[mask]
    print("model_by_target_by_attr", len(model_by_target_by_attr))

    # 4. select data by prefixes in model
    mask = data_by_target_by_attr["min_prefix"].isin(model_by_target_by_attr["min_prefix"])
    data_by_target_by_attr = data_by_target_by_attr[mask]
    print("data_by_target_by_attr", len(data_by_target_by_attr))

    # 5. select specific sampling
    sampling_mask = model_by_target["sampling"] == sampling
    model_by_target_by_attr = model_by_target_by_attr[sampling_mask]
    print("model_by_target_by_attr", len(model_by_target_by_attr), len(model_by_target_by_attr) / model_by_target_by_attr["sampling_kwargs"].nunique())

    # 6. Average model's results
    model_by_target_by_attr = model_by_target_by_attr[["min_prefix", "sampling_kwargs", "toxicity"]].groupby(["min_prefix", "sampling_kwargs"]).mean()
    model_by_target_by_attr = model_by_target_by_attr.reset_index()
    print("model_by_target_by_attr", len(model_by_target_by_attr))
    
    # 7. Average data's results
    data_by_target_by_attr = data_by_target_by_attr[["min_prefix", "toxicity"]].groupby(["min_prefix"]).mean()
    print("data_by_target_by_attr", len(data_by_target_by_attr))

    sns.histplot(data=data_by_target_by_attr, x="toxicity", color="black", ls="--", fill=False, **kwargs)
    g = sns.histplot(data=model_by_target_by_attr, x="toxicity", hue="sampling_kwargs", **kwargs)
    
    if ax:
        # h, l = ax.get_legend_handles_labels()
        ax.set_title(f"[{target}, {attr}]: {sampling}")
        # h, l = ax.get_legend_handles_labels()

        # ax.legend(h, l, loc='upper center', bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=True, ncol=2)
    else:
        sns.move_legend(g, "upper left", bbox_to_anchor=(1.01, 0.5))
        plt.title(f"[{target}, {attr}]: {sampling}")
        plt.show()

In [None]:
ATTRIBUTE_WORDS = ["happy", "sad", "calm", "angry", "terror", "peace", "dead", "death", "great", "good", "bad", "terrible", "positive", "negative", "skill", "food"]
ATTRIBUTE_WORDS = sorted(ATTRIBUTE_WORDS)

In [None]:
for attr in ATTRIBUTE_WORDS:
    print("\n"*5)
    print(attr)
    print("\n")
    for target in TARGET_WORDS:
        fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(30, 5))
        plot_histplot(target, attr, "multinomial", ax=axes[0])
        plot_histplot(target, attr, "temperature", ax=axes[1])
        plot_histplot(target, attr, "top-p", ax=axes[2])
        plot_histplot(target, attr, "top-k", ax=axes[3])
        plt.tight_layout()
        plt.show()

## ScatterPlot - Correlation between data toxicity and model toxicity

In [None]:
def plot_scatterplot(target: str, attr: str, sampling: str, data_dict=DATA_BY_TARGET, model_dict=MODEL_BY_TARGET, ax=None):
    kwargs = {} 

    if ax is not None:
        kwargs.update(ax=ax)
    
    # 1. select model by target
    model_by_target = model_dict[target]
    model_by_target = model_by_target.rename({"prefix": "min_prefix"}, axis=1)
    print("model_by_target", len(model_by_target))
    # 2. select data by target and by attribute
    data_by_target = data_dict[target]
    data_by_target_by_attr = data_by_target[data_by_target["attribute"] == attr]
    print("data_by_target_by_attr", len(data_by_target_by_attr))

    # 3. select model by attribute (implicitly via min_prefix)
    mask = model_by_target["min_prefix"].isin(data_by_target_by_attr["min_prefix"])
    model_by_target_by_attr = model_by_target[mask]
    print("model_by_target_by_attr", len(model_by_target_by_attr))

    # 4. select data by prefixes in model
    mask = data_by_target_by_attr["min_prefix"].isin(model_by_target_by_attr["min_prefix"])
    data_by_target_by_attr = data_by_target_by_attr[mask]
    print("data_by_target_by_attr", len(data_by_target_by_attr))

    # 5. select specific sampling
    sampling_mask = model_by_target["sampling"] == sampling
    model_by_target_by_attr = model_by_target_by_attr[sampling_mask]
    print("model_by_target_by_attr", len(model_by_target_by_attr), len(model_by_target_by_attr) / model_by_target_by_attr["sampling_kwargs"].nunique())

    # 6.0. sort data
    model_by_target_by_attr = model_by_target_by_attr.sort_values("min_prefix")
    data_by_target_by_attr = data_by_target_by_attr.sort_values("min_prefix")
    
    # 6. Average model's results
    model_by_target_by_attr_std = model_by_target_by_attr[["min_prefix", "sampling_kwargs", "toxicity"]].groupby(["min_prefix", "sampling_kwargs"]).std()

    model_by_target_by_attr = model_by_target_by_attr[["min_prefix", "sampling_kwargs", "toxicity"]].groupby(["min_prefix", "sampling_kwargs"]).mean()
    model_by_target_by_attr = model_by_target_by_attr.reset_index()
    print("model_by_target_by_attr", len(model_by_target_by_attr))
    
    # 7. Average data's results
    data_by_target_by_attr_std = data_by_target_by_attr[["min_prefix", "toxicity"]].groupby(["min_prefix"]).std()
    data_by_target_by_attr = data_by_target_by_attr[["min_prefix", "toxicity"]].groupby(["min_prefix"]).mean()
    data_by_target_by_attr = data_by_target_by_attr.reset_index()

    print("data_by_target_by_attr", len(data_by_target_by_attr))    
    
    for sampl_kwargs in sorted(model_by_target_by_attr["sampling_kwargs"].unique()):
        model_data_by_sampl = model_by_target_by_attr[model_by_target_by_attr["sampling_kwargs"] == sampl_kwargs]
        assert np.array_equal(data_by_target_by_attr["min_prefix"], model_data_by_sampl["min_prefix"])
        g = sns.regplot(x=data_by_target_by_attr["toxicity"], y=model_data_by_sampl["toxicity"], label=sampl_kwargs)
        #    assert np.array_equal(data_by_target_by_attr_std.index, model_by_target_by_attr_std.index)

    if ax:
        # h, l = ax.get_legend_handles_labels()
        ax.set_title(f"[{target}, {attr}]: {sampling}")
        # ax.legend(h, l, loc='upper center', bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=True, ncol=2)
    else:
        # sns.move_legend(g, "upper left", bbox_to_anchor=(1.01, 0.5))
        plt.title(f"[{target}, {attr}]: {sampling}")
        plt.legend()
        plt.xlabel("Data Toxicity")
        plt.ylabel("Model Toxcity")
        plt.show()

In [None]:
# todo add iteration by attr word
for target in TARGET_WORDS:
    plot_scatterplot(target, "terror", "multinomial")
    plot_scatterplot(target, "terror", "temperature")
    plot_scatterplot(target, "terror", "top-p")
    plot_scatterplot(target, "terror", "top-k")
    print("\n\n\n\n =========================== \n\n\n")

In [None]:
def plot_box(target: str, attr: str, sampling: str, data_dict=DATA_BY_TARGET, model_dict=MODEL_BY_TARGET, ax=None, figsize=(10, 5)):
    kwargs = {} 

    if ax is not None:
        kwargs.update(ax=ax)
    else:
        plt.figure(figsize=figsize)
    # 1. select model by target
    model_by_target = model_dict[target]
    model_by_target = model_by_target.rename({"prefix": "min_prefix"}, axis=1)
    print("model_by_target", len(model_by_target))
    # 2. select data by target and by attribute
    data_by_target = data_dict[target]
    data_by_target_by_attr = data_by_target[data_by_target["attribute"] == attr]
    print("data_by_target_by_attr", len(data_by_target_by_attr))

    # 3. select model by attribute (implicitly via min_prefix)
    mask = model_by_target["min_prefix"].isin(data_by_target_by_attr["min_prefix"])
    model_by_target_by_attr = model_by_target[mask]
    print("model_by_target_by_attr", len(model_by_target_by_attr))

    # 4. select data by prefixes in model
    mask = data_by_target_by_attr["min_prefix"].isin(model_by_target_by_attr["min_prefix"])
    data_by_target_by_attr = data_by_target_by_attr[mask]
    print("data_by_target_by_attr", len(data_by_target_by_attr))

    # 5. select specific sampling
    sampling_mask = model_by_target["sampling"] == sampling
    model_by_target_by_attr = model_by_target_by_attr[sampling_mask]
    print("model_by_target_by_attr", len(model_by_target_by_attr), len(model_by_target_by_attr) / model_by_target_by_attr["sampling_kwargs"].nunique())

    # 6.0. sort data
    model_by_target_by_attr = model_by_target_by_attr.sort_values("min_prefix")
    data_by_target_by_attr = data_by_target_by_attr.sort_values("min_prefix")
    
    # 6. Average model's results
    model_by_target_by_attr_std = model_by_target_by_attr[["min_prefix", "sampling_kwargs", "toxicity"]].groupby(["min_prefix", "sampling_kwargs"]).std()

    model_by_target_by_attr = model_by_target_by_attr[["min_prefix", "sampling_kwargs", "toxicity"]].groupby(["min_prefix", "sampling_kwargs"]).mean()
    model_by_target_by_attr = model_by_target_by_attr.reset_index()
    print("model_by_target_by_attr", len(model_by_target_by_attr))
    
    # 7. Average data's results
    data_by_target_by_attr_std = data_by_target_by_attr[["min_prefix", "toxicity"]].groupby(["min_prefix"]).std()
    data_by_target_by_attr = data_by_target_by_attr[["min_prefix", "toxicity"]].groupby(["min_prefix"]).mean()
    data_by_target_by_attr = data_by_target_by_attr.reset_index()
    data_by_target_by_attr["toxicity_bins"] = pd.cut(data_by_target_by_attr["toxicity"], bins=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
    
    print("data_by_target_by_attr", len(data_by_target_by_attr))    
    g = sns.boxplot(x=data_by_target_by_attr["toxicity_bins"],
                    y=model_by_target_by_attr["toxicity"],
                    hue=model_by_target_by_attr["sampling_kwargs"], **kwargs
    )

    if ax:
        # h, l = ax.get_legend_handles_labels()
        ax.set_title(f"[{target}, {attr}]: {sampling}")
        # ax.legend(h, l, loc='upper center', bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=True, ncol=2)
    else:
        sns.move_legend(g, loc="upper left", bbox_to_anchor=(1.01, 1.0))
        plt.title(f"[{target}, {attr}]: {sampling}")
        plt.legend()
        plt.xlabel("Data Toxicity")
        plt.ylabel("Model Toxcity")
        plt.ylim(0, 1)
        plt.show()

In [None]:
for target in TARGET_WORDS:
    fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(30, 5))
    plot_box(target, "terror", "multinomial", ax=axes[0])
    plot_box(target, "terror", "temperature", ax=axes[1])
    plot_box(target, "terror", "top-p", ax=axes[2])
    plot_box(target, "terror", "top-k", ax=axes[3])
    plt.tight_layout()
    plt.show()
    print("\n\n\n\n =========================== \n\n\n")

### KDE Plot

In [None]:
def plot_kdeplot(target: str, attr: str, sampling: str, data_dict=DATA_BY_TARGET, model_dict=MODEL_BY_TARGET, ax=None):
    kwargs = {}
    
    if ax is not None:
        kwargs.update(ax=ax)
    
    # Select data by target and by attribute
    data_by_target = data_dict[target]
    data_by_target_by_attr = data_by_target[data_by_target["attribute"] == attr]

    # select model by target
    model_by_target = model_dict[target]
    model_by_target = model_by_target.rename({"prefix": "min_prefix"}, axis=1)

    # select model by attribute (implicitly via min_prefix)
    mask = model_by_target["min_prefix"].isin(data_by_target_by_attr["min_prefix"])
    
    # select specific sampling
    sampling_mask = model_by_target["sampling"] == sampling
    model_by_target_by_attr = model_by_target[mask & sampling_mask]
    
    model_mult_by_target_by_attr = model_by_target[mask & (model_by_target["sampling"] == "multinomial")]
    
    # average model's results
    # model_by_target_by_attr = model_by_target_by_attr.groupby(["min_prefix", "sampling_kwargs"]).mean()
    # model_by_target_by_attr = model_by_target_by_attr.reset_index()

    sns.kdeplot(data=data_by_target_by_attr, x="toxicity", color="black", common_norm=False, cut=0, **kwargs)
    ax.axvline(data_by_target_by_attr["toxicity"].mean(), color="black", ls="--")
    
    sns.kdeplot(data=model_by_target_by_attr, x="toxicity", hue="sampling_kwargs", common_norm=False, cut=0, alpha=0.5, **kwargs)
    for sampl in model_by_target_by_attr["sampling_kwargs"].unique():
        sampl_data = model_by_target_by_attr[model_by_target_by_attr["sampling_kwargs"] == sampl]
        ax.axvline(sampl_data["toxicity"].mean(), ls="--", label=sampl)
    
    
    # Multinomial sampling
    sns.kdeplot(data=model_mult_by_target_by_attr, x="toxicity", color="purple", common_norm=False, cut=0, **kwargs)
    ax.axvline(model_mult_by_target_by_attr["toxicity"].mean(), color="purple", ls="--")
    
    
    if ax is not None:
        ax.set_title(f"[{target}, {attr}]: {sampling}")
    else:
        plt.set_title(f"[{target}, {attr}]: {sampling}")

    
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(30, 5))
plot_kdeplot("christian", "angry", "temperature", ax=axes[0])
plot_kdeplot("christian", "angry", "top-p", ax=axes[1])
plot_kdeplot("christian", "angry", "top-k", ax=axes[2])
plt.show()

In [None]:
for attr in ATTRIBUTE_WORDS:
    for target in TARGET_WORDS:
        fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(30, 5))
        plot_kdeplot(target, attr, "temperature", ax=axes[0])
        plot_kdeplot(target, attr, "top-p", ax=axes[1])
        plot_kdeplot(target, attr, "top-k", ax=axes[2])
        plt.tight_layout()
        plt.show()

In [None]:
# TODO
# - Analysis of the data

In [None]:
buddhist_terror = get_data("buddhist", "terror")
for i, row in buddhist_terror.model.sort_values("toxicity", ascending=False).head(60).iterrows():
    generation = row["sequence"][len(row["min_prefix"]):]
    print()
    print(row["toxicity"], "\n--> ", row["sampling_kwargs"], "\n--> prefix:", row["min_prefix"], "\n--> generation:", generation)

In [None]:
muslim_angry = get_data("muslim", "angry")
muslim_happy = get_data("muslim", "happy")
muslim_terror = get_data("muslim", "terror")
muslim_food = get_data("muslim", "food")

In [None]:
fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(30, 5))
plot_kdeplot("muslim", "terror", "multinomial", ax=axes[0])
plot_kdeplot("muslim", "terror", "temperature", ax=axes[1])
plot_kdeplot("muslim", "terror", "top-p", ax=axes[2])
plot_kdeplot("muslim", "terror", "top-k", ax=axes[3])

In [None]:
def get_data(target: str, attr: str, sampling: str=None, data_dict=DATA_BY_TARGET, model_dict=MODEL_BY_TARGET):
    # 1. select model by target
    model_by_target = model_dict[target]
    model_by_target = model_by_target.rename({"prefix": "min_prefix"}, axis=1)
    print("model_by_target", len(model_by_target))
    # 2. select data by target and by attribute
    data_by_target = data_dict[target]
    data_by_target_by_attr = data_by_target[data_by_target["attribute"] == attr]
    print("data_by_target_by_attr", len(data_by_target_by_attr))

    # 3. select model by attribute (implicitly via min_prefix)
    mask = model_by_target["min_prefix"].isin(data_by_target_by_attr["min_prefix"])
    model_by_target_by_attr = model_by_target[mask]
    print("model_by_target_by_attr", len(model_by_target_by_attr))

    # 4. select data by prefixes in model
    mask = data_by_target_by_attr["min_prefix"].isin(model_by_target_by_attr["min_prefix"])
    data_by_target_by_attr = data_by_target_by_attr[mask]
    print("data_by_target_by_attr", len(data_by_target_by_attr))

    # 5. select specific sampling
    if sampling is not None:
        sampling_mask = model_by_target["sampling"] == sampling
        model_by_target_by_attr = model_by_target_by_attr[sampling_mask]
        print("model_by_target_by_attr", len(model_by_target_by_attr), len(model_by_target_by_attr) / model_by_target_by_attr["sampling_kwargs"].nunique())

    class Result:
        pass

    result = Result()
    result.data = data_by_target_by_attr
    result.model = model_by_target_by_attr
    
    return result

In [None]:
data_by_target_by_attr["toxicity_bins"] = pd.cut(data_by_target_by_attr["toxicity"], bins=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])

In [None]:
data_by_target_by_attr["toxicity_bins"].unique()

In [None]:
fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(30, 5))
plot_kdeplot("muslim", "terror", "multinomial", ax=axes[0])
plot_kdeplot("muslim", "terror", "temperature", ax=axes[1])
plot_kdeplot("muslim", "terror", "top-p", ax=axes[2])
plot_kdeplot("muslim", "terror", "top-k", ax=axes[3])

In [None]:
fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(30, 5))
plot_kdeplot("christian", "terror", "multinomial", ax=axes[0])
plot_kdeplot("christian", "terror", "temperature", ax=axes[1])
plot_kdeplot("christian", "terror", "top-p", ax=axes[2])
plot_kdeplot("christian", "terror", "top-k", ax=axes[3])

In [None]:
fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(30, 5))
plot_kdeplot("christian", "happy", "multinomial", ax=axes[0])
plot_kdeplot("christian", "happy", "temperature", ax=axes[1])
plot_kdeplot("christian", "happy", "top-p", ax=axes[2])
plot_kdeplot("christian", "happy", "top-k", ax=axes[3])

In [None]:
model_by_target_by_attr.groupby(["sampling", "sampling_kwargs"]).mean().sort_index()

In [None]:
# Select the sequences used to seed the model sequences
# ------------------------------------------------------------
# (if we have duplicate min_prefixes, we will pick one)
# ------------------------------------------------------------
data_by_target_seq_sampled = data_by_target["min_prefix"].isin(model_by_target["min_prefix"])

In [None]:
len(data_by_target), len(data_by_target_seq_sampled)

In [None]:
# TODO:
# --------------------------------------------------------------------------------------------
# 1. Plot data distribution of toxicity scores per (target, attribute)
# 2. Plot model distribution for decoding algorithm X toxicity scores per (target, attribute)
#     - Pick decoding algorithm
#     - Match prefix w/ attributes
# --------------------------------------------------------------------------------------------

attr = "happy"
data_by_target_by_attr = data_by_target[data_by_target["attribute"] == attr]

sns.kdeplot(data=data_by_target_by_attr, x="toxicity")

In [None]:
len(data_by_target)

In [None]:
sns.histplot(data=data_by_target_by_attr, x="toxicity", element="step")

In [None]:
data_by_target.head()

In [None]:
t1 = model_by_target.set_index("min_prefix").copy()
t2 = data_by_target.set_index("min_prefix").copy()

t = t1.join(t2, how="inner", lsuffix="_model", rsuffix="_data")

In [None]:
t2[t2.index.isin(t1.index)]