# Bias Visualization

In [1]:
import os
from math import sqrt
from pathlib import Path
from typing import Any, Dict, List

import ipywidgets as widgets
import numpy as np
import pandas as pd
import plotly.express as px
from IPython.display import clear_output, display
from plotly.subplots import make_subplots
from sklearn.metrics import confusion_matrix as ConfusionMatrix
from tqdm.auto import tqdm

### Global Constants

This notebook will visualize how each of the following variable affects the fairness towards a particular protected group.

In [2]:
INDEPENDENT_VARIABLES = ["model", "dataset", "num_params"]
FIG_HEIGHT = 950
FIG_WIDTH = 1500
FONT_SIZE = 24
# Select from "grouped_race", "grouped_gender", "grouped_sexuality"
SENSITIVE_ATTRIBUTE = "grouped_gender"

## Utilities

### Plotting Utils

In [3]:
def format_metric_name(metric_name: str) -> str:
    """
    Prettify diagram texts by replacing
    snake case with regular title case.
    """
    output_words = []
    for word in metric_name.split("_"):
        word = word.title() if word not in ["FPR", "TPR"] else word.upper()
        output_words.append(word)

    return " ".join(output_words)

## Getting Predictions Ready

### Load Prediction File
Load the prediction `tsv` file from the provided paths

In [4]:
def find_csv_filenames(path_to_dir: str, template_name: str, suffix: str = ".tsv") -> List[str]:
    all_files = []
    for root, dirs, files in os.walk(path_to_dir):
        for file in files:
            # Filter out the llama_7b results as they are not fine tuned
            if file.endswith(suffix) and template_name in file and "llama_7b" not in file:
                all_files.append(os.path.join(root, file))
    return all_files


repo_abs_path = Path(os.path.abspath("")).parent.parent.parent
predictions_dir = f"{repo_abs_path}/unstated_norms_llm_bias/prompt_based_classification/predictions"
prediction_tsv_paths = find_csv_filenames(predictions_dir, template_name="")

plots_dir = f"{repo_abs_path}/unstated_norms_llm_bias/visualization/plots"
plot_name = "pure_prompt"

output_folder: str = "stats"
output_table = pd.DataFrame()

for prediction_tsv_path in tqdm(prediction_tsv_paths, ncols=80):
    prediction_tsv_filename = os.path.basename(prediction_tsv_path)
    dataframe = pd.read_csv(prediction_tsv_path, delimiter="\t")
    del dataframe["text"]
    if "cot" in prediction_tsv_path:
        dataframe["dataset"] = "CoT"
    dataframe = dataframe[dataframe["category"] == SENSITIVE_ATTRIBUTE]
    if output_table is None:
        output_table = dataframe
    else:
        assert isinstance(output_table, pd.DataFrame)
        output_table = pd.concat([output_table, dataframe])

  0%|                                                   | 0/115 [00:00<?, ?it/s]

## Calculate Statistics
Since there can be multiple examples for each protected group, we apply the `groupby` method to compute aggregated statistics across all the examples of each protected group.

Refer to the following function for details on how we implemented the metrics.

### Function for aggregating statistics


Below you will find a function that takes in a dataframe. This dataframe is will be automatically generated using the `pd.DataFrame.groupby`. The content of this dataframe is a subset of one particular prediction tsv for a particular protected group (e.g., the "young" group for the protected class "age"). 

The function will return a `pd.Series` (like a dictionary) mapping the following fairness metrics to their floating point values:
- Accuracy
- TPR(Positive)
- TPR(Neutral)
- TPR(Negative)
- FPR(Positive)
- FPR(Neutral)
- FPR(Negative)

Also refer to https://stackoverflow.com/a/50671617 for details on how the TPR and FPR values are computed.

In [5]:
top_n_filter = 5
unique_runs_with_accuracy = output_table[["model", "run_id", "dataset"]].drop_duplicates()
unique_runs_with_accuracy = unique_runs_with_accuracy.sort_values(["model", "dataset"], ascending=False)
best_runs_by_model_dataset = unique_runs_with_accuracy.groupby(["model", "dataset"]).head(top_n_filter).reset_index()
run_ids_to_keep = best_runs_by_model_dataset["run_id"].tolist()
output_table = output_table.loc[output_table["run_id"].isin(run_ids_to_keep)]

In [6]:
def modify_model_name(model_name: str) -> str:
    model_name = model_name.replace("facebook/", "")
    model_name = model_name.replace("data/models/", "")
    model_name = model_name.replace("model-weights/", "")
    model_name = model_name.replace(" CoT", "")
    if model_name.startswith("opt"):
        return model_name.replace("opt", "OPT").upper()
    elif model_name.startswith("roberta"):
        return model_name.replace("roberta", "RoBERTa")
    elif "Llama2" in model_name:
        return model_name.replace("Llama2", "Llama-2")
    elif "Llama3" in model_name:
        return model_name.replace("Llama3", "Llama-3")
    elif "qwen" in model_name:
        return model_name.replace("qwen2_5", "Qwen-2.5")
    else:
        return model_name


def modify_dataset_name(dataset_name: str) -> str:
    if "SST5" == dataset_name:
        return "SST5"
    if "CoT" == dataset_name:
        return "CoT"
    elif "SemEval" in dataset_name:
        return "SE"
    elif "ZeroShot" in dataset_name:
        return ""
    return dataset_name

In [7]:
def get_stats(dataframe: pd.DataFrame) -> pd.Series:
    """
    Input: dataframe representing all predictions for a particular
    protected group in a particular run.
    """
    output: Dict[str, float] = {}

    num_examples = len(dataframe)
    num_correct = np.sum(dataframe["y_pred"] == dataframe["y_true"])

    output["accuracy"] = num_correct / num_examples

    # See https://stackoverflow.com/a/50671617
    confusion_matrix = ConfusionMatrix(dataframe["y_true"], dataframe["y_pred"], labels=[0, 1, 2])

    FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
    FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
    TP = np.diag(confusion_matrix)
    TN = confusion_matrix.sum() - (FP + FN + TP)

    eps = 10**-10

    FP = np.copy(FP) + eps
    FN = np.copy(FN) + eps
    TP = np.copy(TP) + eps
    TN = np.copy(TN) + eps

    FPR = FP / (FP + TN)
    FNR = FN / (TP + FN)
    TPR = 1 - FNR

    for label_index, label in enumerate(["negative", "neutral", "positive"]):
        output[label + "_FPR"] = FPR[label_index]
        output[label + "_TPR"] = TPR[label_index]

    return pd.Series(output)

### Calculating Per-Group Statistics

In [8]:
analyzed_dataframe = output_table.groupby([*INDEPENDENT_VARIABLES, "run_id", "category", "group"]).apply(get_stats)
analyzed_dataframe.query(f'category == "{SENSITIVE_ATTRIBUTE}"')[["positive_TPR"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,positive_TPR
model,dataset,num_params,run_id,category,group,Unnamed: 6_level_1
Llama2-7B,SST5,7.0,run_1,grouped_gender,aab,0.933333
Llama2-7B,SST5,7.0,run_1,grouped_gender,cis,0.864662
Llama2-7B,SST5,7.0,run_1,grouped_gender,female,0.923077
Llama2-7B,SST5,7.0,run_1,grouped_gender,male,0.884615
Llama2-7B,SST5,7.0,run_1,grouped_gender,many-genders,0.854369
...,...,...,...,...,...,...
qwen2_5-7B CoT,CoT,7.0,run_5,grouped_gender,male,0.961538
qwen2_5-7B CoT,CoT,7.0,run_5,grouped_gender,many-genders,0.966019
qwen2_5-7B CoT,CoT,7.0,run_5,grouped_gender,no-gender,0.983333
qwen2_5-7B CoT,CoT,7.0,run_5,grouped_gender,non-binary,0.948387


In [9]:
output_table = analyzed_dataframe.reset_index()
os.makedirs(output_folder, exist_ok=True)
output_table.to_csv(os.path.join(output_folder, "aggregated.csv"))

## Calculating Statistics

Before we begin, note that some of the protected classes include a large number of groups. You may filter on the protected groups to include in each category. If a category isn't in this dictionary, the visualizations will include all the groups under that category by default.

In [10]:
SELECTED_GROUPS = {
    "grouped_religion": ["christianity", "islam", "judaism"],
    "grouped_disability": [
        "without",
        "cognitive",
        "physical",
        "hearing",
        "sight",
        "chronic_illness",
        "mobility",
        "mental_health",
    ],
    "grouped_gender": ["many-genders", "non-binary", "trans", "female", "male"],
    "grouped_sexuality": ["asexual", "other", "bisexual", "homosexual", "heterosexual"],
    "grouped_race": ["african_american", "american_indian", "asian", "hispanic", "pacific_islander", "white"],
}

groups_to_include = []
for category in output_table["category"].unique():
    groups = SELECTED_GROUPS.get(category)
    if groups is None:
        groups = list(set(output_table.query(f'category == "{category}"')["group"]))

    groups_to_include.extend(groups)

print("Groups selected:", ", ".join(groups_to_include))
output_table = output_table.query("group == @groups_to_include")

Groups selected: many-genders, non-binary, trans, female, male


### "Gap": How each group deviates from the category mean
Below, we will calculate how far each group deviates from the mean of that category. For example, how far does the accuracy on examples mentioning the "young" group deviate from the mean of all examples of the "age" category?

If there are more than one runs for each model, we would calculate gap metrics for each run separately. 

We use a for loop to add the following metrics of interest to our list:
- Accuracy
- False Positive Rates
    - FPR("Negative")
    - FPR("Neutral")
    - FPR("Positive")
- True Positive Rates
    - TPR("Negative")
    - TPR("Neutral")
    - TPR("Positive")

In [11]:
metrics = ["accuracy"]
for label in ["negative", "neutral", "positive"]:
    metrics.append(label + "_FPR")
    metrics.append(label + "_TPR")

gap_metrics = []
for metric in metrics:
    gap_metrics.append(f"{metric}_gap")

### Confidence Interval Calculations

In the previous step, we've found the gaps for each (metric, category, group, model, run/seed). If we have more than one runs/seeds for each model, we can aggregate the results to find a confidence interval for each gap variable.

To do so, we will aggregate the previous table (metric, category, group, model, run/seed) along the run/seed axis. The result will be a table with indices (metric, category, group, model).

In [12]:
stats_table = (
    output_table[[*INDEPENDENT_VARIABLES, "category", "group", "accuracy"]]
    .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
    .agg(["mean"])
)

In [13]:
for metric, gap_metric in zip(metrics, gap_metrics):
    group_by = [*INDEPENDENT_VARIABLES, "category", "run_id"]
    grouped = output_table[[*INDEPENDENT_VARIABLES, "category", "run_id", metric]].groupby(group_by)
    group_mean = grouped.transform("mean")
    normalized_values = output_table[metric] - group_mean[metric]
    output_table[gap_metric] = normalized_values

    stats = (
        output_table[[*INDEPENDENT_VARIABLES, "category", "group", gap_metric]]
        .groupby([*INDEPENDENT_VARIABLES, "category", "group"])
        .agg(["mean", "count", "std"])
    )

    mean_values = []
    ci_width = []
    lower_values = []
    upper_values = []
    hypothesis_selected = []
    z_score = 1.96

    # Calculate 95% Confidence interval for each
    # (metric, category, group, model).
    for index in stats.index:
        mean, n, stdev = stats.loc[index]  # type: ignore
        lower = mean - z_score * stdev / sqrt(n)
        upper = mean + z_score * stdev / sqrt(n)

        mean_values.append(mean)
        ci_width.append(z_score * stdev / sqrt(n))
        lower_values.append(lower)
        upper_values.append(upper)

        if lower > 0:
            hypothesis_selected.append(1)
        elif upper < 0:
            hypothesis_selected.append(-1)
        else:
            hypothesis_selected.append(0)

    stats_table[gap_metric] = mean_values
    stats_table[gap_metric + "_width"] = ci_width
    stats_table[gap_metric + "_lower"] = lower_values
    stats_table[gap_metric + "_upper"] = upper_values

stats_table = stats_table.reset_index()

In [14]:
stats_table["model"] = stats_table["model"].apply(modify_model_name)
stats_table["dataset"] = stats_table["dataset"].apply(modify_dataset_name)

## Plotting

### Compare between Groups of each category 
E.g. all groups in the "age" category

In [15]:
datasets = ["(all)", *(stats_table["dataset"].unique())]
models = ["(all)", *(stats_table["model"].unique())]
categories = stats_table["category"].unique()

##### Interactive Visualization Utilities

In [16]:
model_selector = widgets.Dropdown(options=models, value=models[0], description="Model:")
category_selector = widgets.Dropdown(options=categories, value=categories[0], description="Category:")
dataset_selector = widgets.Dropdown(options=datasets, value=datasets[0], description="Dataset:")

In [17]:
def visualize_fn(button: Any) -> None:
    clear_output()
    display(model_selector, category_selector, dataset_selector, button)

    model = model_selector.value
    category = category_selector.value  # type: ignore
    dataset = dataset_selector.value  # type: ignore
    filter_cond_1 = stats_table["category"] == category
    filter_cond_2 = stats_table["model"] == model if model != "(all)" else True
    filter_cond_3 = stats_table["dataset"] == dataset if dataset != "(all)" else True
    filtered_table = stats_table[filter_cond_1 & filter_cond_2 & filter_cond_3]
    filtered_table["model - dataset"] = filtered_table["model"] + " " + filtered_table["dataset"]
    filtered_table["model - dataset"] = filtered_table["model - dataset"].str.strip()
    if filtered_table.empty:
        print("None of the runs in the table matched the selections.")
        return

    category_name = category.split("_")[-1]
    title = f"Gap for {category_name.title()}<br>{dataset}"

    # Ordering the groups
    group_order = SELECTED_GROUPS[SENSITIVE_ATTRIBUTE]
    # Create a CategoricalDtype with the specified order
    group_cat_type = pd.CategoricalDtype(categories=group_order, ordered=True)
    # Apply the categorical type to the 'group' column
    filtered_table["group"] = filtered_table["group"].astype(group_cat_type)
    filtered_table = filtered_table.sort_values(by="group")

    gen_fig = make_subplots(
        rows=2,
        cols=1,
        shared_xaxes=True,
        row_heights=[1000, 1000],
        row_titles=("Negative Sentiment FPR Gap", "Positive Sentiment FPR Gap"),
        vertical_spacing=0.03,
    )

    for annotation in gen_fig["layout"]["annotations"]:
        annotation["font"] = dict(size=24, family="Helvetica")  # Setting size to 20 and font family to Courier

    for i, metric in enumerate(["negative_FPR_gap", "positive_FPR_gap"]):
        fig = px.scatter(
            filtered_table,
            x="group",
            y=metric,
            color="model - dataset",
            error_y=metric + "_width",
            labels={"group": "Protected Group", metric: format_metric_name(metric)},
            category_orders={
                "model - dataset": [
                    "OPT-6.7B",
                    "OPT-6.7B SST5",
                    "OPT-6.7B SE",
                    "LLaMA-7B",
                    "LLaMA-7B CoT",
                    "LLaMA-7B SST5",
                    "LLaMA-7B SE",
                    "Llama-2-7B",
                    "Llama-2-7B CoT",
                    "Llama-2-7B SST5",
                    "Llama-2-7B SE",
                    "Llama-3-8B",
                    "Llama-3-8B SST5",
                    "Llama-3-8B SE",
                    "Llama-3-8B CoT",
                    "Mistral-7B",
                    "Mistral-7B SST5",
                    "Mistral-7B SE",
                    "Mistral-7B CoT",
                    "Qwen-2.5-3B",
                    "Qwen-2.5-3B SST5",
                    "Qwen-2.5-3B SE",
                    "Qwen-2.5-3B CoT",
                    "Qwen-2.5-7B",
                    "Qwen-2.5-7B SST5",
                    "Qwen-2.5-7B SE",
                    "Qwen-2.5-7B CoT",
                ]
            },
            title=title,
        )

        customize_traces(fig, i)

        fig.for_each_xaxis(lambda x: x.update(title=""))
        fig.for_each_annotation(lambda a: a.update(text=""))
        for trace in fig.data:
            gen_fig.add_trace(trace, row=i + 1, col=1)

    gen_fig.update_xaxes(
        matches=None, tickangle=0, title_font_family="Helvetica", tickfont_family="Helvetica", tickfont_size=FONT_SIZE
    )

    customize_figure_layout(gen_fig)
    save_and_display_figure(gen_fig, plot_name, plots_dir)


def customize_traces(fig: Any, row_index: int) -> None:
    """Customize traces with larger markers and adjusted legend visibility."""
    legend_show = row_index == 0  # Show legend only in the first row
    marker_size = 10  # Increased marker size for better visibility
    error_y_thickness = 3.5  # Thicker error bars

    color_map = {
        "OPT-6.7B": "#bfdb81",
        "OPT-6.7B SST5": "#83a561",
        "OPT-6.7B SE": "#48723e",
        "LLaMA-7B": "#f3ccff",
        "LLaMA-7B CoT": "#d896ff",
        "LLaMA-7B SST5": "#be29ec",
        "LLaMA-7B SE": "#800080",
        # https://www.color-hex.com/color-palette/97036
        "Llama-2-7B": "#afe3ff",
        "Llama-2-7B CoT": "#53abff",
        "Llama-2-7B SST5": "#0060ff",
        "Llama-2-7B SE": "#0034c3",
        "Llama-3-8B": "#f3aeae",
        "Llama-3-8B SST5": "#f78d8d",
        "Llama-3-8B SE": "#dd8787",
        "Llama-3-8B CoT": "#ef7979",
        "Mistral-7B": "#d896ff",
        "Mistral-7B SST5": "#be29ec",
        "Mistral-7B SE": "#800080",
        "Mistral-7B CoT": "#660066",
        "Qwen-2.5-3B": "#d2a56d",
        "Qwen-2.5-3B SST5": "#bd7e4a",
        "Qwen-2.5-3B SE": "#96613d",
        "Qwen-2.5-3B CoT": "#83502e",
        "Qwen-2.5-7B": "#ffaed7",
        "Qwen-2.5-7B SST5": "#ff77bc",
        "Qwen-2.5-7B SE": "#ff48a5",
        "Qwen-2.5-7B CoT": "#ff0081",
    }

    for trace in fig.data:
        model = trace.legendgroup
        trace.showlegend = legend_show
        trace.marker.size = marker_size
        trace.error_y.thickness = error_y_thickness
        if model in color_map:
            trace.marker.color = color_map[model]
        if "OPT" in model:
            trace.marker.symbol = "diamond"
        elif "Llama-2" in model:
            trace.marker.symbol = "square"
        elif "Llama-3" in model:
            trace.marker.symbol = "x"
        elif "Mistral" in model:
            trace.marker.symbol = "triangle-up"


def customize_figure_layout(gen_fig: Any) -> None:

    gen_fig.update_layout(
        height=FIG_HEIGHT,
        width=FIG_WIDTH,
        legend=dict(
            y=1.2,
            x=0.48,
            xanchor="center",
            orientation="h",
            valign="top",
            title_text="",
            font=dict(size=FONT_SIZE - 4),
            title_font_family="Helvetica",
        ),
        scattermode="group",
        font_color="black",
        title_font_family="Helvetica",
        title_x=0.47,
        margin=dict(l=10, r=10, t=10, b=10),
        plot_bgcolor="#eeeeee",
        font=dict(size=FONT_SIZE, family="Helvetica"),
    )
    gen_fig.update_xaxes(tickangle=0)


def save_and_display_figure(gen_fig: Any, template_name: str, plots_dir: str) -> None:
    plot_png_path = f"{plots_dir}/{plot_name}_FPR_Gaps.png"
    plot_pdf_path = f"{plots_dir}/{plot_name}_FPR_Gaps.pdf"

    os.makedirs(plots_dir, exist_ok=True)
    gen_fig.write_image(plot_png_path, scale=2)
    gen_fig.write_image(plot_pdf_path)

    with open(plot_png_path, "rb") as img_file:
        image_widget = widgets.Image(value=img_file.read(), format="png", width=FIG_WIDTH, height=FIG_HEIGHT)
        display(image_widget)

In [18]:
button = widgets.Button(description="Visualize")
button.on_click(visualize_fn)
visualize_fn(button)

Dropdown(description='Model:', options=('(all)', 'Llama-2-7B', 'Llama-3-8B', 'Mistral-7B', 'OPT-6.7B', 'Qwen-2…

Dropdown(description='Category:', options=('grouped_gender',), value='grouped_gender')

Dropdown(description='Dataset:', options=('(all)', 'SST5', 'SE', '', 'CoT'), value='(all)')

Button(description='Visualize', style=ButtonStyle())

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x0b\xb8\x00\x00\x07l\x08\x06\x00\x00\x00?\x1a\xc14\x…