In [1]:
%load_ext autoreload
%autoreload 2

# Synthetic Data

## Generate

In [2]:
from faim.data_preparation.synthetic import NormalSyntheticGroupedDatasetBuilder
import numpy as np
from numpy.random import PCG64, Generator


random_generator = Generator(PCG64(4))
group_names=["privileged", "disadvantaged"]
synth_data_builder = NormalSyntheticGroupedDatasetBuilder(
    group_names=group_names,
    n_by_group=[50000, 50000],
    truth_prediction_means_by_group=[np.array([1, 2]), np.array([-1, -3]),],
    truth_prediction_correlation_matrixs_by_group=[
        np.array([[1, 0.8], [0.8, 1]]),
        np.array([[1, 0.8], [0.8, 1]]),
    ],
    random_generator=random_generator
)
synth_data = synth_data_builder.build()
synth_data.head()

Unnamed: 0,uuid,group,true_score,pred_score,true_label,pred_label
37946,61690248984788657218284987529151945635,0,0.897985,1.856003,1,1
4589,313863767218095658908289679367809502023,1,-0.221229,-2.008019,0,0
3166,54831838391621729574676657023165588603,0,2.361094,3.257432,1,1
17204,227503409615892428828483690159797022919,1,-0.124138,-2.095224,0,0
46899,61559048902492163868108324151274561739,1,-1.730384,-3.377135,0,0


## Figure 1

In [None]:
from copy import deepcopy
import seaborn as sns
import matplotlib.pyplot as plt

sns.color_palette("tab10")

plot_data = deepcopy(synth_data)
plot_data["group"] = plot_data.group.map(lambda idx: group_names[idx])

fig, ax = plt.subplots()
sns.kdeplot(data=plot_data, x="pred_score", hue="group", linestyle="--", ax=ax)
sns.kdeplot(data=plot_data, x="true_score", hue="group", linestyle="-", ax=ax)
_ = ax.set_ylabel("probability density")
_ = ax.set_xlim([-12, 12])
_ = ax.set_ylim([0, 0.25])

In [None]:
ax.get_legend().remove()
fig

In [None]:
fig.savefig("figures/figure1.svg", format="svg")

## Figure 2

In [None]:
synth_data

# Synth From Paper

## Get Fair Scores

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

# Load paper data (remove truncated data - it's calculated later anyways)
synth_data_from_paper_filepath = Path("../prepared-data/synthetic/2groups/2022-01-12/dataset.csv")
data = pd.read_csv(synth_data_from_paper_filepath)

group_names = {0: "advantaged", 1: "disadvantaged"}
pred_score_column = "pred_score"
score_stepsize = 0.1
thetas = {0: np.array([1, 0, 0]), 1: np.array([1, 0, 0])}
optimal_transport_regularization = 0.001
plot_dir = Path("../results/synthetic/2groups/2022-01-12-notebook/")

In [None]:
from faim.algorithm.faim import FairInterpolationMethod

fair_interpolation_method = FairInterpolationMethod(
    rawData=data,
    group_names=group_names,
    pred_score_column=pred_score_column,
    score_stepsize=score_stepsize,
    thetas=thetas,
    regForOT=optimal_transport_regularization,
    plot_dir=plot_dir,
    plot=False,
)
results = fair_interpolation_method.run()

## Boundary

In [None]:
results.plot(x="pred_score_truncated", y="predictedLabel", style='.', markersize=2, figsize=(3, 1), legend=False)

In [None]:
def find_boundary(results) -> float:
    return (
        results[results.predictedLabel==0].pred_score_truncated.max()
        + results[results.predictedLabel==1].pred_score_truncated.min()
    ) / 2
    

In [None]:
find_boundary(results)

## Results Plots - Calibration Condition

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.color_palette("tab10")
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(3, 5))

sns.kdeplot(data=results, x="pred_score_truncated", hue="group", linestyle="--", ax=ax1)
sns.kdeplot(data=results, x="fairScore", hue="group", linestyle=":", ax=ax1)

sns.scatterplot(data=results, x="pred_score_truncated", y="fairScore", hue="group", s=8, ax=ax2)
plt.plot([0, 1], [0, 1], 'k-', linewidth=1)

for ax in (ax1, ax2):
    ax.get_legend().remove()

## All Conditions

In [None]:
thetas_by_condition = [
    {0: np.array([1, 0, 0]), 1: np.array([1, 0, 0])},
    {0: np.array([0, 1, 0]), 1: np.array([0, 1, 0])},
    {0: np.array([0, 0, 1]), 1: np.array([0, 0, 1])},
    {0: np.array([1, 1, 1]), 1: np.array([1, 1, 1])},
]
plot_dir_by_condition = [
    Path("../results/synthetic/2groups/2022-01-12-notebook/1,0,0,1,0,0/"),
    Path("../results/synthetic/2groups/2022-01-12-notebook/0,1,0,0,1,0/"),
    Path("../results/synthetic/2groups/2022-01-12-notebook/0,0,1,0,0,1/"),
    Path("../results/synthetic/2groups/2022-01-12-notebook/1,1,1,1,1,1/"),
]

results_by_condition = []
for thetas, plot_dir in zip(theta_by_condition, plot_dir_by_condition):
    fair_interpolation_method = FairInterpolationMethod(
        rawData=data,
        group_names=group_names,
        pred_score_column=pred_score_column,
        score_stepsize=score_stepsize,
        thetas=thetas,
        regForOT=optimal_transport_regularization,
        plot_dir=plot_dir,
        plot=False,
    )
    results_by_condition.append(fair_interpolation_method.run())