```
Copyright 2021 Twitter, Inc.
SPDX-License-Identifier: Apache-2.0
```

## Demographic Bias Plots FairFace

In [None]:
import math
import os
import random
import shlex
import subprocess
import sys
import time
from collections import OrderedDict
from pathlib import Path

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display
from matplotlib.patches import Rectangle
from PIL import Image

## Add seed for reproducibility
random.seed(42)
np.random.seed(42)

In [None]:
sns.set_context("paper")
sns.set_style("ticks")

In [None]:
base_dir = Path("../data").expanduser()
base_dir.exists()

In [None]:
import ast
import pickle

all_dicts = []
with open(base_dir / "./results.txt", "r") as f:
    lines = f.read().split("\n")
for i, l in enumerate(lines):
    if not l.strip():
        continue
    try:
        l = l.replace("array", "").replace("dict_keys", "").replace("\n", "")
        all_dicts.append(ast.literal_eval(l))
    except:
        print(
            f"Cannot read dictionary from line {i} (maybe empty line if nothing shows after this, e.g. last line):",
            l,
        )

df_results = pd.DataFrame(all_dicts)

In [None]:
df_results.head()

In [None]:
df_results[df_results.num_iterations == 10000].set_index("group_name_list")[
    "stats"
].to_dict()

In [None]:
df_results.setting.value_counts()

In [None]:
def plot_pairwise_stats(
    stat_dict,
    figsize=None,
    setting_name="unknown",
    num_iteration="unknown",
    confidence_interval_err=2,
    middle_band_width=0.1,
    x_label_angle=10,
    annotate=False,
):
    """
  Given a dictionary of pairs of group and comparison statisitcs:
  ('group1', 'group2'): [num_group1_is_picked, num_group2_is_picked]
  Plot the bar graph on all pairs in this format on the probability p that group1 is picked.
  
  The std error is assumed to be sqrt(p(1-p)/n), a confidence interval for Bernoulli inference.
  The bar graph plot +- 2 std err, giving 95% confidence interval.
  
  Args:
      confidence_interval_err: the width of the confidence interval in the plotsetting_name: the setting of this experiment. Only used for the title of the plot and name of the saved figure
      num_iteration: the number of samples used (int or str). Only used for the title of the plot and name of the saved figure
      x_label_angle: angle to rotate the x label. May need to increase for lengthy labels.
      middle_band_width: add two horizontal lines above and below 0.5 symmetrically to the plot, so creating a band of given width.
                    If None, no line is added.
  """
    x_labels = [
        "{}-{}".format(*pair[0]) + "\nhigher than\n" + "{}-{}".format(*pair[1])
        for pair in stat_dict.keys()
    ]
    prob = [val[0] / (val[0] + val[1]) for val in stat_dict.values()]
    total = [(val[0] + val[1]) for val in stat_dict.values()]
    y_err = [
        confidence_interval_err * math.sqrt(p * (1 - p) / n)
        for p, n in zip(prob, total)
    ]
    fig, ax = plt.subplots(figsize=figsize)
    ax.bar(x_labels, np.array(prob) - 0.5, yerr=y_err, color="0.5")
    if annotate:
        for i, v in enumerate(prob):
            y = (max(prob) + 0.05) if v > 0.5 else (min(prob) - 0.1)
            y = y - 0.5
            ax.text(
                i,
                y,
                f"{100*((v-0.5)):.0f}%",
                color="red",
                fontweight="bold",
                fontsize=16,
                ha="center",
            )

    if middle_band_width is not None:
        ax.axhline(0, color="k", label=f"Demographic Parity")
    plt.xlim(-0.5, len(x_labels) - 0.5)
    plt.ylim(-0.5, 0.5)
    ax.set_ylabel("Probability - 0.5\n$\pm$ 2 * error", fontsize=20)
    sns.despine(offset=10)
    plt.xticks(rotation=x_label_angle, fontsize=16)
    plt.yticks(fontsize=16)
    ax.yaxis.grid(True)
    plt.legend(fontsize=16)
    plt.title(f"{num_iteration} samples", fontsize=20)
    plt.tight_layout()
    plt.savefig(setting_name + "_n=" + str(num_iteration) + ".jpg")


def prepare_plot_vars(stat_dict, confidence_interval_err=2):
    stat_dict[(("black", ""), ("white", ""))] = sum(
        [
            np.array(stat_dict[(("black", g1), ("white", g2))])
            for g1 in ("male", "female")
            for g2 in ("male", "female")
        ]
    )

    stat_dict[(("", "female"), ("", "male"))] = sum(
        [
            np.array(stat_dict[((g1, "female"), (g2, "male"))])
            for g1 in ("black", "white")
            for g2 in ("black", "white")
            if ((g1, "female"), (g2, "male")) in stat_dict
        ]
    )

    pairs = [
        # Race
        (("black", "female"), ("white", "female")),
        (("black", "male"), ("white", "male")),
        # Gender
        (("black", "female"), ("black", "male")),
        (("white", "female"), ("white", "male")),
        # Cross
        (("black", "female"), ("white", "male")),
        (("black", "male"), ("white", "female")),
        # Agg
        (("black", ""), ("white", "")),
        (("", "female"), ("", "male")),
    ]
    stat_dict = OrderedDict([(pair, stat_dict[pair]) for pair in pairs])
    x_labels, prob1, prob2, total = zip(
        *[
            [
                f"{r1[:1].upper()}{g1[:1].upper()}"
                + "$\leftrightarrow$"
                + f"{r2[:1].upper()}{g2[:1].upper()}",
                val[0] / (val[0] + val[1]),
                val[1] / (val[0] + val[1]),
                (val[0] + val[1]),
            ]
            for ((r1, g1), (r2, g2)), val in stat_dict.items()
        ]
    )
    y_err1 = [
        confidence_interval_err * math.sqrt(p * (1 - p) / n)
        for p, n in zip(prob1, total)
    ]
    y_err2 = [
        confidence_interval_err * math.sqrt(p * (1 - p) / n)
        for p, n in zip(prob2, total)
    ]
    x = np.arange(len(x_labels))
    return stat_dict, x_labels, prob1, prob2, total, x, y_err1, y_err2


def plot_pairwise_stats_new(
    stat_dict,
    figsize=None,
    setting_name="unknown",
    num_iteration="unknown",
    confidence_interval_err=2,
    middle_band_width=0.1,
    x_label_angle=10,
    annotate=False,
    demographic_parity=True,
):
    """
    Given a dictionary of pairs of group and comparison statisitcs:
    ('group1', 'group2'): [num_group1_is_picked, num_group2_is_picked]
    Plot the bar graph on all pairs in this format on the probability p that group1 is picked.

    The std error is assumed to be sqrt(p(1-p)/n), a confidence interval for Bernoulli inference.
    The bar graph plot +- 2 std err, giving 95% confidence interval.

    Args:
      confidence_interval_err: the width of the confidence interval in the plotsetting_name: the setting of this experiment. Only used for the title of the plot and name of the saved figure
      num_iteration: the number of samples used (int or str). Only used for the title of the plot and name of the saved figure
      x_label_angle: angle to rotate the x label. May need to increase for lengthy labels.
      middle_band_width: add two horizontal lines above and below 0.5 symmetrically to the plot, so creating a band of given width.
                    If None, no line is added.
    """
    stat_dict, x_labels, prob1, prob2, total, x, y_err1, y_err2 = prepare_plot_vars(
        stat_dict, confidence_interval_err=confidence_interval_err
    )
    fig, ax = plt.subplots(figsize=figsize)
    width = 0.4
    bar1 = ax.bar(x - width / 2, np.array(prob1), width=width, yerr=y_err1, color="r")
    bar2 = ax.bar(x + width / 2, np.array(prob2), width=width, yerr=y_err2, color="0.5")
    if annotate:
        for i, v in enumerate(prob1):
            y = max(prob1 + prob2) + 0.05
            ax.text(
                x[i] - width / 2,
                0.1,
                f"{100*(prob1[i]):.0f}%",
                color="white",
                fontweight="bold",
                fontsize=12,
                ha="center",
            )
            ax.text(
                x[i] + width / 2,
                0.1,
                f"{100*(prob2[i]):.0f}%",
                color="white",
                fontweight="bold",
                fontsize=12,
                ha="center",
            )
            if demographic_parity:
                parity_title = f"$\Delta_{{0.5}}$\n$\mathbf{{{100*(prob1[i] - 0.5):+.0f}\%}}$"
            else:
                parity_title = f"$\Delta$\n$\mathbf{{{100*((prob1[i]/prob2[i]) - 1):+.0f}\%}}$"
            ax.text(
                x[i],
                y,
                parity_title,
                color="k",
                fontweight="bold",
                fontsize=12,
                ha="center",
            )
    if middle_band_width is not None:
        ax.axhline(0.5, color="k", linestyle="--", label=f"Demographic Parity", lw=0.5)
        ax.text(
            x[-1] + 1.75 * width,
            0.5,
            f"Demographic\nParity",
            color="k",
            fontweight="normal",
            fontsize=12,
            ha="center",
            va="center",
        )
        pass
    for i, region in enumerate(["Race", "Gender", "Cross", "Aggregate"]):
        ax.fill_betweenx(
            [0, 1], x[2 * i] - width, x[2 * i + 1] + width, color="0.9", alpha=0.4
        )
        ax.text(
            (x[2 * i] + x[2 * i + 1]) / 2,
            0.9,
            region,
            fontsize=12,
            fontweight="bold",
            ha="center",
        )
    plt.ylim(0, 1.1)
    ax.set_ylabel("$p(x) \pm 2 S.E.$", fontsize=20)
    sns.despine(offset=10)
    ax.set_axisbelow(False)
    plt.xticks(x, rotation=x_label_angle, fontsize=16)
    ax.set_xticklabels(x_labels)
    plt.yticks(fontsize=16)
    ax.yaxis.grid(True, color="white", linewidth=1, zorder=3)
    #   plt.legend(fontsize=16)
    if demographic_parity:
        title = f"n={num_iteration:,} samples; $\Delta_{{0.5}} = p_{{left}} - 0.5$"        
    else:
        title = f"n={num_iteration:,} samples; $\Delta = p_{{left}}/p_{{right}} - 1$"
    sub_title = "B=Black, W=White, F=Female, M=Male"
    plt.title(f"{title}; {sub_title}", fontsize=16, loc="left")
    plt.tight_layout()
    plt.savefig(setting_name + "_n=" + str(num_iteration) + ".jpg")

In [None]:
NUM_ITERATION = 10000
SETTING_NAME = "wiki_no_scaling_intersect"
max_salient_compare_dict = (
    df_results[
        (df_results.num_iterations == NUM_ITERATION)
        & (df_results.setting == SETTING_NAME)
    ]
    .set_index("group_name_list")["stats"]
    .to_dict()
)
max_salient_compare_dict

In [None]:
plot_pairwise_stats_new(
    max_salient_compare_dict,
    setting_name=SETTING_NAME,
    num_iteration=NUM_ITERATION,
    figsize=(12, 4),
    x_label_angle=0,
    annotate=True,
    demographic_parity=False
)

In [None]:
plot_pairwise_stats_new(
    max_salient_compare_dict,
    setting_name=SETTING_NAME,
    num_iteration=NUM_ITERATION,
    figsize=(12, 4),
    x_label_angle=0,
    annotate=True,
    demographic_parity=True
)

In [None]:
def plot_pairwise_stats_new_split(
    stat_dict,
    figsize=None,
    setting_name="unknown",
    num_iteration="unknown",
    confidence_interval_err=2,
    middle_band_width=0.1,
    x_label_angle=10,
    annotate=False,
    demographic_parity=True,
):
    """
    Given a dictionary of pairs of group and comparison statisitcs:
    ('group1', 'group2'): [num_group1_is_picked, num_group2_is_picked]
    Plot the bar graph on all pairs in this format on the probability p that group1 is picked.

    The std error is assumed to be sqrt(p(1-p)/n), a confidence interval for Bernoulli inference.
    The bar graph plot +- 2 std err, giving 95% confidence interval.

    Args:
      confidence_interval_err: the width of the confidence interval in the plotsetting_name: the setting of this experiment. Only used for the title of the plot and name of the saved figure
      num_iteration: the number of samples used (int or str). Only used for the title of the plot and name of the saved figure
      x_label_angle: angle to rotate the x label. May need to increase for lengthy labels.
      middle_band_width: add two horizontal lines above and below 0.5 symmetrically to the plot, so creating a band of given width.
                    If None, no line is added.
    """
    stat_dict, x_labels, prob1, prob2, total, x, y_err1, y_err2 = prepare_plot_vars(
        stat_dict, confidence_interval_err=confidence_interval_err
    )
    regions = ["Race", "Gender", "Cross", "Aggregate"]
    for i, region in enumerate(regions):
        fig, ax = plt.subplots(figsize=figsize)
        width = 0.4
        x = np.arange(2)
        bar1 = ax.bar(x - width / 2, np.array(prob1[i:i+2]), width=width, yerr=y_err1[i:i+2], color="r")
        bar2 = ax.bar(x + width / 2, np.array(prob2[i:i+2]), width=width, yerr=y_err2[i:i+2], color="0.5")
        if annotate:
            for j, v in enumerate(prob1[i:i+2]):
                y = max(prob1 + prob2) + 0.05
                ax.text(
                    x[j] - width / 2,
                    0.1,
                    f"{100*(prob1[i+j]):.0f}%",
                    color="white",
                    fontweight="bold",
                    fontsize=12,
                    ha="center",
                )
                ax.text(
                    x[j] + width / 2,
                    0.1,
                    f"{100*(prob2[i+j]):.0f}%",
                    color="white",
                    fontweight="bold",
                    fontsize=12,
                    ha="center",
                )
                if demographic_parity:
                    parity_title = f"$\Delta_{{0.5}}$\n$\mathbf{{{100*(prob1[i+j] - 0.5):+.0f}\%}}$"
                else:
                    parity_title = f"$\Delta$\n$\mathbf{{{100*((prob1[i+j]/prob2[i+j]) - 1):+.0f}\%}}$"
                ax.text(
                    x[j],
                    y,
                    parity_title,
                    color="k",
                    fontweight="bold",
                    fontsize=12,
                    ha="center",
                )

        if middle_band_width is not None:
            ax.axhline(0.5, color="k", linestyle="--", label=f"Demographic Parity", lw=0.5)
        plt.ylim(0, 1.1)
        ax.set_ylabel("$p(x) \pm 2 S.E.$", fontsize=20)
        sns.despine(offset=10)
        ax.set_axisbelow(False)
        plt.xticks(x, rotation=x_label_angle, fontsize=14)
        ax.set_xticklabels(x_labels[i:i+2])
        plt.yticks(fontsize=16)
        ax.yaxis.grid(True, color="white", linewidth=1, zorder=3)
        if demographic_parity:
            title = f"n={num_iteration:,}; $\Delta_{{0.5}} = p_{{left}} - 0.5$"        
        else:
            title = f"n={num_iteration:,}; $\Delta = p_{{left}}/p_{{right}} - 1$"
        sub_title = "B=Black, W=White, F=Female, M=Male"
#         plt.title(f"$\\mathbf{{{region}}}$\nDemographic\nParity= {100*(prob1[i] - 0.5):+.0f}%", fontsize=16, loc="center")
        plt.title(f"$\\mathbf{{{region}}}$\n{title}", fontsize=16, loc="right")
        plt.tight_layout()
        plt.savefig(f"{setting_name}_n={str(num_iteration)}_{region}.jpg")

In [None]:
plot_pairwise_stats_new_split(
    max_salient_compare_dict,
    setting_name=SETTING_NAME,
    num_iteration=NUM_ITERATION,
    figsize=(3, 4),
    x_label_angle=0,
    annotate=True,
)

In [None]:
NUM_ITERATION = 5000
SETTING_NAME = "wiki_fixed_height_intersect"
max_salient_compare_dict = (
    df_results[
        (df_results.num_iterations == NUM_ITERATION)
        & (df_results.setting == SETTING_NAME)
    ]
    .set_index("group_name_list")["stats"]
    .to_dict()
)
max_salient_compare_dict

In [None]:
plot_pairwise_stats(
    max_salient_compare_dict,
    setting_name=SETTING_NAME,
    num_iteration=NUM_ITERATION,
    figsize=(12, 4),
    x_label_angle=0,
)

In [None]:
def plot_dict_values(
    stat_dict,
    figsize=None,
    setting_name="unknown",
    num_iteration="unknown",
    confidence_interval_err=2,
    middle_band_width=0.1,
    x_label_angle=10,
):
    x_labels = ["{}-{}".format(*group_name) for group_name in stat_dict.keys()]
    print(x_labels)
    values = np.array(list(stat_dict.values()))
    total = values.sum()
    prob = values / total
    y_err = [confidence_interval_err * math.sqrt(p * (1 - p) / total) for p in prob]

    fig, ax = plt.subplots(figsize=figsize)
    ax.bar(x_labels, prob, yerr=y_err)

    ax.plot(
        [-0.5, len(x_labels) - 0.5],
        np.full(2, 1 / len(x_labels)),
        "r",
        label=f"average",
    )

    plt.xticks(rotation=x_label_angle, fontsize=16)
    ax.set_ylabel("Probability $\pm$ 2 * error", fontsize=20)
    plt.ylim(0.0, 1.0)
    ax.yaxis.grid(True)
    plt.yticks(fontsize=16)
    plt.legend(fontsize=16)
    plt.title(f"Probabilities with {num_iteration} samples", fontsize=20)
    plt.tight_layout()
    plt.savefig(setting_name + "_n=" + str(num_iteration) + ".jpg")

In [None]:
NUM_ITERATION = 10
SETTING_NAME = "wiki_no_scaling_intersect_together"

df_results[
    (df_results.num_iterations == NUM_ITERATION) & (df_results.setting == SETTING_NAME)
][["group_name_list", "stats"]].values[0][1]

In [None]:
dict(
    zip(
        *df_results[
            (df_results.num_iterations == NUM_ITERATION)
            & (df_results.setting == SETTING_NAME)
        ][["group_name_list", "stats"]].values[0]
    )
)

In [None]:
max_salient_all_groups_dict = dict(
    zip(
        *df_results[
            (df_results.num_iterations == NUM_ITERATION)
            & (df_results.setting == SETTING_NAME)
        ][["group_name_list", "stats"]].values[0]
    )
)

In [None]:
max_salient_all_groups_dict

In [None]:
plot_dict_values(
    max_salient_all_groups_dict,
    setting_name=SETTING_NAME,
    num_iteration=NUM_ITERATION,
    figsize=(12, 4),
    x_label_angle=0,
)