# ⚖️ Fairness and Bias Analysis

Compute and visualize performance differences across simulated client groups or demographic slices. Useful for research reports and to impress reviewers with responsible ML practices.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from polyscale_fl.datasets import get_synthetic_client_data
from polyscale_fl.aggregator.aggregator_node import AggregatorNode
from polyscale_fl.client.client_node import ClientNode
from polyscale_fl.ipfs.ipfs_client import IPFSClient
from polyscale_fl.chain.chain_stub import ChainStub


## 1. Setup toy FL with 4 client groups (simulate demographic groups)

In [None]:
# Create heterogeneous synthetic data per client
clients_data = get_synthetic_client_data(
    n_clients=4, samples_per_client=400, n_features=20, hetero=True, seed=123
)
        model_cfg = {"input_dim": 20, "hidden_dim": 64, "output_dim": 2}
        ipfs = IPFSClient()
        # chain stub for recording metrics
chain = ChainStub()
        # aggregator with simple FedAvg
aggregator = AggregatorNode(model_cfg=model_cfg, ipfs_client=ipfs, chain_client=chain)
        client_nodes = {}
for cid, data in clients_data.items():
    client_nodes[cid] = ClientNode(cid, model_cfg, data, aggregator.p2p, chain, ipfs_client=ipfs)
        # register in p2p stub
for cid, client in client_nodes.items():
    aggregator.p2p.register(cid, client._on_message)
aggregator.p2p.register(aggregator.aggregator_id, aggregator._on_message)
        aggregator.p2p.set_aggregator(aggregator.aggregator_id)


## 2. Run a few rounds and collect per-client accuracy

In [None]:
from polyscale_fl.aggregator.evaluator import eval_state_on_sample
        rounds = 3
acc_history = {cid: [] for cid in client_nodes.keys()}
        for r in range(rounds):
    # each client trains one epoch
    for cid, client in client_nodes.items():
        client.local_train_and_prepare_update(epochs=1)
    # aggregator aggregates
    res = aggregator.collect_and_aggregate()
    if res is None:
        continue
    # evaluate each client on its own validation
    for cid, client in client_nodes.items():
        X_val = client.dataset["X_val"]
        y_val = client.dataset["y_val"]
                loss, acc = eval_state_on_sample(aggregator.get_global_state(), [(X_val, y_val)])
        acc_history[cid].append(acc)
        print(f"Round {r+1}, {cid} acc={acc:.3f}")


## 3. Plot fairness across groups

In [None]:
plt.figure(figsize=(8,5))
for cid, accs in acc_history.items():
    plt.plot(range(1, len(accs)+1), accs, marker='o', label=cid)
plt.xlabel('Round')
plt.ylabel('Validation Accuracy')
plt.title('Per-client accuracy across rounds')
plt.legend()
plt.show()