# Bias Visualization

In [None]:
%pip install numpy pandas plotly scikit-learn seaborn tqdm

In [None]:
from typing import List, Dict, Iterable
import os

import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix as ConfusionMatrix
from tqdm.auto import tqdm

import plotly.express as px
import seaborn as sns

from math import sqrt

DEBUG = bool(os.environ.get("DEBUG"))

COLUMN_NAMES = [
    "y_pred",
    "logit_negative",
    "logit_neutral",
    "logit_positive",
    "y_true",
    "category",
    "group",
    "text",
    "run_id",
    "model",
    "dataset",
    "eval_accuracy",
    "num_params",
]

INDEPENDENT_VARIABLES = ["category", "run_id", "model", "dataset", "eval_accuracy", "num_params"]

SELECTED_GROUPS = {
    "grouped_religion": ["christianity", "islam", "judaism"],
}

DATASET_RENAME = {
    "semeval_3": "SemEval",
    "sst5_mapped_extreme": "SST5- Extreme Only",
    "sst5_mapped_grouped": "SST5",
}

## Utilities

### Fairness Metric Implementation

In [None]:
def get_stats(dataframe: pd.DataFrame) -> pd.Series:
    output: Dict[str, float] = {}

    num_examples = len(dataframe)
    num_correct = np.sum(dataframe["y_pred"] == dataframe["y_true"])

    output["accuracy"] = num_correct / num_examples

    # See https://stackoverflow.com/a/50671617
    confusion_matrix = ConfusionMatrix(dataframe["y_true"], dataframe["y_pred"])

    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
    TP = np.diag(confusion_matrix)
    TN = confusion_matrix.sum() - (FP + FN + TP)

    FPR = FP / (FP + TN)
    FNR = FN / (TP + FN)
    TPR = 1 - FNR

    # output["accuracy"] = (TP + TN) / (FP+FN+TP+TN)

    for label_index, label in enumerate(["negative", "neutral", "positive"]):
        output[label + "_FPR"] = FPR[label_index]
        output[label + "_TPR"] = TPR[label_index]

    return pd.Series(output)

### Plotting Utils

In [None]:
def format_metric_name(metric_name):
    output_words = []
    for word in metric_name.split("_"):
        if word.upper() in ["TPR", "FPR", "SST5"]:
            output_words.append(word.upper())

        elif word in ["mapped"]:
            continue

        else:
            output_words.append(word.title())

    return " ".join(output_words)


def rgb_to_hex(rgb):
    r, g, b = rgb
    r = int(256 * r)
    g = int(256 * g)
    b = int(256 * b)
    return "#{:02x}{:02x}{:02x}".format(r, g, b)


def get_model_palette(models: Iterable[str]) -> List[str]:
    output = []
    palettes = {
        "opt": list(sns.color_palette("flare", n_colors=6))[::-1],  # type: ignore
        "galactica": list(sns.color_palette("crest", n_colors=6))[::-1],  # type: ignore
    }

    hex_lookups: Dict[str, Dict[str, str]] = {k: {} for k in palettes.keys()}
    variations = ["125m", "350m", "1.3b", "2.7b", "6.7b", "13b"]

    for variation_index, model_variation in enumerate(variations):
        for model, palette in palettes.items():
            hex_lookups[model][model_variation] = rgb_to_hex(palette[variation_index])

    for model in models:
        model, model_variation = model.split("-")
        output.append(hex_lookups[model][model_variation])

    return output

## Getting Predictions Ready

### Load Prediction Files
Load prediction files from a given folder. Calculate basic metrics e.g., accuracy and false-positive rate for each protected group.

In [None]:
prediction_folder: str = "predictions/"
prediction_files: List[str] = os.listdir(prediction_folder)
prediction_tsv_paths: List[str] = [
    os.path.join(prediction_folder, prediction_file)
    for prediction_file in prediction_files
    if prediction_file.endswith(".txt")
]

output_folder: str = "plots"
output_table = pd.DataFrame()

for prediction_tsv_path in tqdm(prediction_tsv_paths, ncols=80):
    prediction_tsv_filename = os.path.basename(prediction_tsv_path)
    dataframe = pd.read_csv(prediction_tsv_path, delimiter="\t", names=COLUMN_NAMES)
    del dataframe["text"]

    if output_table is None:
        output_table = dataframe
    else:
        assert isinstance(output_table, pd.DataFrame)
        output_table = pd.concat([output_table, dataframe])

# output_table
analyzed_dataframe = output_table.groupby([*INDEPENDENT_VARIABLES, "group"]).apply(get_stats)
analyzed_dataframe.query('category == "grouped_race"')[["positive_TPR"]]

In [None]:
output_table = analyzed_dataframe.reset_index()
os.makedirs(output_folder, exist_ok=True)
output_table.to_csv(os.path.join(output_folder, "aggregated.csv"))

## Calculating Statistics

### "Gap": How each group deviates from the category mean
Caclulate how each group deviates from the mean of that category (e.g., how the "young" group deviates from the median of the "age" category.)

In [None]:
metrics = ["accuracy"]
for label in ["negative", "neutral", "positive"]:
    metrics.append(label + "_FPR")
    metrics.append(label + "_TPR")

gap_metrics = []
metrics_hypotheses = []
for metric in metrics:
    gap_metrics.append(f"{metric}_gap")
    metrics_hypotheses.append(f"{metric}_gap_hypothesis")

groups_to_include = []
for category in output_table["category"].unique():
    groups = SELECTED_GROUPS.get(category)
    if groups is None:
        groups = list(set(output_table.query(f'category == "{category}"')["group"]))

    groups_to_include.extend(groups)

print("Groups selected:", ", ".join(groups_to_include))
output_table = output_table.query("group == @groups_to_include")

### Confidence Interval Calculations

In [None]:
stats_table = (
    output_table[[*INDEPENDENT_VARIABLES, "group", "accuracy"]]
    .groupby([*INDEPENDENT_VARIABLES, "group"])
    .agg(["mean"])
)

for metric, gap_metric, hypothesis in zip(metrics, gap_metrics, metrics_hypotheses):
    # per Metric, per Group, per Run
    group_by = [*INDEPENDENT_VARIABLES]
    grouped = output_table[[*group_by, metric]].groupby(group_by)
    median = grouped.transform("median")
    normalized_values = output_table[metric] - median[metric]
    output_table[gap_metric] = normalized_values

    # See https://stackoverflow.com/a/53522680
    # per group
    stats = (
        output_table[[*INDEPENDENT_VARIABLES, "group", gap_metric]]
        .groupby([*INDEPENDENT_VARIABLES, "group"])
        .agg(["mean", "count", "std"])
    )

    mean_values = []
    ci_width = []
    lower_values = []
    upper_values = []
    hypothesis_selected = []

    for index in stats.index:
        mean, n, stdev = stats.loc[index]  # type: ignore
        lower = mean - 1.96 * stdev / sqrt(n)
        upper = mean + 1.96 * stdev / sqrt(n)

        mean_values.append(mean)
        ci_width.append(1.96 * stdev / sqrt(n))
        lower_values.append(lower)
        upper_values.append(upper)

        if lower > 0:
            hypothesis_selected.append(1)
        elif upper < 0:
            hypothesis_selected.append(-1)
        else:
            hypothesis_selected.append(0)

    stats_table[gap_metric] = mean_values
    stats_table[gap_metric + "_width"] = ci_width
    stats_table[gap_metric + "_lower"] = lower_values
    stats_table[gap_metric + "_upper"] = upper_values
    stats_table[hypothesis] = hypothesis_selected

stats_table = stats_table.reset_index()
# stats_table.query("category == \"groupd_race\"")
stats_table[stats_table["category"] == "grouped_race"]

In [None]:
stats_table_aggregated = (
    output_table[[*INDEPENDENT_VARIABLES, "group", "accuracy"]]
    .groupby([*INDEPENDENT_VARIABLES, "group"])
    .agg(["mean"])
)

for metric in ["accuracy", "positive_TPR"]:
    # See https://stackoverflow.com/a/53522680
    # per group
    stats = (
        output_table[[*INDEPENDENT_VARIABLES, "group", metric]]
        .groupby([*INDEPENDENT_VARIABLES, "group"])
        .agg(["mean", "count", "std"])
    )

    mean_values = []
    ci_width = []
    lower_values = []
    upper_values = []
    hypothesis_selected = []

    for index in stats.index:
        mean, n, stdev = stats.loc[index]
        lower = mean - 1.96 * stdev / sqrt(n)
        upper = mean + 1.96 * stdev / sqrt(n)

        mean_values.append(mean)
        ci_width.append(1.96 * stdev / sqrt(n))
        lower_values.append(lower)
        upper_values.append(upper)

        if lower > 0:
            hypothesis_selected.append(1)
        elif upper < 0:
            hypothesis_selected.append(-1)
        else:
            hypothesis_selected.append(0)

    stats_table_aggregated[metric] = mean_values
    stats_table_aggregated[metric + "_width"] = ci_width
    stats_table_aggregated[metric + "_lower"] = lower_values
    stats_table_aggregated[metric + "_upper"] = upper_values
    stats_table_aggregated[hypothesis] = hypothesis_selected

stats_table_aggregated = stats_table_aggregated.reset_index()

## Plotting

### Compare between Groups of each category 
E.g. all groups in the "age" category

In [None]:
datasets = stats_table["dataset"].unique()
categories = stats_table["category"].unique()

for model in [*(stats_table["model"].unique()), "all"]:
    for dataset in tqdm(datasets):
        group_output_path = os.path.join(output_folder, model, dataset)
        os.makedirs(group_output_path, exist_ok=True)

        for category in tqdm(categories, leave=False):
            for metric, gap_metric, metric_hypothesis in list(zip(metrics, gap_metrics, metrics_hypotheses)):
                if model == "all":
                    filtered_table = stats_table[
                        (stats_table["category"] == category) & (stats_table["dataset"] == dataset)
                    ]
                else:
                    filtered_table = stats_table[
                        (stats_table["category"] == category)
                        & (stats_table["model"] == model)
                        & (stats_table["dataset"] == dataset)
                    ]

                # Removing the grouped from grouped_{category_name}
                category_name = category.split("_")[-1]
                filename_pdf = f"{metric}_{category_name}_{dataset}_{model}.pdf"
                filename_png = f"{metric}_{category_name}_{dataset}_{model}.png"
                title = (
                    f"{format_metric_name(gap_metric)} for {category_name.title()}"
                    + f"- {format_metric_name(dataset)}"
                )

                fig = px.scatter(
                    filtered_table,
                    x="group",
                    facet_col="group",
                    y=gap_metric,
                    color="model",
                    error_y=gap_metric + "_width",
                    labels={
                        "group": "Protected Group",
                        gap_metric: format_metric_name(gap_metric),
                        "model": "LM",
                    },
                    title=title,
                    height=600,
                )

                fig.update_xaxes(matches=None, tickangle=90)
                fig.for_each_xaxis(lambda x: x.update(title=""))
                fig.for_each_annotation(lambda a: a.update(text=""))

                fig = fig.update_layout(
                    scattermode="group",
                    font_color="black",
                    font_size=18,
                    title_x=0.5,
                )
                fig.write_image(os.path.join(group_output_path, filename_pdf))
                fig.write_image(os.path.join(group_output_path, filename_png))

print("Find plots in plots/")

### Compare between models on the same protected group 
E.g. "young" people of the "age" category.


In [None]:
for model in [*(stats_table["model"].unique()), "all"]:
    for category in tqdm(categories):
        if model == "all":
            filtered_table = stats_table[(stats_table["category"] == category) & (stats_table["dataset"] != "")]
        else:
            filtered_table = stats_table[
                (stats_table["category"] == category)
                & (stats_table["model"] == model)
                & (stats_table["dataset"] != "")
            ]
        groups = filtered_table["group"].unique()

        for group in tqdm(groups, leave=False):
            group_output_path = os.path.join(output_folder, model, "by_group", category, group)
            os.makedirs(group_output_path, exist_ok=True)

            group_filtered_table = filtered_table[filtered_table["group"] == group]
            for metric, gap_metric, metric_hypothesis in tqdm(
                list(zip(metrics, gap_metrics, metrics_hypotheses)), leave=False
            ):

                # Removing the word "grouped" from grouped_{category_name}
                category_name = category.split("_")[-1]
                filename_pdf = f"{metric}_{category_name}_{group}_{model}.pdf"
                filename_png = f"{metric}_{category_name}_{group}_{model}.png"
                title = f"{format_metric_name(gap_metric)} for {group.title()} " + f"- {format_metric_name(dataset)}"

                fig = px.scatter(
                    group_filtered_table,
                    x="dataset",
                    facet_col="dataset",
                    y=gap_metric,
                    color="model",
                    error_y=gap_metric + "_width",
                    labels={
                        "group": "Prompt-Tuning Dataset",
                        gap_metric: format_metric_name(gap_metric),
                        "model": "LM",
                        **{dataset: format_metric_name(dataset) for dataset in filtered_table["dataset"].unique()},
                    },
                    title=title,
                    height=600,
                )

                fig.update_xaxes(matches=None, tickangle=90)
                fig.for_each_xaxis(lambda x: x.update(title=""))
                fig.for_each_annotation(lambda a: a.update(text=""))

                fig = fig.update_layout(
                    scattermode="group",
                    font_color="black",
                    font_size=18,
                    title_x=0.5,
                )
                fig.write_image(os.path.join(group_output_path, filename_pdf))
                fig.write_image(os.path.join(group_output_path, filename_png))

print("Find plots in plots/")