# FairGroups Partitioning on Folktables Dataset

This notebook demonstrates how to use the `FairGroups` class to partition continuous sensitive attributes into groups that maximize the variance of fairness criterion $\Phi$.

## Dataset: Folktables
- **Sensitive attribute (S)**: Age
- **Outcome variable (Y)**: income (binary: yearly income is above $50,000)

## Overview
We'll load the Folktables dataset, preprocess it to extract age and income information, and then use FairGroups to partition age into groups that maximize fairness variance.

In [None]:
# Import necessary libraries
import numpy as np

# Import our FairGroups implementation
import sys
sys.path.append('..')
from fair_groups.partition_estimation import FairGroups
from fair_groups.fairness_metrics import get_conditional_positive_y_proba
from fair_groups.visualization import plot_partition, plot_partition_with_ci, plot_conditional_proba, plot_group_summary_statistics_table
from data.folktables_income_data import load_folktables_data

# Set random seed for reproducibility
np.random.seed(42)

## 1. Load Preprocessed Folktables Dataset

In [None]:
# Load Folktables dataset
age_sample, income_sample = load_folktables_data()

In [None]:
# Visualize the conditional probability of positive outcome given feature S
s_bins, y_s_proba = get_conditional_positive_y_proba(age_sample, income_sample, n_bins=30)
plot_conditional_proba(s_bins, y_s_proba, 'Age')

## 2. Apply FairGroups Partitioning

In [None]:
# Initialize FairGroups
n_groups = 5
grid_size = 50  # Reduced for faster computation

print(f"\n=== Fitting FairGroups with {n_groups} groups ===")

# Initialize and fit FairGroups
fair_groups = FairGroups(n_groups=n_groups, grid_size=grid_size)
fair_groups.fit(age_sample, income_sample)

# Use the print method
print("\nDetailed model information:")
fair_groups.print()

## 3. Visualize Partition

In [None]:
print("\nDetailed analysis of the partition:")

partition = fair_groups.partition

group_indices = range(len(partition) - 1)
phi_by_group = fair_groups.phi_by_group

plot_group_summary_statistics_table(age_sample, income_sample, partition,
                                    phi_by_group, sensitive_var_name="Age")

In [None]:
# Visualization partition
phi_by_group_ci = fair_groups.phi_by_group_ci

plot_partition(partition, phi_by_group, sensitive_var_name="Age")
plot_partition_with_ci(partition, phi_by_group_ci, sensitive_var_name="Age")

## 4. Predict groups with FairGroups

In [None]:
# Test the predict method with new data
print("Testing predictions with new age values...")

# Create test ages
test_ages = np.array([18, 20, 24, 25, 27, 31, 32, 35, 45, 55, 65, 75])
predicted_groups = fair_groups.predict(test_ages)

print("\nPrediction Results:")
for age, group in zip(test_ages, predicted_groups):
    print(f"Age {age}: Group {group}")