# Bias Visualization

### Dependencies
Note that we use the scikit-learn confusion matrix utility to help calculate some of the fairness metrics. We generate the visualizations using plotly Express.

In [None]:
from typing import List, Dict, Iterable
import os

import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix as ConfusionMatrix
from tqdm.auto import tqdm

import plotly.express as px
import seaborn as sns

from math import sqrt

### Global Constants

This notebook will visualize how each of the following variable affects the fairness towards a particular protected group.

In [None]:
INDEPENDENT_VARIABLES = ["model", "dataset", "num_params"]

## Utilities

### Plotting Utils

In [None]:
def format_metric_name(metric_name):
    """
    Prettify diagram texts by replacing
    snake case with regular title case.
    """
    output_words = []
    for word in metric_name.split("_"):
        output_words.append(word.title())

    return " ".join(output_words)

## Getting Predictions Ready

### Load Prediction Files
Load all prediction `tsv` files from a given folder into a dataframe.

In [None]:
prediction_folder: str = "predictions/"
prediction_files: List[str] = os.listdir(prediction_folder)
prediction_tsv_paths: List[str] = [
    os.path.join(prediction_folder, prediction_file)
    for prediction_file in prediction_files
    if prediction_file.endswith(".tsv")
]

output_folder: str = "plots"
output_table = pd.DataFrame()

for prediction_tsv_path in tqdm(prediction_tsv_paths, ncols=80):
    prediction_tsv_filename = os.path.basename(prediction_tsv_path)
    dataframe = pd.read_csv(prediction_tsv_path, delimiter="\t")
    del dataframe["text"]
    del dataframe["eval_accuracy"]

    if output_table is None:
        output_table = dataframe
    else:
        assert isinstance(output_table, pd.DataFrame)
        output_table = pd.concat([output_table, dataframe])

## Calculate Statistics
Since there can be multiple examples for each protected group, we apply the `groupby` method to compute aggregated statistics across all the examples of each protected group.

Refer to the following function for details on how we implemented the metrics.

### Function for aggregating statistics


Below you will find a function that takes in a dataframe. This dataframe is will be automatically generated using the `pd.DataFrame.groupby`. The content of this dataframe is a subset of one particular prediction tsv for a particular protected group (e.g., the "young" group for the protected class "age"). 

The function will return a `pd.Series` (like a dictionary) mapping the following fairness metrics to their floating point values:
- Accuracy
- TPR(Positive)
- TPR(Neutral)
- TPR(Negative)
- FPR(Positive)
- FPR(Neutral)
- FPR(Negative)

Also refer to https://stackoverflow.com/a/50671617 for details on how the TPR and FPR values are computed.

In [None]:
def get_stats(dataframe: pd.DataFrame) -> pd.Series:
    """
    Input: dataframe representing all predictions for a particular
    protected group in a particular run.
    """
    output: Dict[str, float] = {}

    num_examples = len(dataframe)
    num_correct = np.sum(dataframe["y_pred"] == dataframe["y_true"])

    output["accuracy"] = num_correct / num_examples

    # See https://stackoverflow.com/a/50671617
    confusion_matrix = ConfusionMatrix(dataframe["y_true"], dataframe["y_pred"])

    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
    TP = np.diag(confusion_matrix)
    TN = confusion_matrix.sum() - (FP + FN + TP)

    FPR = FP / (FP + TN)
    FNR = FN / (TP + FN)
    TPR = 1 - FNR

    for label_index, label in enumerate(["negative", "neutral", "positive"]):
        output[label + "_FPR"] = FPR[label_index]
        output[label + "_TPR"] = TPR[label_index]

    return pd.Series(output)

### Calculating Per-Group Statistics

In [None]:
# output_table
analyzed_dataframe = output_table.groupby([*INDEPENDENT_VARIABLES, "run_id", "category", "group"]).apply(get_stats)
analyzed_dataframe.query('category == "grouped_race"')[["positive_TPR"]]

In [None]:
output_table = analyzed_dataframe.reset_index()
os.makedirs(output_folder, exist_ok=True)
output_table.to_csv(os.path.join(output_folder, "aggregated.csv"))

## Calculating Statistics

Before we begin, note that of the protected classes include a large number of groups. You may filter on the protected groups to include in each category. If a category isn't in this dictionary, the visualizations will include all the groups under that category by default.

In [None]:
SELECTED_GROUPS = {
    "grouped_religion": ["christianity", "islam", "judaism"],
}

groups_to_include = []
for category in output_table["category"].unique():
    groups = SELECTED_GROUPS.get(category)
    if groups is None:
        groups = list(set(output_table.query(f'category == "{category}"')["group"]))

    groups_to_include.extend(groups)

print("Groups selected:", ", ".join(groups_to_include))
output_table = output_table.query("group == @groups_to_include")

### "Gap": How each group deviates from the category mean
Below, we will calculate how far each group deviates from the mean of that category. For example, how far does the accuracy on examples mentioning the "young" group deviate from the mean of all examples of the "age" category?

If there are more than one runs for each model, we would calculate gap metrics for each run separately. 

We use a for loop to add the following metrics of interest to our list:
- Accuracy
- False Positive Rates
    - FPR("Negative")
    - FPR("Neutral")
    - FPR("Positive")
- True Positive Rates
    - TPR("Negative")
    - TPR("Neutral")
    - TPR("Positive")

In [None]:
metrics = ["accuracy"]
for label in ["negative", "neutral", "positive"]:
    metrics.append(label + "_FPR")
    metrics.append(label + "_TPR")

gap_metrics = []
metrics_hypotheses = []
for metric in metrics:
    gap_metrics.append(f"{metric}_gap")

print(gap_metrics)

In [None]:
# Calculate gap for each metric and
# save the results in a column named f"{metrics}_gap".
for metric, gap_metric in zip(metrics, gap_metrics):
    # Calculate mean for each metric, group, and run/seed.
    group_by = [*INDEPENDENT_VARIABLES, "category", "run_id"]
    grouped = output_table[[*INDEPENDENT_VARIABLES, "category", "run_id", metric]].groupby(group_by)
    mean = grouped.transform("mean")

    # Broadcast the mean values across the table
    # to find the gap for each entry.
    normalized_values = output_table[metric] - mean[metric]
    output_table[gap_metric] = normalized_values

### Confidence Interval Calculations

In the previous step, we've found the gaps for each (metric, category, group, model, run/seed). If we have more than one runs/seeds for each model, we can aggregate the results to find a confidence interval for each gap variable.

To do so, we will aggregate the previous table (metric, category, group, model, run/seed) along the run/seed axis. The result will be a table with indices (metric, category, group, model)

In [None]:
# Create a (metric, group) placeholder table for storing the output values.
# Unlike in the previous step, "run_id" isn't part of groupby,
# and this axis will be squeezed in the result.
stats_table = (
    output_table[[*INDEPENDENT_VARIABLES, "category", "group", "accuracy"]]
    .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
    .agg(["mean"])
)

In [None]:
for metric, gap_metric in zip(metrics, gap_metrics):
    group_by = [*INDEPENDENT_VARIABLES, "category"]
    grouped = output_table[[*INDEPENDENT_VARIABLES, "category", metric]].groupby(group_by)
    median = grouped.transform("median")
    normalized_values = output_table[metric] - median[metric]
    output_table[gap_metric] = normalized_values

    # See https://stackoverflow.com/a/53522680
    # Calculate mean and stdev for each (metric, category, group, model).
    stats = (
        output_table[[*INDEPENDENT_VARIABLES, "category", "group", gap_metric]]
        .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
        .agg(["mean", "count", "std"])
    )

    mean_values = []
    ci_width = []
    lower_values = []
    upper_values = []
    hypothesis_selected = []

    # Calculate 95% Confidence interval for each
    # (metric, category, group, model).
    for index in stats.index:
        mean, n, stdev = stats.loc[index]  # type: ignore
        lower = mean - 1.96 * stdev / sqrt(n)
        upper = mean + 1.96 * stdev / sqrt(n)

        mean_values.append(mean)
        ci_width.append(1.96 * stdev / sqrt(n))
        lower_values.append(lower)
        upper_values.append(upper)

        if lower > 0:
            hypothesis_selected.append(1)
        elif upper < 0:
            hypothesis_selected.append(-1)
        else:
            hypothesis_selected.append(0)

    stats_table[gap_metric] = mean_values
    stats_table[gap_metric + "_width"] = ci_width
    stats_table[gap_metric + "_lower"] = lower_values
    stats_table[gap_metric + "_upper"] = upper_values

stats_table = stats_table.reset_index()

## Plotting

### Compare between Groups of each category 
E.g. all groups in the "age" category

In [None]:
datasets = stats_table["dataset"].unique()
categories = stats_table["category"].unique()

for model in [*(stats_table["model"].unique()), "all"]:
    for dataset in tqdm(datasets):
        group_output_path = os.path.join(output_folder, model, dataset)
        os.makedirs(group_output_path, exist_ok=True)

        for category in tqdm(categories, leave=False):
            for metric, gap_metric, metric_hypothesis in list(zip(metrics, gap_metrics, metrics_hypotheses)):
                if model == "all":
                    filtered_table = stats_table[
                        (stats_table["category"] == category) & (stats_table["dataset"] == dataset)
                    ]
                else:
                    filtered_table = stats_table[
                        (stats_table["category"] == category)
                        & (stats_table["model"] == model)
                        & (stats_table["dataset"] == dataset)
                    ]

                if len(filtered_table) == 0:
                    continue

                # Removing the grouped from grouped_{category_name}
                category_name = category.split("_")[-1]
                filename_pdf = f"{metric}_{category_name}_{dataset}_{model}.pdf"
                filename_png = f"{metric}_{category_name}_{dataset}_{model}.png"
                title = (
                    f"{format_metric_name(gap_metric)} for {category_name.title()}"
                    + f"- {format_metric_name(dataset)}"
                )

                fig = px.scatter(
                    filtered_table,
                    x="group",
                    facet_col="group",
                    y=gap_metric,
                    color="model",
                    error_y=gap_metric + "_width",
                    labels={
                        "group": "Protected Group",
                        gap_metric: format_metric_name(gap_metric),
                        "model": "LM",
                    },
                    title=title,
                    height=600,
                )

                fig.update_xaxes(matches=None, tickangle=90)
                fig.for_each_xaxis(lambda x: x.update(title=""))
                fig.for_each_annotation(lambda a: a.update(text=""))

                fig = fig.update_layout(
                    scattermode="group",
                    font_color="black",
                    font_size=18,
                    title_x=0.5,
                )
                fig.write_image(os.path.join(group_output_path, filename_pdf))
                fig.write_image(os.path.join(group_output_path, filename_png))

print("Find plots in plots/")

### Compare between models on the same protected group 
E.g. "young" people of the "age" category.


In [None]:
for model in tqdm([*(stats_table["model"].unique()), "all"]):
    for category in tqdm(categories, leave=False):
        if model == "all":
            filtered_table = stats_table[(stats_table["category"] == category) & (stats_table["dataset"] != "")]
        else:
            filtered_table = stats_table[
                (stats_table["category"] == category)
                & (stats_table["model"] == model)
                & (stats_table["dataset"] != "")
            ]

        groups = filtered_table["group"].unique()

        for group in tqdm(groups, leave=False):
            group_output_path = os.path.join(output_folder, model, "by_group", category, group)
            os.makedirs(group_output_path, exist_ok=True)

            group_filtered_table = filtered_table[filtered_table["group"] == group]
            if len(group_filtered_table) == 0:
                continue

            for metric, gap_metric, metric_hypothesis in list(zip(metrics, gap_metrics, metrics_hypotheses)):
                # Removing the grouped from grouped_{category_name}
                category_name = category.split("_")[-1]
                filename_pdf = f"{metric}_{category_name}_{group}_{model}.pdf"
                filename_png = f"{metric}_{category_name}_{group}_{model}.png"
                title = f"{format_metric_name(gap_metric)} for {group.title()} " + f"- {format_metric_name(dataset)}"

                fig = px.scatter(
                    group_filtered_table,
                    x="dataset",
                    facet_col="dataset",
                    y=gap_metric,
                    color="model",
                    error_y=gap_metric + "_width",
                    labels={
                        "group": "Prompt-Tuning Dataset",
                        gap_metric: format_metric_name(gap_metric),
                        "model": "LM",
                        **{dataset: format_metric_name(dataset) for dataset in filtered_table["dataset"].unique()},
                    },
                    title=title,
                    height=600,
                )

                fig.update_xaxes(matches=None, tickangle=90)
                fig.for_each_xaxis(lambda x: x.update(title=""))
                fig.for_each_annotation(lambda a: a.update(text=""))

                fig = fig.update_layout(
                    scattermode="group",
                    font_color="black",
                    font_size=18,
                    title_x=0.5,
                )
                fig.write_image(os.path.join(group_output_path, filename_pdf))
                fig.write_image(os.path.join(group_output_path, filename_png))

print("Find plots in plots/")