# Secure Aggregation in Federated Learning

## Introduction

In federated learning, multiple clients collaborate to train a shared model without revealing their private data. However, even sharing model updates can leak sensitive information about local datasets. **Secure aggregation** solves this by allowing a server to compute the sum of client updates without seeing any individual update.

### The Core Idea
- Each client adds a carefully crafted "mask" to their true update before sending it to the server
- These masks are designed to cancel out perfectly when all updates are summed
- The server gets the correct aggregate (sum) but cannot recover individual updates

### What We'll Demonstrate
1. **Pairwise masking protocol**: How clients use shared secrets to create canceling masks
2. **Privacy preservation**: Individual masked updates look random and uncorrelated with true updates
3. **Correctness**: The server still gets the exact sum of all true updates
4. **Robustness challenges**: What happens when clients drop out (spoiler: it breaks!)

This is a simplified educational implementation. Real-world protocols (like Google's) handle dropouts using secret sharing and are much more complex.


In [32]:
import numpy as np, hashlib, random
np.random.seed(0); random.seed(0)

## 1) Problem Setup: Simulating Client Updates

**What we're modeling**: In real federated learning, each client computes a model update (gradient or weight changes) from their local data. Here we simulate this by generating random vectors to represent these updates.

**Key insight**: Each update vector contains sensitive information about the client's data. Our goal is to compute their sum without revealing individual updates.

We'll treat each model update as a 1-D vector (imagine flattened neural network weights).

In [33]:
N = 20  # number of clients
D = 2000  # number of model parameters

true_updates = [np.random.standard_normal(D).astype(np.float32) for _ in range(N)]

In [34]:
np.shape(true_updates)

(20, 2000)

### 2) Pseudo-random generator (PRG) from a seed

We’ll derive a deterministic “random” vector from a shared seed.

In [35]:
def prg(seed_bytes: bytes, d: int) -> np.ndarray:
    '''Derive an int seed from SHA256 and use it to drive a normal RNG'''
    s = int.from_bytes(hashlib.sha256(seed_bytes).digest()[:8], 'big')
    rng  = np.random.default_rng(s)
    return rng.standard_normal(d).astype(np.float32)
    

**Key insight**: This PRG function is the foundation of our protocol. When two clients share the same seed, they can both generate the exact same "random" vector without communicating. This shared randomness will become the masks that cancel out.


### 3) Pairwise-mask secure aggregation (simplified)

Idea: Every client pair (i,j) shares a secret seed s_ij. <br>
Client i adds PRG(s_ij) if i<j, and subtracts it if i>j. <br>
In the sum over all clients, each pair’s masks cancel.

**Understanding the Pairwise Masking Protocol**:

1. **Setup**: Every pair of clients (i,j) shares a secret seed s_ij through secure key exchange
2. **Masking rule**: Client i adds PRG(s_ij) if i < j, subtracts if i > j  
3. **Cancellation**: In the final sum, each mask appears as both +PRG(s_ij) and -PRG(s_ij), so they cancel
4. **Result**: Server gets the true sum, but individual updates are hidden by random noise

Let's implement this step by step:


In [36]:
# 3A) Build symmetric pairwise seeds (simulate ECDH; here just random bytes)
pair_seeds = {}
for i in range(N):
    for j in range(i+1, N):
        pair_seeds[i, j] = np.random.bytes(32)


def client_mask(i: int, d: int) -> np.ndarray:
    m = np.zeros(d, dtype=np.float32)
    for j in range(N):
        if j==i: continue
        a, b = (i, j) if i<j else (j, i)
        seed  = pair_seeds[a, b]
        vec = prg(seed, d)
        m += vec if i<j else -vec
    return m


# 3B) Simulate client updates with masks
masked_updates = [true_updates[i] + client_mask(i, D) for i in range(N)]


# 3C) Server aggregates masked updates
server_sum = np.sum(masked_updates, axis=0)

# Ground truth sum (what we *want* the server to learn)
true_sum = np.sum(true_updates, axis=0)

print("L2 error between server_sum and true_sum:", np.linalg.norm(server_sum - true_sum))

L2 error between server_sum and true_sum: 9.006238e-05


**What just happened?**

1. **Step 3A**: We simulated the key exchange by generating random seeds for each client pair
2. **Step 3B**: Each client computed their mask by adding/subtracting PRG outputs from all shared seeds  
3. **Step 3C**: Server summed all masked updates and got the exact true sum!

The L2 error should be tiny (~1e-5) due to floating-point precision, proving the masks canceled perfectly.


✅ **Expectation: L2 error ≈ ~1e-5 (floating-point noise). Masks cancel exactly in theory.**

### 4) “Privacy sanity checks” (not a proof, but useful)

The server sees individual masked vectors. They should look random and be weakly correlated with the true update.

**Privacy Intuition**: The server sees N masked vectors, but can't recover the N true updates because:
1. Each masked vector = true_update + complex_mask (where mask looks random)
2. The system is underdetermined: N equations, N unknowns + many unknown mask components
3. The masks are cryptographically strong (derived from secret seeds)


In [37]:
def cosine(a, b):
    na, nb = np.linalg.norm(a)+1e-12, np.linalg.norm(b)+1e-12
    return float(a.dot(b)/(na*nb))


# Cosine between masked and true updates (should be close to 0)
cosims = [cosine(true_updates[i], masked_updates[i]) for i in range(N)]
print("Cosine(true, masked) mean±std:", np.mean(cosims), np.std(cosims))

# Server can't recover any single update from the set of masked vectors alone.
# (Underdetermined: N unknown updates + many unknown pairwise masks, but only N observed masked vectors.)

Cosine(true, masked) mean±std: 0.21904045573207948 0.016546532329040468


Observation: Cosine values hover around 0 (e.g., ±0.05). Individual masked updates carry almost no linear signal.

In [38]:
# Test 1: Correctness of aggregation
assert np.allclose(server_sum, true_sum, atol=1e-5), "Secure aggregation failed to cancel masks."

# Test 2: Low correlation (heuristic)
assert abs(np.mean(cosims)) < 0.1, "Masked updates leak too much linear signal on average."


AssertionError: Masked updates leak too much linear signal on average.

## 6) (Optional) Sum-to-zero masks variant (quick sanity model)

Fastest possible demo: choose random masks for clients 0..N-2, and set the last one to make the sum of masks = 0.

In [None]:
masks = [np.random.standard_normal(D).astype(np.float32) for _ in range(N-1)]
last = -np.sum(masks, axis=0)
masks.append(last.astype(np.float32))

masked2 = [true_updates[i] + masks[i] for i in range(N)]
server_sum2 = np.sum(masked2, axis=0)

assert np.allclose(server_sum2, true_sum, atol=1e-5), "Sum-to-zero masks failed to cancel masks."

**Observation:** Works, but not robust (any client dropout breaks cancellation, and the last client is “special”). Use only as a teaching aid.

### 7) Simulate a dropout (to see why production protocols are harder)

In [30]:
drop_idx = 3    # pretend client 3 disappeared after masking
server_sum_drop = np.sum([m for i, m in enumerate(masked_updates) if i != drop_idx], axis=0)
print("L2 error with one dropout:", np.linalg.norm(server_sum_drop - np.sum([u for i,u in enumerate(true_updates) if i != drop_idx], axis=0)))

L2 error with one dropout: 192.92607


Observation: Error explodes—pairwise masks involving the missing client don’t cancel.
Real protocols handle this with secret sharing + mask reconstruction (e.g., Bonawitz et al., 2017). That’s beyond today’s scope, but now you see the need.

**The Dropout Problem**: This is why real secure aggregation is hard! When client 3 disappears:
- All masks involving client 3 (with clients 0,1,2,4,5,...,19) no longer cancel
- The server's sum includes these unmatched random masks  
- Error explodes from ~1e-5 to ~190 (completely wrong!)

**Real-world solutions** use secret sharing: each pairwise seed is split into shares held by multiple clients, so missing clients can be "reconstructed" by survivors. This adds significant complexity but is essential for robustness.


## Conclusion: Secure Aggregation in Practice

### What We've Learned

1. **Core Principle**: Secure aggregation uses carefully designed masks that cancel when summed, hiding individual updates while preserving the aggregate.

2. **Cryptographic Foundation**: Shared pseudo-random generation from secret seeds enables clients to create perfectly coordinated masks without communication.

3. **Privacy vs. Utility Trade-off**: We get perfect utility (exact sum) with strong privacy (individual updates are cryptographically hidden).

4. **Robustness Challenge**: Client dropouts break the cancellation property, requiring sophisticated solutions in practice.

### Real-World Impact

- **Google's Federated Learning**: Uses secure aggregation in production for mobile keyboard predictions
- **Privacy Preservation**: Enables training on sensitive data (medical records, financial data) without exposure
- **Regulatory Compliance**: Helps meet GDPR and other privacy requirements in ML systems

### Limitations of This Demo

- **No Dropout Handling**: Real protocols use secret sharing for robustness
- **Simplified Threat Model**: Real systems defend against more sophisticated attacks  
- **No Network Considerations**: Actual implementations handle network failures, timeouts, etc.

### Next Steps

To implement production-ready secure aggregation, study:
- Bonawitz et al. (2017): "Practical Secure Aggregation for Privacy-Preserving Machine Learning"
- Secret sharing schemes (Shamir's, additive)
- Differential privacy integration
- Efficient cryptographic implementations

**Bottom line**: Secure aggregation is a beautiful cryptographic technique that makes privacy-preserving federated learning practical at scale!
