In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns
custom_params = {"axes.spines.right": False, "axes.spines.top": False, "axes.spines.left": False,
                 "axes.spines.bottom": False, "figure.dpi": 700, 'savefig.dpi': 300}
sns.set_theme(style = "whitegrid", rc = custom_params, font_scale = 1.5)

In [None]:
# Observable
muo_g = 0.5
muo_ng = 0.

alpha_g = 0.7
alpha_ng = 0.8

ratio = 0.25

In [None]:
# Computation averages observables
alpha = alpha_g * ratio + alpha_ng * (1 - ratio)
muo = muo_g * ratio * alpha_g / alpha + muo_ng * (1 - ratio) * alpha_ng / alpha 

alpha, muo

In [None]:
# Unobservables
sigma_g = 0.5
sigma_ng = 0.5

std_g = 1
std_ng = 1

In [None]:
# Theorem 1
def L_pop(alpha, rho, sigma, muo_g, muo, std):
    return (B_group(alpha, rho, sigma) + muo_g - muo) ** 2 + std ** 2

def L_group(alpha, rho, sigma, std):
    return B_group(alpha, rho, sigma) ** 2 + std ** 2

def B_group(alpha, rho, sigma):
    return -  rho * sigma / np.sqrt((1 - alpha) * alpha)

In [None]:
# Theorem 2
def theorem2_empirical(alpha, rho, sigma, muo_g, muo, std):
    return L_pop(alpha, rho, sigma, muo_g, muo, std) < L_group(alpha, rho, sigma, std)

def theorem2_theory(alpha, rho, sigma, muo_g, muo, std):
    q1 = rho / np.sqrt((1 - alpha) * alpha)
    q2 = (muo_g - muo) / (2 * sigma)
    return q1 < q2 if q2 < 0 else q1 > q2

In [None]:
def compute_deltas(alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng):
    delta_group = L_group(alpha_g, rho_g, sigma_g, std_g) - L_group(alpha_ng, rho_ng, sigma_ng, std_ng) 
    delta_pop = L_pop(alpha_g, rho_g, sigma_g, muo_g, muo, std_g) - L_pop(alpha_ng, rho_ng, sigma_ng, muo_ng, muo, std_ng) 
    return delta_group, delta_pop

# Theorem 3
def theorem3_empirical(alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng):
    assert sigma_g == sigma_ng, "Theorem assume equal unobserved variance"
    delta_group, delta_pop = compute_deltas(alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng)
    return (delta_group > delta_pop) & (delta_pop > 0)

def theorem3_theory(r_g, alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng):
    assert sigma_g == sigma_ng, "Theorem assume equal unobserved variance"
    assert muo_g > muo, "Theorem only applies for larger observed mean"
    
    f = lambda a, r, an: (2* an * (1-r)) / np.sqrt(a * (1-a)) - np.sqrt((1 - a) / a) * (an * (1-r) - a*r)
    d = lambda a, r, an: (a*r + an * (1-r)) / np.sqrt(a * (1-a)) - np.sqrt((1 - a) / a) * (an * (1-r) - a*r)
    e = lambda a: np.sqrt(a / (1 - a))

    L1 = lambda diff, rh, s, a, r, an, rhn, sn: (rh * s * f(a, r, an) + rhn * sn * f(an, 1 - r, a)) > (((1-r) * an - r * a) * diff)
    L2 = lambda diff, rh, s, a, r, an, rhn, sn: (rh * s * d(a, r, an) + rhn * sn * d(an, 1 - r, a)) < (((1-r) * an - r * a) * diff)
    L3 = lambda diff, rh, s, a, an, rhn, sn: (rh * s * e(a) - rhn * sn * e(an)) < diff

    diffmu = muo_g - muo_ng - rho_g * np.sqrt((1 - alpha_g) / alpha_g) * sigma_g + rho_ng * np.sqrt((1 - alpha_ng) / alpha_ng) * sigma_ng
    L2_eval = L2(diffmu, rho_g, sigma_g, alpha_g, r_g, alpha_ng, rho_ng, sigma_ng) 
    L3_eval =L3(diffmu, rho_g, sigma_g, alpha_g, alpha_ng, rho_ng, sigma_ng)

    return L1(diffmu, rho_g, sigma_g, alpha_g, r_g, alpha_ng, rho_ng, sigma_ng) & ((L2_eval & L3_eval) | (~L2_eval & ~L3_eval))

In [None]:
def theorem4(alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng):
    assert sigma_g == sigma_ng, "Theorem assume equal unobserved variance"
    delta_group, delta_pop = compute_deltas(alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng)
    return np.abs(delta_group) > np.abs(delta_pop)

In [None]:
# See when the different constraints are satisfied with rho varying
linspace = np.linspace(-1, 1, num = 100)
differences = pd.DataFrame(0, index = linspace, columns = linspace)
mask = pd.DataFrame(0, index = linspace, columns = linspace)
mask_comp = pd.DataFrame(0, index = linspace, columns = linspace)

for rho_g in linspace:
    for rho_ng in linspace:
        delta_group, delta_pop = compute_deltas(alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng)
        differences[rho_g][rho_ng] = delta_pop - delta_group
        mask[rho_g][rho_ng] = theorem3_theory(ratio, alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng)
        mask_comp[rho_g][rho_ng] = theorem4(alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng)

        assert theorem2_empirical(alpha_g, rho_g, sigma_g, muo_g, muo, std_g) == theorem2_theory(alpha_g, rho_g, sigma_g, muo_g, muo, std_g), "Theorem 2 not verified for group g"
        assert theorem2_empirical(alpha_ng, rho_ng, sigma_ng, muo_ng, muo, std_ng) == theorem2_theory(alpha_ng, rho_ng, sigma_ng, muo_ng, muo, std_ng), "Theorem 2 not verified for group ng"
        assert theorem3_empirical(alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng) == (theorem3_theory(ratio, alpha_g, rho_g, sigma_g, muo_g, muo, std_g, alpha_ng, rho_ng, sigma_ng, muo_ng, std_ng) & (delta_pop > 0)), "Theorem 3 not verified"

In [None]:
plt.pcolor(differences.columns, differences.index, differences, cmap='RdBu')
clb = plt.colorbar() 
clb.ax.set_title(r'$\Delta^{pop} - \Delta^{group}$')
clb.ax.set_ylabel('Difference in fairness gap')
plt.contourf(differences.columns, differences.index, mask == 1, 1, hatches=['', 'xxx'], alpha = 0)
plt.contourf(differences.columns, differences.index, mask_comp == 1, 1, hatches=['', '...'], alpha = 0)
plt.legend([plt.Rectangle((0,0),1,1, hatch = 'xxx', edgecolor = 'k', fill = False, alpha = 0.5), plt.Rectangle((0,0),1,1, hatch = '...', edgecolor = 'k', fill = False, alpha = 0.5)], 
           [r'0 < $\Delta^{pop} < \Delta^{group}$ (Theorem 3)', r'$|\Delta^{pop}| < |\Delta^{group}|$'], loc='upper left', bbox_to_anchor=(1.4, 1.04), frameon=False,
            handletextpad = 0.5, handlelength = 1.0, columnspacing = -0.5)
plt.axhline(y=0, color='k', alpha = 0.5, linestyle='--')
plt.axvline(x=0, color='k', alpha = 0.5, linestyle='--')

ax = plt.gca()
plt.xlabel(r'$\rho_g$')
plt.ylabel(r'$\rho_{\neg g}$')