# Bias Visualization

### Dependencies
Note that we use the scikit-learn confusion matrix utility to help calculate some of the fairness metrics. We generate the visualizations using plotly Express.

In [None]:
import os
from math import sqrt
from typing import Any, Dict

import ipywidgets as widgets
import numpy as np
import pandas as pd
import plotly.express as px
from IPython.display import clear_output, display
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix as ConfusionMatrix
from tqdm.auto import tqdm

### Global Constants

This notebook will visualize how each of the following variable affects the fairness towards a particular protected group.

In [None]:
INDEPENDENT_VARIABLES = ["model", "dataset", "num_params"]

## Utilities

### Plotting Utils

In [None]:
def format_metric_name(metric_name: str) -> str:
    """
    Prettify diagram texts by replacing
    snake case with regular title case.
    """
    output_words = []
    for word in metric_name.split("_"):
        word = word.title() if word not in ["FPR", "TPR"] else word.upper()
        output_words.append(word)

    return " ".join(output_words)

## Getting Predictions Ready

In [None]:
# template_name = "amazon"
# template_name = "NS-prompts"
template_name = "regard"

In [None]:
def find_csv_filenames(path_to_dir, template_name, suffix=".tsv"):
    all_files = []
    for root, dirs, files in os.walk(path_to_dir):
        for file in files:
            if file.endswith(suffix) and template_name in file:
                all_files.append(os.path.join(root, file))
    return all_files

predictions_dir = "../predictions/"
prediction_tsv_paths = find_csv_filenames(predictions_dir, template_name)

plots_dir = "plots/"
output_folder: str = "stats"
output_table = pd.DataFrame()

for prediction_tsv_path in tqdm(prediction_tsv_paths, ncols=80):
    prediction_tsv_filename = os.path.basename(prediction_tsv_path)
    dataframe = pd.read_csv(prediction_tsv_path, delimiter="\t")
    del dataframe["text"]

    if output_table is None:
        output_table = dataframe
    else:
        assert isinstance(output_table, pd.DataFrame)
        output_table = pd.concat([output_table, dataframe])

## Calculate Statistics
Since there can be multiple examples for each protected group, we apply the `groupby` method to compute aggregated statistics across all the examples of each protected group.

Refer to the following function for details on how we implemented the metrics.

### Function for aggregating statistics


Below you will find a function that takes in a dataframe. This dataframe is will be automatically generated using the `pd.DataFrame.groupby`. The content of this dataframe is a subset of one particular prediction tsv for a particular protected group (e.g., the "young" group for the protected class "age"). 

The function will return a `pd.Series` (like a dictionary) mapping the following fairness metrics to their floating point values:
- Accuracy
- TPR(Positive)
- TPR(Neutral)
- TPR(Negative)
- FPR(Positive)
- FPR(Neutral)
- FPR(Negative)

Also refer to https://stackoverflow.com/a/50671617 for details on how the TPR and FPR values are computed.

In [None]:
top_n_filter = 5
unique_runs_with_accuracy = output_table[["model", "run_id", "dataset"]].drop_duplicates()
unique_runs_with_accuracy = unique_runs_with_accuracy.sort_values(
    ["model", "dataset"], ascending=False
)
best_runs_by_model_dataset = unique_runs_with_accuracy.groupby(["model", "dataset"]).head(top_n_filter).reset_index()
run_ids_to_keep = best_runs_by_model_dataset["run_id"].tolist()
output_table = output_table.loc[output_table["run_id"].isin(run_ids_to_keep)]

In [None]:
# Modify the names of the models and the datasets for display purposes
def modify_model_name(model_name: str) -> str:
    model_name = model_name.replace("125m", "125M")
    model_name = model_name.replace("350m", "350M")
    model_name = model_name.replace("1.3b", "1.3B")
    model_name = model_name.replace("7b", "7B")
    model_name = model_name.replace("8b", "8B")
    model_name = model_name.replace("13b", "13B")
    model_name = model_name.replace("base", "125M")
    model_name = model_name.replace("large", "355M")
    model_name = model_name.replace("/data/llama-farnaz/", "")
    model_name = model_name.replace(" fine-tuned", "")
    model_name = model_name.replace("facebook/", "")
    model_name = model_name.replace("-lora-lr1e4", "")
    model_name = model_name.replace("opt", "OPT")
    model_name = model_name.replace("roberta", "RoBERTa")
    model_name = model_name.replace("llama", "Llama")
    model_name = model_name.replace("mistral", "Mistral")

    return model_name


def modify_dataset_name(dataset_name: str) -> str:
    if "extreme" in dataset_name:
        return "SST5 Extreme"
    elif "grouped" in dataset_name:
        return "SST5 Grouped"
    return "SST5 Grouped"

In [None]:
def get_stats(dataframe: pd.DataFrame) -> pd.Series:
    """
    Input: dataframe representing all predictions for a particular
    protected group in a particular run.
    """
    output: Dict[str, float] = {}

    num_examples = len(dataframe)
    num_correct = np.sum(dataframe["y_pred"] == dataframe["y_true"])

    output["accuracy"] = num_correct / num_examples

    # See https://stackoverflow.com/a/50671617
    confusion_matrix = ConfusionMatrix(dataframe["y_true"], dataframe["y_pred"], labels=[0,1,2])

    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
    TP = np.diag(confusion_matrix)
    TN = confusion_matrix.sum() - (FP + FN + TP)

    eps = 10 ** -10

    FP = np.copy(FP) + eps
    FN = np.copy(FN) + eps
    TP = np.copy(TP) + eps
    TN = np.copy(TN) + eps

    FPR = FP / (FP + TN)
    FNR = FN / (TP + FN)
    TPR = 1 - FNR

    for label_index, label in enumerate(["negative", "neutral", "positive"]):
        output[label + "_FPR"] = FPR[label_index]
        output[label + "_TPR"] = TPR[label_index]

    return pd.Series(output)

### Calculating Per-Group Statistics

In [None]:
# output_table
analyzed_dataframe = output_table.groupby([*INDEPENDENT_VARIABLES, "run_id", "category", "group"]).apply(get_stats)
analyzed_dataframe.query('category == "grouped_race"')[["positive_TPR"]]

In [None]:
output_table = analyzed_dataframe.reset_index()
os.makedirs(output_folder, exist_ok=True)
output_table.to_csv(os.path.join(output_folder, "aggregated.csv"))

## Calculating Statistics

Before we begin, note that some of the protected classes include a large number of groups. You may filter on the protected groups to include in each category. If a category isn't in this dictionary, the visualizations will include all the groups under that category by default.

In [None]:
SELECTED_GROUPS = {
    "grouped_religion": ["christianity", "islam", "judaism"],
    "grouped_disability": [
        "without",
        "cognitive",
        "physical",
        "hearing",
        "sight",
        "chronic_illness",
        "mobility",
        "mental_health",
    ],
}

groups_to_include = []
for category in output_table["category"].unique():
    groups = SELECTED_GROUPS.get(category)
    if groups is None:
        groups = list(set(output_table.query(f'category == "{category}"')["group"]))

    groups_to_include.extend(groups)

print("Groups selected:", ", ".join(groups_to_include))
output_table = output_table.query("group == @groups_to_include")

### "Gap": How each group deviates from the category mean
Below, we will calculate how far each group deviates from the mean of that category. For example, how far does the accuracy on examples mentioning the "young" group deviate from the mean of all examples of the "age" category?

If there are more than one runs for each model, we would calculate gap metrics for each run separately. 

We use a for loop to add the following metrics of interest to our list:
- Accuracy
- False Positive Rates
    - FPR("Negative")
    - FPR("Neutral")
    - FPR("Positive")
- True Positive Rates
    - TPR("Negative")
    - TPR("Neutral")
    - TPR("Positive")

In [None]:
metrics = ["accuracy"]
for label in ["negative", "neutral", "positive"]:
    metrics.append(label + "_FPR")
    metrics.append(label + "_TPR")

gap_metrics = []
for metric in metrics:
    gap_metrics.append(f"{metric}_gap")

### Confidence Interval Calculations

In the previous step, we've found the gaps for each (metric, category, group, model, run/seed). If we have more than one runs/seeds for each model, we can aggregate the results to find a confidence interval for each gap variable.

To do so, we will aggregate the previous table (metric, category, group, model, run/seed) along the run/seed axis. The result will be a table with indices (metric, category, group, model).

In [None]:
# Create a (metric, group) placeholder table for storing the output values.
# Unlike in the previous step, "run_id" isn't part of groupby,
# and this axis will be squeezed in the result.
stats_table = (
    output_table[[*INDEPENDENT_VARIABLES, "category", "group", "accuracy"]]
    .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
    .agg(["mean"])
)

In [None]:
for metric, gap_metric in zip(metrics, gap_metrics):
    group_by = [*INDEPENDENT_VARIABLES, "category", "run_id"]
    grouped = output_table[[*INDEPENDENT_VARIABLES, "category", "run_id", metric]].groupby(group_by)
    group_mean = grouped.transform("mean")
    normalized_values = output_table[metric] - group_mean[metric]
    output_table[gap_metric] = normalized_values

    # See https://stackoverflow.com/a/53522680
    # Calculate mean and stdev for each (metric, category, group, model).
    stats = (
        output_table[[*INDEPENDENT_VARIABLES, "category", "group", gap_metric]]
        .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
        .agg(["mean", "count", "std"])
    )

    mean_values = []
    ci_width = []
    lower_values = []
    upper_values = []
    hypothesis_selected = []
    z_score = 1.96

    # Calculate 95% Confidence interval for each
    # (metric, category, group, model).
    for index in stats.index:
        mean, n, stdev = stats.loc[index]  # type: ignore
        lower = mean - z_score * stdev / sqrt(n)
        upper = mean + z_score * stdev / sqrt(n)

        mean_values.append(mean)
        ci_width.append(z_score * stdev / sqrt(n))
        lower_values.append(lower)
        upper_values.append(upper)

        if lower > 0:
            hypothesis_selected.append(1)
        elif upper < 0:
            hypothesis_selected.append(-1)
        else:
            hypothesis_selected.append(0)

    stats_table[gap_metric] = mean_values
    stats_table[gap_metric + "_width"] = ci_width
    stats_table[gap_metric + "_lower"] = lower_values
    stats_table[gap_metric + "_upper"] = upper_values

stats_table = stats_table.reset_index()

In [None]:
stats_table["model"] = stats_table["model"].apply(modify_model_name)
stats_table["dataset"] = stats_table["dataset"].apply(modify_dataset_name)

## Plotting

### Compare between Groups of each category 
E.g. all groups in the "age" category

In [None]:
datasets = stats_table["dataset"].unique()
models = ["(all)", *(stats_table["model"].unique())]
categories = stats_table["category"].unique()

##### Interactive Visualization Utilities

In [None]:
model_selector = widgets.Dropdown(options=models, value=models[0], description="Model:")
category_selector = widgets.Dropdown(options=categories, value=categories[0], description="Category:")
dataset_selector = widgets.Dropdown(options=datasets, value=datasets[0], description="Dataset:")

In [None]:
def visualize_fn(button: Any) -> None:
    clear_output()
    display(model_selector, category_selector, dataset_selector, button)

    model = model_selector.value
    category = category_selector.value  # type: ignore
    dataset = dataset_selector.value  # type: ignore

    filter_cond = (stats_table["category"] == category) & (stats_table["dataset"] == dataset)
    filtered_table = stats_table[filter_cond & (stats_table["model"] == model if model != "(all)" else True)]

    if filtered_table.empty:
        print("None of the runs in the table matched the selections.")
        return

    category_name = category.split("_")[-1]
    title = f"Gap for {category_name.title()}<br>{dataset}"

    gen_fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
                            row_heights=[1000, 1000],
                            row_titles=('Negative Sentiment FPR Gap', 'Positive Sentiment FPR Gap'),
                            vertical_spacing=0.03)

    for annotation in gen_fig['layout']['annotations']:
        annotation['font'] = dict(size=24, family='Helvetica')  # Setting size to 20 and font family to Courier

    for i, metric in enumerate(["negative_FPR_gap", "positive_FPR_gap"]):
        fig = px.scatter(
            filtered_table,
            x="group",
            y=metric,
            color="model",
            error_y=metric + "_width",
            labels={"group": "Protected Group", metric: format_metric_name(metric), "model": "LM"},
            category_orders={"model": ["RoBERTa-125M", "RoBERTa-355M", "OPT-125M", "OPT-350M",
                                       "OPT-1.3B", "OPT-2.7B", "OPT-6.7B", "OPT-13B",
                                       "Llama-2-7B", "Llama-2-13B", "Llama-3-8B", "Mistral-7B"]},
            title=title)

        customize_traces(fig, i)

        fig.for_each_xaxis(lambda x: x.update(title=""))
        fig.for_each_annotation(lambda a: a.update(text=""))
        for trace in fig.data:
            gen_fig.add_trace(trace, row=i+1, col=1)

    customize_figure_layout(gen_fig)
    save_and_display_figure(gen_fig, template_name, plots_dir)

def customize_traces(fig, row_index):
    legend_show = row_index == 0  # Show legend only in the first row
    marker_size = 10  # Increased marker size for better visibility
    error_y_thickness = 3.5  # Thicker error bars

    color_map = {
        "RoBERTa-125M": "#7ab3ef", "RoBERTa-355M": "#368ce7",
        "OPT-125M": "#f3ccff", "OPT-350M": "#d896ff", "OPT-1.3B": "#be29ec",
        "OPT-6.7B": "#800080", "OPT-13B": "#660066",
        "Llama-2-7B": "#bfdb81", "Llama-2-13B": "#83a561", "Llama-3-8B": "#48723e",
        "Mistral-7B": "#c68642"
    }

    for trace in fig.data:
        model = trace.legendgroup
        trace.showlegend = legend_show
        trace.marker.size = marker_size
        trace.error_y.thickness = error_y_thickness
        if model in color_map:
            trace.marker.color = color_map[model]
        if "OPT" in model:
            trace.marker.symbol = "diamond"
        elif "Llama" in model:
            trace.marker.symbol = "square"
        elif "Mistral" in model:
            trace.marker.symbol = "triangle-up"

def customize_figure_layout(gen_fig):
    font_size = 24  # Larger font size for better readability

    gen_fig.update_layout(
        height=1000, width=1300,  # Increase overall figure size
        legend=dict(
            y=1.14, x=0.5, xanchor='center', orientation='h', valign='top', title_text='',
            font=dict(size=font_size), title_font_family="Helvetica"
        ),
        scattermode="group", font_color="black", title_font_family="Helvetica", title_x=0.47,
        margin=dict(l=10, r=10, t=10, b=10),
        plot_bgcolor='#eeeeee',
        font=dict(size=font_size, family="Helvetica"),
    )

def save_and_display_figure(gen_fig, template_name, plots_dir):
    plot_png_path = f"{plots_dir}/{template_name}_FPR_Gaps.png"
    plot_pdf_path = f"{plots_dir}/{template_name}_FPR_Gaps.pdf"

    os.makedirs(plots_dir, exist_ok=True)

    # Write images
    gen_fig.write_image(plot_png_path, scale=2)
    gen_fig.write_image(plot_pdf_path)


##### Interative Visualization

In [None]:
button = widgets.Button(description="Visualize")
button.on_click(visualize_fn)
visualize_fn(button)