# Bias Visualization

### Dependencies
Note that we use the scikit-learn confusion matrix utility to help calculate some of the fairness metrics. We generate the visualizations using plotly Express.

In [21]:
%%capture
%pip install kaleido numpy pandas plotly scikit-learn tqdm ipywidgets

In [22]:
import os
from math import sqrt
from typing import Any, Dict, List

import ipywidgets as widgets
import numpy as np
import pandas as pd
import plotly.express as px
from IPython.display import clear_output, display
from sklearn.metrics import confusion_matrix as ConfusionMatrix
from tqdm.auto import tqdm

### Global Constants

This notebook will visualize how each of the following variable affects the fairness towards a particular protected group.

In [23]:
INDEPENDENT_VARIABLES = ["model", "dataset", "num_params"]

## Utilities

### Plotting Utils

In [24]:
def format_metric_name(metric_name: str) -> str:
    """
    Prettify diagram texts by replacing
    snake case with regular title case.
    """
    output_words = []
    for word in metric_name.split("_"):
        word = word.title() if word not in ["FPR", "TPR"] else word.upper()
        output_words.append(word)

    return " ".join(output_words)


# Modify the names of the models and the datasets for display purposes
def modify_model_name(model_name: str) -> str:
    if model_name.startswith("opt"):
        return model_name.replace("opt", "OPT").upper()
    elif model_name.startswith("roberta"):
        return model_name.replace("roberta", "RoBERTa")
    elif "llama" in model_name:
        return model_name.replace("llama", "LLaMA")
    else:
        return model_name


def modify_dataset_name(dataset_name: str) -> str:
    return dataset_name

## Getting Predictions Ready

### Load Prediction Files
Load all prediction `tsv` files from a given folder into a dataframe.

In [25]:
prediction_folder: str = "resources/predictions/"
prediction_files: List[str] = os.listdir(prediction_folder)
prediction_tsv_paths: List[str] = [
    os.path.join(prediction_folder, prediction_file)
    for prediction_file in prediction_files
    if prediction_file.endswith(".tsv") and "placeholder" not in prediction_file
]

output_folder: str = "stats"
output_table = pd.DataFrame()

for prediction_tsv_path in tqdm(prediction_tsv_paths, ncols=80):
    prediction_tsv_filename = os.path.basename(prediction_tsv_path)
    dataframe = pd.read_csv(prediction_tsv_path, delimiter="\t")
    del dataframe["text"]

    if output_table is None:
        output_table = dataframe
    else:
        assert isinstance(output_table, pd.DataFrame)
        output_table = pd.concat([output_table, dataframe])

  0%|                                                     | 0/3 [00:00<?, ?it/s]

## Calculate Statistics
Since there can be multiple examples for each protected group, we apply the `groupby` method to compute aggregated statistics across all the examples of each protected group.

Refer to the following function for details on how we implemented the metrics.

### Function for aggregating statistics


Below you will find a function that takes in a dataframe. This dataframe is will be automatically generated using the `pd.DataFrame.groupby`. The content of this dataframe is a subset of one particular prediction tsv for a particular protected group (e.g., the "young" group for the protected class "age"). 

The function will return a `pd.Series` (like a dictionary) mapping the following fairness metrics to their floating point values:
- Accuracy
- TPR(Positive)
- TPR(Neutral)
- TPR(Negative)
- FPR(Positive)
- FPR(Neutral)
- FPR(Negative)

Also refer to https://stackoverflow.com/a/50671617 for details on how the TPR and FPR values are computed.

In [26]:
def get_stats(dataframe: pd.DataFrame) -> pd.Series:
    """
    Input: dataframe representing all predictions for a particular
    protected group in a particular run.
    """
    output: Dict[str, float] = {}

    num_examples = len(dataframe)
    num_correct = np.sum(dataframe["y_pred"] == dataframe["y_true"])

    output["accuracy"] = num_correct / num_examples

    # See https://stackoverflow.com/a/50671617
    confusion_matrix = ConfusionMatrix(dataframe["y_true"], dataframe["y_pred"])

    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
    TP = np.diag(confusion_matrix)
    TN = confusion_matrix.sum() - (FP + FN + TP)

    FPR = FP / (FP + TN)
    FNR = FN / (TP + FN)
    TPR = 1 - FNR

    for label_index, label in enumerate(["negative", "neutral", "positive"]):
        output[label + "_FPR"] = FPR[label_index]
        output[label + "_TPR"] = TPR[label_index]

    return pd.Series(output)

### Calculating Per-Group Statistics

In [27]:
# output_table
analyzed_dataframe = output_table.groupby([*INDEPENDENT_VARIABLES, "run_id", "category", "group"]).apply(get_stats)
analyzed_dataframe.query('category == "grouped_race"')[["positive_TPR"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,positive_TPR
model,dataset,num_params,run_id,category,group,Unnamed: 6_level_1
opt-175b,SST5,175.0,run_1,grouped_race,african_american,1.0
opt-175b,SST5,175.0,run_1,grouped_race,american_indian,1.0
opt-175b,SST5,175.0,run_1,grouped_race,asian,1.0
opt-175b,SST5,175.0,run_1,grouped_race,hispanic,1.0
opt-175b,SST5,175.0,run_1,grouped_race,pacific_islander,1.0
opt-175b,SST5,175.0,run_1,grouped_race,white,0.75
opt-6.7b,SST5,6.7,run_1,grouped_race,african_american,1.0
opt-6.7b,SST5,6.7,run_1,grouped_race,american_indian,1.0
opt-6.7b,SST5,6.7,run_1,grouped_race,asian,1.0
opt-6.7b,SST5,6.7,run_1,grouped_race,hispanic,1.0


In [28]:
output_table = analyzed_dataframe.reset_index()
os.makedirs(output_folder, exist_ok=True)
output_table.to_csv(os.path.join(output_folder, "aggregated.csv"))

## Calculating Statistics

Before we begin, note that some of the protected classes include a large number of groups. You may filter on the protected groups to include in each category. If a category isn't in this dictionary, the visualizations will include all the groups under that category by default.

In [29]:
SELECTED_GROUPS = {
    "grouped_religion": ["christianity", "islam", "judaism"],
    "grouped_disability": [
        "without",
        "cognitive",
        "physical",
        "hearing",
        "sight",
        "chronic_illness",
        "mobility",
        "mental_health",
    ],
}

groups_to_include = []
for category in output_table["category"].unique():
    groups = SELECTED_GROUPS.get(category)
    if groups is None:
        groups = list(set(output_table.query(f'category == "{category}"')["group"]))

    groups_to_include.extend(groups)

print("Groups selected:", ", ".join(groups_to_include))
output_table = output_table.query("group == @groups_to_include")

Groups selected: adult, old, young, 5q, 2q, 4q, 3q, 1q, 6q, without, cognitive, physical, hearing, sight, chronic_illness, mobility, mental_health, aab, cis, many-genders, no-gender, trans, female, male, non-binary, asian, american_indian, hispanic, white, pacific_islander, african_american, christianity, islam, judaism, asexual, heterosexual, other, bisexual, homosexual, female, male


### "Gap": How each group deviates from the category mean
Below, we will calculate how far each group deviates from the mean of that category. For example, how far does the accuracy on examples mentioning the "young" group deviate from the mean of all examples of the "age" category?

If there are more than one runs for each model, we would calculate gap metrics for each run separately. 

We use a for loop to add the following metrics of interest to our list:
- Accuracy
- False Positive Rates
    - FPR("Negative")
    - FPR("Neutral")
    - FPR("Positive")
- True Positive Rates
    - TPR("Negative")
    - TPR("Neutral")
    - TPR("Positive")

In [30]:
metrics = ["accuracy"]
for label in ["negative", "neutral", "positive"]:
    metrics.append(label + "_FPR")
    metrics.append(label + "_TPR")

gap_metrics = []
for metric in metrics:
    gap_metrics.append(f"{metric}_gap")

print(gap_metrics)

['accuracy_gap', 'negative_FPR_gap', 'negative_TPR_gap', 'neutral_FPR_gap', 'neutral_TPR_gap', 'positive_FPR_gap', 'positive_TPR_gap']


### Confidence Interval Calculations

In the previous step, we've found the gaps for each (metric, category, group, model, run/seed). If we have more than one runs/seeds for each model, we can aggregate the results to find a confidence interval for each gap variable.

To do so, we will aggregate the previous table (metric, category, group, model, run/seed) along the run/seed axis. The result will be a table with indices (metric, category, group, model).

In [31]:
# Create a (metric, group) placeholder table for storing the output values.
# Unlike in the previous step, "run_id" isn't part of groupby,
# and this axis will be squeezed in the result.
stats_table = (
    output_table[[*INDEPENDENT_VARIABLES, "category", "group", "accuracy"]]
    .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
    .agg(["mean"])
)

In [32]:
# Calculate gap for each metric and
# save the results in a column named f"{metrics}_gap".
for metric, gap_metric in zip(metrics, gap_metrics):
    # Calculate mean for each metric, group, and run/seed.
    group_by = [*INDEPENDENT_VARIABLES, "category", "run_id"]
    grouped = output_table[[*INDEPENDENT_VARIABLES, "category", "run_id", metric]].groupby(group_by)
    group_mean = grouped.transform("mean")

    # Broadcast the mean values across the table
    # to find the gap for each entry.
    normalized_values = output_table[metric] - group_mean[metric]
    output_table[gap_metric] = normalized_values

    # See https://stackoverflow.com/a/53522680
    # Calculate mean and stdev for each (metric, category, group, model).
    stats = (
        output_table[[*INDEPENDENT_VARIABLES, "category", "group", gap_metric]]
        .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
        .agg(["mean", "count", "std"])
    )

    mean_values = []
    ci_width = []
    lower_values = []
    upper_values = []
    hypothesis_selected = []
    z_score = 1.96

    # Calculate 95% Confidence interval for each
    # (metric, category, group, model).
    for index in stats.index:
        mean, n, stdev = stats.loc[index]  # type: ignore
        lower = mean - z_score * stdev / sqrt(n)
        upper = mean + z_score * stdev / sqrt(n)

        mean_values.append(mean)
        ci_width.append(z_score * stdev / sqrt(n))
        lower_values.append(lower)
        upper_values.append(upper)

        if lower > 0:
            hypothesis_selected.append(1)
        elif upper < 0:
            hypothesis_selected.append(-1)
        else:
            hypothesis_selected.append(0)

    stats_table[gap_metric] = mean_values
    stats_table[gap_metric + "_width"] = ci_width
    stats_table[gap_metric + "_lower"] = lower_values
    stats_table[gap_metric + "_upper"] = upper_values

stats_table = stats_table.reset_index()

In [33]:
stats_table["model"] = stats_table["model"].apply(modify_model_name)
stats_table["dataset"] = stats_table["dataset"].apply(modify_dataset_name)

## Plotting

### Compare between Groups of each category 
E.g. all groups in the "age" category

In [34]:
datasets = ["(all)", *stats_table["dataset"].unique()]
models = ["(all)", *(stats_table["model"].unique())]
categories = stats_table["category"].unique()

##### Interactive Visualization Utilities

In [35]:
model_selector = widgets.Dropdown(options=models, value=models[0], description="Model:")
category_selector = widgets.Dropdown(options=categories, value=categories[0], description="Category:")
metric_selector = widgets.Dropdown(options=metrics, value=metrics[0], description="Metric:")
dataset_selector = widgets.Dropdown(options=datasets, value=datasets[0], description="Dataset:")

In [36]:
def visualize_fn(button: Any) -> None:
    """
    Function to be called when the button is pressed.
    This function clears previous output and any widget
    that has already been rendered. To add these widgets
    back (including the button) the "button" parameter
    needs to be supplied from the button action.
    """
    clear_output()

    display(model_selector)
    display(category_selector)
    display(metric_selector)
    display(dataset_selector)
    display(button)

    model = model_selector.value
    category: str = category_selector.value  # type: ignore
    metric: str = metric_selector.value  # type: ignore
    dataset: str = dataset_selector.value  # type: ignore

    metric += "_gap"

    if model == "(all)" and dataset == "(all)":
        filtered_table = stats_table[(stats_table["category"] == category)]
    elif model == "(all)":
        filtered_table = stats_table[(stats_table["category"] == category) & (stats_table["dataset"] == dataset)]
    elif dataset == "(all)":
        filtered_table = stats_table[(stats_table["category"] == category) & (stats_table["model"] == model)]
    else:
        filtered_table = stats_table[
            (stats_table["category"] == category)
            & (stats_table["model"] == model)
            & (stats_table["dataset"] == dataset)
        ]

    if len(filtered_table) == 0:
        print("None of the runs in the table matched the selections.")
        return

    # Removing the "grouped" from grouped_{category_name}
    category_name = category.split("_")[-1]
    title = f"Gap for {category_name.title()}<br>{dataset}"

    # Generate scatterplot with confidence interval.
    fig = px.scatter(
        filtered_table,
        x="group",
        facet_col="group",
        y=metric,
        color="model",
        error_y=metric + "_width",
        labels={
            "group": "Protected Group",
            metric: format_metric_name(metric),
            "model": "LM",
        },
        title=title,
        height=700,
        width=1200,
    )

    # Hide redundant x axis labels.
    fig.update_xaxes(
        matches=None, tickangle=-35, title_font_family="Helvetica", tickfont_family="Helvetica", tickfont_size=28
    )
    fig.update_yaxes(title_font_family="Helvetica", tickfont_family="Helvetica", title_font_size=36)
    fig.for_each_xaxis(lambda x: x.update(title=""))
    fig.for_each_annotation(lambda a: a.update(text=""))
    fig.update_traces(
        marker=dict(
            size=14,
        ),
        error_y=dict(thickness=5),
    )

    fig = fig.update_layout(legend=dict(title_font_family="Helvetica", font=dict(size=28)))

    fig = fig.update_layout(
        scattermode="group",
        font_color="black",
        title_font_family="Helvetica",
        font_size=24,
        title_x=0.47,
    )

    plot_output_path = "plot.png"
    fig.write_image(plot_output_path, scale=3)
    with open(plot_output_path, "rb") as img_file:
        image_widget = widgets.Image(value=img_file.read(), format="png", width=700, height=1200)
        display(image_widget)

##### Interative Visualization

In [37]:
button = widgets.Button(description="Visualize")
button.on_click(visualize_fn)
visualize_fn(button)

Dropdown(description='Model:', options=('(all)', 'OPT-175B', 'OPT-6.7B', 'RoBERTa-base'), value='(all)')

Dropdown(description='Category:', options=('grouped_age', 'grouped_country_by_gdp_ppp_quantile', 'grouped_disa…

Dropdown(description='Metric:', options=('accuracy', 'negative_FPR', 'negative_TPR', 'neutral_FPR', 'neutral_T…

Dropdown(description='Dataset:', options=('(all)', 'SST5', 'TweetEval'), value='(all)')

Button(description='Visualize', style=ButtonStyle())

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x0e\x10\x00\x00\x084\x08\x06\x00\x00\x00\xe2W\xdd \x…

### Compare between models and datasets on the same protected group 
E.g. "young" people of the "age" category.


##### Interative Visualization Utilities

In [38]:
unique_groups = stats_table["group"].unique()
model_selector = widgets.Dropdown(options=models, value=models[0], description="Model:")
group_selector = widgets.Dropdown(options=unique_groups, value=unique_groups[0], description="Group:")
metric_selector = widgets.Dropdown(options=metrics, value=metrics[0], description="Metric:")

In [39]:
def visualize_fn_2(button: Any) -> None:
    clear_output()
    display(model_selector)
    display(group_selector)
    display(metric_selector)
    display(button)

    model = model_selector.value
    group: str = group_selector.value  # type: ignore
    metric: str = metric_selector.value  # type: ignore

    gap_metric = metric + "_gap"

    if model == "(all)":
        filtered_table = stats_table[(stats_table["dataset"] != "")]
    else:
        filtered_table = stats_table[(stats_table["model"] == model) & (stats_table["dataset"] != "")]

    group_filtered_table = filtered_table[filtered_table["group"] == group]

    if len(group_filtered_table) == 0:
        print("None of the runs in the table matched the selections.")
        return

    title = f"Gap for {group.title()} "

    fig = px.scatter(
        group_filtered_table,
        x="dataset",
        facet_col="dataset",
        y=gap_metric,
        color="model",
        error_y=gap_metric + "_width",
        labels={
            "group": "Protected Group",
            gap_metric: format_metric_name(gap_metric),
            "model": "LM",
            **{dataset: format_metric_name(dataset) for dataset in filtered_table["dataset"].unique()},
        },
        title=title,
        height=700,
        width=1200,
    )

    # Hide redundant x axis labels.
    fig.update_xaxes(
        matches=None, tickangle=-35, title_font_family="Helvetica", tickfont_family="Helvetica", tickfont_size=28
    )
    fig.update_yaxes(title_font_family="Helvetica", tickfont_family="Helvetica", title_font_size=36)
    fig.for_each_xaxis(lambda x: x.update(title=""))
    fig.for_each_annotation(lambda a: a.update(text=""))
    fig.update_traces(
        marker=dict(
            size=14,
        ),
        error_y=dict(thickness=5),
    )

    fig = fig.update_layout(legend=dict(title_font_family="Helvetica", font=dict(size=28)))

    fig = fig.update_layout(
        scattermode="group",
        font_color="black",
        title_font_family="Helvetica",
        font_size=24,
        title_x=0.47,
    )

    plot_output_path = "plot.png"
    fig.write_image(plot_output_path, scale=5)
    with open(plot_output_path, "rb") as img_file:
        image_widget = widgets.Image(value=img_file.read(), format="png", width=700, height=1200)
        display(image_widget)

##### Interative Visualization

In [40]:
button = widgets.Button(description="Visualize")
button.on_click(visualize_fn_2)
visualize_fn_2(button)

Dropdown(description='Model:', options=('(all)', 'OPT-175B', 'OPT-6.7B', 'RoBERTa-base'), value='(all)')

Dropdown(description='Group:', options=('adult', 'old', 'young', '1q', '2q', '3q', '4q', '5q', '6q', 'chronic_…

Dropdown(description='Metric:', options=('accuracy', 'negative_FPR', 'negative_TPR', 'neutral_FPR', 'neutral_T…

Button(description='Visualize', style=ButtonStyle())

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x17p\x00\x00\r\xac\x08\x06\x00\x00\x00E\x06\x98\x97\…