# Bias Visualization

### Dependencies
Note that we use the scikit-learn confusion matrix utility to help calculate some of the fairness metrics. We generate the visualizations using plotly Express.

In [156]:
%%capture
%pip install kaleido numpy pandas plotly scikit-learn tqdm ipywidgets

In [157]:
import os
from math import sqrt
from typing import Any, Dict

import ipywidgets as widgets
import numpy as np
import pandas as pd
import plotly.express as px
from IPython.display import clear_output, display
from sklearn.metrics import confusion_matrix as ConfusionMatrix
from tqdm.auto import tqdm
from os import listdir


### Global Constants

This notebook will visualize how each of the following variable affects the fairness towards a particular protected group.

In [158]:
INDEPENDENT_VARIABLES = ["model", "dataset", "num_params"]

## Utilities

### Plotting Utils

In [159]:
def format_metric_name(metric_name: str) -> str:
    """
    Prettify diagram texts by replacing
    snake case with regular title case.
    """
    output_words = []
    for word in metric_name.split("_"):
        word = word.title() if word not in ["FPR", "TPR"] else word.upper()
        output_words.append(word)

    return " ".join(output_words)

## Getting Predictions Ready

### Load Prediction File
Load the prediction `tsv` file from the provided paths

In [160]:
def find_csv_filenames(path_to_dir, template_name, suffix=".tsv"):
    filenames = listdir(path_to_dir)
    return [path_to_dir + filename for filename in filenames if filename.endswith(suffix) and template_name in filename]

# template_name = "amazon"
# template_name = "NS-prompts"
template_name = "regard"

results_path_to_dir = "resources/final_all/"
prediction_tsv_paths = find_csv_filenames(results_path_to_dir, template_name)


plot_name = template_name
# model_names = ["opt-125m", "opt-350m", "opt-1d3b", "opt-6d7b", "roberta-base", "roberta-large", "llama-7b"]
# case_names = [template_name + "_" + model_name + "_SST5" for model_name in model_names]
# prediction_tsv_paths = [
#     "resources/final_all/" + case_name + ".tsv" for case_name in case_names
# ]


# prediction_tsv_paths = [
#     "/Users/david/Documents/SoftPromptTuningFairness/TACLRevisionsData/July_12_2023_SST5_Experiments/"
#     "predictions-eval-sweep-20230709a2.tsv"
# ]

output_folder: str = "stats"
output_table = pd.DataFrame()

for prediction_tsv_path in tqdm(prediction_tsv_paths, ncols=80):
    prediction_tsv_filename = os.path.basename(prediction_tsv_path)
    dataframe = pd.read_csv(prediction_tsv_path, delimiter="\t")
    del dataframe["text"]

    if output_table is None:
        output_table = dataframe
    else:
        assert isinstance(output_table, pd.DataFrame)
        output_table = pd.concat([output_table, dataframe])
    print(prediction_tsv_path)
    print(output_table.size)

  0%|                                                    | 0/15 [00:00<?, ?it/s]

resources/final_all/predictions_regard_2000.tsv
12221
resources/final_all/predictions_regard_5023.tsv
24442
resources/final_all/predictions_regard_4.tsv
36663
resources/final_all/predictions_regard_3679.tsv
48884
resources/final_all/regard_roberta-base_SST5.tsv
109989
resources/final_all/regard_roberta-large_SST5.tsv
171094
resources/final_all/predictions_regard.tsv
232199
resources/final_all/predictions_regard_16087.tsv
244420
resources/final_all/predictions_regard_862.tsv
256641
resources/final_all/predictions_regard_902.tsv
268862
resources/final_all/predictions_regard_455.tsv
281083
resources/final_all/regard_opt-350m_SST5.tsv
342188
resources/final_all/predictions_regard_89.tsv
354409
resources/final_all/regard_opt-125m_SST5.tsv
415514
resources/final_all/predictions_regard_8.tsv
427735


## Calculate Statistics
Since there can be multiple examples for each protected group, we apply the `groupby` method to compute aggregated statistics across all the examples of each protected group.

Refer to the following function for details on how we implemented the metrics.

### Function for aggregating statistics


Below you will find a function that takes in a dataframe. This dataframe is will be automatically generated using the `pd.DataFrame.groupby`. The content of this dataframe is a subset of one particular prediction tsv for a particular protected group (e.g., the "young" group for the protected class "age"). 

The function will return a `pd.Series` (like a dictionary) mapping the following fairness metrics to their floating point values:
- Accuracy
- TPR(Positive)
- TPR(Neutral)
- TPR(Negative)
- FPR(Positive)
- FPR(Neutral)
- FPR(Negative)

Also refer to https://stackoverflow.com/a/50671617 for details on how the TPR and FPR values are computed.

In [161]:
top_n_filter = 5
unique_runs_with_accuracy = output_table[["model", "run_id", "dataset"]].drop_duplicates()
unique_runs_with_accuracy = unique_runs_with_accuracy.sort_values(
    ["model", "dataset"], ascending=False
)
best_runs_by_model_dataset = unique_runs_with_accuracy.groupby(["model", "dataset"]).head(top_n_filter).reset_index()
run_ids_to_keep = best_runs_by_model_dataset["run_id"].tolist()
output_table = output_table.loc[output_table["run_id"].isin(run_ids_to_keep)]

In [162]:
print(output_table.shape)
print(output_table["model"].unique())

(38885, 11)
['/ssd005/projects/llm/fair-llm/llama/llama-2-7b' 'facebook/opt-1.3b'
 'roberta-base fine-tuned' 'roberta-large fine-tuned' 'facebook/opt-6.7b'
 'opt-350m fine-tuned' 'opt-125m fine-tuned']


In [163]:
# Modify the names of the models and the datasets for display purposes
def modify_model_name(model_name: str) -> str:
    if model_name.startswith("facebook/"):
        return model_name.replace("facebook/", "").upper()
    if model_name.startswith("/ssd005/projects/llm/fair-llm/llama/"):
        return model_name.replace("/ssd005/projects/llm/fair-llm/llama/", "").replace("llama", "LLaMA")
    if "fine-tuned" in model_name:
        if model_name.startswith("opt"):
            return model_name.replace("fine-tuned", "").replace("opt", "OPT")
        if model_name.startswith("roberta"):
            return model_name.replace("fine-tuned", "").replace("roberta", "RoBERTa")   
    else:
        llama_name = model_name.replace("/data/models/", "")
        return llama_name.replace("llama/", "LLaMA-")


def modify_dataset_name(dataset_name: str) -> str:
    # dataset_name = dataset_name.replace("jacobthebanana/", "")
    # if "_mapped_" in dataset_name:
    #     return dataset_name.replace("sst5_mapped_grouped", "SST5 Grouped")
    # elif "-mapped-" in dataset_name:
    #     return dataset_name.replace("sst5-mapped-extreme", "SST5 Extreme")
    # elif "data/processed/semeval_3" in dataset_name:
    #     return "SemEval"
    # return ""
    return "SST5 Extreme"

In [164]:
def get_stats(dataframe: pd.DataFrame) -> pd.Series:
    """
    Input: dataframe representing all predictions for a particular
    protected group in a particular run.
    """
    output: Dict[str, float] = {}

    num_examples = len(dataframe)
    num_correct = np.sum(dataframe["y_pred"] == dataframe["y_true"])

    output["accuracy"] = num_correct / num_examples

    # See https://stackoverflow.com/a/50671617
    confusion_matrix = ConfusionMatrix(dataframe["y_true"], dataframe["y_pred"], labels=[0,1,2])

    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
    TP = np.diag(confusion_matrix)
    TN = confusion_matrix.sum() - (FP + FN + TP)

    eps = 10 ** -10

    FP = np.copy(FP) + eps
    FN = np.copy(FN) + eps
    TP = np.copy(TP) + eps
    TN = np.copy(TN) + eps
    
    FPR = FP / (FP + TN)
    FNR = FN / (TP + FN)
    TPR = 1 - FNR

    for label_index, label in enumerate(["negative", "neutral", "positive"]):
        output[label + "_FPR"] = FPR[label_index]
        output[label + "_TPR"] = TPR[label_index]

    return pd.Series(output)

### Calculating Per-Group Statistics

In [165]:
# output_table
analyzed_dataframe = output_table.groupby([*INDEPENDENT_VARIABLES, "run_id", "category", "group"]).apply(get_stats)
analyzed_dataframe.query('category == "grouped_race"')[["positive_TPR"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,positive_TPR
model,dataset,num_params,run_id,category,group,Unnamed: 6_level_1
/ssd005/projects/llm/fair-llm/llama/llama-2-7b,jacobthebanana/sst5-mapped-extreme,-8.00,2sjwxsn3,grouped_race,african_american,0.057143
/ssd005/projects/llm/fair-llm/llama/llama-2-7b,jacobthebanana/sst5-mapped-extreme,-8.00,2sjwxsn3,grouped_race,american_indian,0.057143
/ssd005/projects/llm/fair-llm/llama/llama-2-7b,jacobthebanana/sst5-mapped-extreme,-8.00,2sjwxsn3,grouped_race,asian,0.057143
/ssd005/projects/llm/fair-llm/llama/llama-2-7b,jacobthebanana/sst5-mapped-extreme,-8.00,2sjwxsn3,grouped_race,hispanic,0.057143
/ssd005/projects/llm/fair-llm/llama/llama-2-7b,jacobthebanana/sst5-mapped-extreme,-8.00,2sjwxsn3,grouped_race,pacific_islander,0.057143
...,...,...,...,...,...,...
roberta-large fine-tuned,SST5,0.35,r4,grouped_race,american_indian,0.142857
roberta-large fine-tuned,SST5,0.35,r4,grouped_race,asian,0.085714
roberta-large fine-tuned,SST5,0.35,r4,grouped_race,hispanic,0.142857
roberta-large fine-tuned,SST5,0.35,r4,grouped_race,pacific_islander,0.171429


In [166]:
output_table = analyzed_dataframe.reset_index()
os.makedirs(output_folder, exist_ok=True)
output_table.to_csv(os.path.join(output_folder, "aggregated.csv"))

## Calculating Statistics

Before we begin, note that some of the protected classes include a large number of groups. You may filter on the protected groups to include in each category. If a category isn't in this dictionary, the visualizations will include all the groups under that category by default.

In [167]:
SELECTED_GROUPS = {
    "grouped_religion": ["christianity", "islam", "judaism"],
    "grouped_disability": [
        "without",
        "cognitive",
        "physical",
        "hearing",
        "sight",
        "chronic_illness",
        "mobility",
        "mental_health",
    ],
}

groups_to_include = []
for category in output_table["category"].unique():
    groups = SELECTED_GROUPS.get(category)
    if groups is None:
        groups = list(set(output_table.query(f'category == "{category}"')["group"]))

    groups_to_include.extend(groups)

print("Groups selected:", ", ".join(groups_to_include))
output_table = output_table.query("group == @groups_to_include")

Groups selected: pacific_islander, african_american, hispanic, american_indian, white, asian


### "Gap": How each group deviates from the category mean
Below, we will calculate how far each group deviates from the mean of that category. For example, how far does the accuracy on examples mentioning the "young" group deviate from the mean of all examples of the "age" category?

If there are more than one runs for each model, we would calculate gap metrics for each run separately. 

We use a for loop to add the following metrics of interest to our list:
- Accuracy
- False Positive Rates
    - FPR("Negative")
    - FPR("Neutral")
    - FPR("Positive")
- True Positive Rates
    - TPR("Negative")
    - TPR("Neutral")
    - TPR("Positive")

In [168]:
metrics = ["accuracy"]
for label in ["negative", "neutral", "positive"]:
    metrics.append(label + "_FPR")
    metrics.append(label + "_TPR")

gap_metrics = []
for metric in metrics:
    gap_metrics.append(f"{metric}_gap")

print(gap_metrics)

['accuracy_gap', 'negative_FPR_gap', 'negative_TPR_gap', 'neutral_FPR_gap', 'neutral_TPR_gap', 'positive_FPR_gap', 'positive_TPR_gap']


### Confidence Interval Calculations

In the previous step, we've found the gaps for each (metric, category, group, model, run/seed). If we have more than one runs/seeds for each model, we can aggregate the results to find a confidence interval for each gap variable.

To do so, we will aggregate the previous table (metric, category, group, model, run/seed) along the run/seed axis. The result will be a table with indices (metric, category, group, model).

In [169]:
# Create a (metric, group) placeholder table for storing the output values.
# Unlike in the previous step, "run_id" isn't part of groupby,
# and this axis will be squeezed in the result.
stats_table = (
    output_table[[*INDEPENDENT_VARIABLES, "category", "group", "accuracy"]]
    .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
    .agg(["mean"])
)

In [170]:
for metric, gap_metric in zip(metrics, gap_metrics):
    group_by = [*INDEPENDENT_VARIABLES, "category", "run_id"]
    grouped = output_table[[*INDEPENDENT_VARIABLES, "category", "run_id", metric]].groupby(group_by)
    group_mean = grouped.transform("mean")
    normalized_values = output_table[metric] - group_mean[metric]
    output_table[gap_metric] = normalized_values

    # See https://stackoverflow.com/a/53522680
    # Calculate mean and stdev for each (metric, category, group, model).
    stats = (
        output_table[[*INDEPENDENT_VARIABLES, "category", "group", gap_metric]]
        .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
        .agg(["mean", "count", "std"])
    )

    mean_values = []
    ci_width = []
    lower_values = []
    upper_values = []
    hypothesis_selected = []
    z_score = 1.96

    # Calculate 95% Confidence interval for each
    # (metric, category, group, model).
    for index in stats.index:
        mean, n, stdev = stats.loc[index]  # type: ignore
        lower = mean - z_score * stdev / sqrt(n)
        upper = mean + z_score * stdev / sqrt(n)

        mean_values.append(mean)
        ci_width.append(z_score * stdev / sqrt(n))
        lower_values.append(lower)
        upper_values.append(upper)

        if lower > 0:
            hypothesis_selected.append(1)
        elif upper < 0:
            hypothesis_selected.append(-1)
        else:
            hypothesis_selected.append(0)

    stats_table[gap_metric] = mean_values
    stats_table[gap_metric + "_width"] = ci_width
    stats_table[gap_metric + "_lower"] = lower_values
    stats_table[gap_metric + "_upper"] = upper_values

stats_table = stats_table.reset_index()

In [171]:
stats_table["model"] = stats_table["model"].apply(modify_model_name)
stats_table["dataset"] = stats_table["dataset"].apply(modify_dataset_name)

## Plotting

### Compare between Groups of each category 
E.g. all groups in the "age" category

In [172]:
datasets = stats_table["dataset"].unique()
models = ["(all)", *(stats_table["model"].unique())]
categories = stats_table["category"].unique()

##### Interactive Visualization Utilities

In [173]:
model_selector = widgets.Dropdown(options=models, value=models[0], description="Model:")
category_selector = widgets.Dropdown(options=categories, value=categories[0], description="Category:")
metric_selector = widgets.Dropdown(options=metrics, value=metrics[0], description="Metric:")
dataset_selector = widgets.Dropdown(options=datasets, value=datasets[0], description="Dataset:")

In [174]:
def visualize_fn(button: Any) -> None:
    """
    Function to be called when the button is pressed.
    This function clears previous output and any widget
    that has already been rendered. To add these widgets
    back (including the button) the "button" parameter
    needs to be supplied from the button action.
    """
    clear_output()

    display(model_selector)
    display(category_selector)
    display(metric_selector)
    display(dataset_selector)
    display(button)

    model = model_selector.value
    category: str = category_selector.value  # type: ignore
    metric: str = metric_selector.value  # type: ignore
    dataset: str = dataset_selector.value  # type: ignore

    metric += "_gap"

    if model == "(all)":
        filtered_table = stats_table[(stats_table["category"] == category) & (stats_table["dataset"] == dataset)]
    else:
        filtered_table = stats_table[
            (stats_table["category"] == category)
            & (stats_table["model"] == model)
            & (stats_table["dataset"] == dataset)
        ]

    if len(filtered_table) == 0:
        print("None of the runs in the table matched the selections.")
        return


    # Removing the "grouped" from grouped_{category_name}
    category_name = category.split("_")[-1]
    title = f"Gap for {category_name.title()}<br>{dataset}"

    # Generate scatterplot with confidence interval.
    fig = px.scatter(
        filtered_table,
        x="group",
        facet_col="group",
        y=metric,
        color="model",
        error_y=metric + "_width",
        labels={
            "group": "Protected Group",
            metric: format_metric_name(metric),
            "model": "LM",
        },
        # category_orders={
        #     "model": ["OPT-125M", "OPT-350M", "OPT-1.3B", "OPT-2.7B", "OPT-6.7B", "OPT-13B", "LLaMA-7B", "LLaMA-13B"]
        # },
        title=title,
        height=900,
        width=1200,
    )

    # Hide redundant x axis labels.
    fig.update_xaxes(
        matches=None, tickangle=-35, title_font_family="Helvetica", tickfont_family="Helvetica", tickfont_size=28
    )
    fig.update_yaxes(title_font_family="Helvetica", tickfont_family="Helvetica", title_font_size=36)
    fig.for_each_xaxis(lambda x: x.update(title=""))
    fig.for_each_annotation(lambda a: a.update(text=""))
    fig.update_traces(
        marker=dict(
            size=10,
        ),
        error_y=dict(thickness=5),
    )

    fig = fig.update_layout(legend=dict(title_font_family="Helvetica", font=dict(size=28)))

    fig = fig.update_layout(
        scattermode="group",
        font_color="black",
        title_font_family="Helvetica",
        font_size=24,
        title_x=0.47,
    )

    fig.update_yaxes(fixedrange=True)
    fig.update_yaxes(range=[-0.2, 0.2])


    plot_dir = "plots/" + plot_name + "/"
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)
    plot_output_path = plot_dir + metric + ".png"
    fig.write_image(plot_output_path, width=370)
    with open(plot_output_path, "rb") as img_file:
        image_widget = widgets.Image(value=img_file.read(), format="png", width=370)
        display(image_widget)

##### Interative Visualization

In [None]:
button = widgets.Button(description="Visualize")
button.on_click(visualize_fn)
visualize_fn(button)

Dropdown(description='Model:', options=('(all)', 'LLaMA-2-7b', 'OPT-1.3B', 'OPT-6.7B', 'OPT-125m ', 'OPT-350m …

Dropdown(description='Category:', options=('grouped_race',), value='grouped_race')

Dropdown(description='Metric:', index=1, options=('accuracy', 'negative_FPR', 'negative_TPR', 'neutral_FPR', '…

Dropdown(description='Dataset:', options=('SST5 Extreme',), value='SST5 Extreme')

Button(description='Visualize', style=ButtonStyle())

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x04\xb0\x00\x00\x03\x84\x08\x06\x00\x00\x00\xb1m\xc8…

### Compare between models and datasets on the same protected group 
E.g. "young" people of the "age" category.


##### Interative Visualization Utilities

In [176]:
unique_groups = stats_table["group"].unique()
model_selector = widgets.Dropdown(options=models, value=models[0], description="Model:")
group_selector = widgets.Dropdown(options=unique_groups, value=unique_groups[0], description="Group:")
metric_selector = widgets.Dropdown(options=metrics, value=metrics[0], description="Metric:")

In [177]:
def visualize_fn_2(button: Any) -> None:
    clear_output()
    display(model_selector)
    display(group_selector)
    display(metric_selector)
    display(button)

    model = model_selector.value
    group: str = group_selector.value  # type: ignore
    metric: str = metric_selector.value  # type: ignore

    gap_metric = metric + "_gap"

    if model == "(all)":
        filtered_table = stats_table[(stats_table["dataset"] != "")]
    else:
        filtered_table = stats_table[(stats_table["model"] == model) & (stats_table["dataset"] != "")]

    group_filtered_table = filtered_table[filtered_table["group"] == group]

    if len(group_filtered_table) == 0:
        print("None of the runs in the table matched the selections.")
        return

    title = f"Gap for {group.title()} "

    fig = px.scatter(
        group_filtered_table,
        x="dataset",
        facet_col="dataset",
        y=gap_metric,
        color="model",
        error_y=gap_metric + "_width",
        labels={
            "group": "Protected Group",
            gap_metric: format_metric_name(gap_metric),
            "model": "LM",
            **{dataset: format_metric_name(dataset) for dataset in filtered_table["dataset"].unique()},
        },
        title=title,
        height=600,
    )

    fig.update_xaxes(matches=None, tickangle=90)
    fig.for_each_xaxis(lambda x: x.update(title=""))
    fig.for_each_annotation(lambda a: a.update(text=""))


    fig = fig.update_layout(
        scattermode="group",
        font_color="black",
        font_size=18,
        title_x=0.5,
    )
    fig.update_yaxes(fixedrange=True)
    fig.update_yaxes(range=[-0.1, 0.1])


    plot_dir = "plots/" + plot_name + "/"
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)
    plot_output_path = plot_dir + metric + ".png"

    fig.write_image(plot_output_path, scale=5)
    with open(plot_output_path, "rb") as img_file:
        image_widget = widgets.Image(value=img_file.read(), format="png", width=600)
        display(image_widget)

##### Interative Visualization

In [178]:
button = widgets.Button(description="Visualize")
button.on_click(visualize_fn_2)
visualize_fn_2(button)

Dropdown(description='Model:', options=('(all)', 'LLaMA-2-7b', 'OPT-1.3B', 'OPT-6.7B', 'OPT-125m ', 'OPT-350m …

Dropdown(description='Group:', options=('african_american', 'american_indian', 'asian', 'hispanic', 'pacific_i…

Dropdown(description='Metric:', options=('accuracy', 'negative_FPR', 'negative_TPR', 'neutral_FPR', 'neutral_T…

Button(description='Visualize', style=ButtonStyle())

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\r\xac\x00\x00\x0b\xb8\x08\x06\x00\x00\x00\xb7\xe5\x1…