# Day 23: Secure Aggregation

**Cryptography-Based Privacy for Federated Learning**

## Overview
- **Paper**: Bonawitz et al., CCS 2017
- **Goal**: Server learns ONLY the aggregate, not individual updates
- **Method**: Shamir's Secret Sharing + Pairwise Masking

## What You'll Learn
1. **Secret Sharing**: Splitting data into shares
2. **Pairwise Masking**: Masks that cancel during aggregation
3. **Dropout Recovery**: Handling client disconnections
4. **Information-Theoretic Security**: Stronger than DP

---

## 1. What is Secure Aggregation?

In [None]:
print("""

SECURE AGGREGATION PROBLEM:

Standard FL:
  Client sends: update_i (PLAINTEXT)
  Server computes: Œ£ update_i
  Server sees: INDIVIDUAL updates ‚ùå

Secure Aggregation:
  Client sends: masked_update_i (ENCRYPTED)
  Server computes: Œ£ masked_update_i
  Server sees: ONLY THE SUM ‚úÖ

Key Insight:
  ‚Ä¢ Clients add pairwise masks that cancel out
  ‚Ä¢ mask_ij + mask_ji = 0
  ‚Ä¢ Server receives: update + mask
  ‚Ä¢ During sum: masks cancel, leaving Œ£ update

Security Guarantee:
  ‚Ä¢ Server learns NOTHING about individual updates
  ‚Ä¢ Only learns the aggregate sum
  ‚Ä¢ Information-theoretic security (not computational)

""")

## 2. Shamir's Secret Sharing

In [None]:
import numpy as np

def shamir_share(secret, n_shares, threshold, prime=2**61 - 1):
    """
    Split secret into n shares using Shamir's Secret Sharing.
    
    Properties:
    ‚Ä¢ Need threshold shares to reconstruct
    ‚Ä¢ < threshold shares reveal NOTHING
    ‚Ä¢ Information-theoretic security
    
    Args:
        secret: Secret value to share
        n_shares: Number of shares to create
        threshold: Minimum shares needed to reconstruct (t)
        prime: Prime field for arithmetic
        
    Returns:
        List of shares (x_i, y_i)
    """
    # Create random polynomial: f(x) = secret + a‚ÇÅx + a‚ÇÇx¬≤ + ... + a‚Çú‚Çã‚ÇÅx·µó‚Åª¬π
    coefficients = [secret] + [np.random.randint(0, prime) for _ in range(threshold - 1)]
    
    # Evaluate polynomial at n points
    shares = []
    for i in range(1, n_shares + 1):
        x = i
        y = secret
        for j, coeff in enumerate(coefficients[1:], 1):
            y = (y + coeff * (x ** j)) % prime
        shares.append((x, y))
    
    return shares

def shamir_reconstruct(shares, prime=2**61 - 1):
    """
    Reconstruct secret from shares using Lagrange interpolation.
    
    Args:
        shares: List of (x, y) pairs (need ‚â• threshold)
        prime: Prime field
        
    Returns:
        Reconstructed secret
    """
    # Lagrange interpolation
    secret = 0
    for i, (xi, yi) in enumerate(shares):
        # Compute Lagrange basis polynomial
        numerator = 1
        denominator = 1
        for j, (xj, _) in enumerate(shares):
            if i != j:
                numerator = (numerator * (-xj)) % prime
                denominator = (denominator * (xi - xj)) % prime
        
        # Add contribution
        term = (yi * numerator * pow(denominator, -1, prime)) % prime
        secret = (secret + term) % prime
    
    return secret

# Example
secret_value = 42
n_clients = 10
threshold = 7  # Need 7 shares to reconstruct

shares = shamir_share(secret_value, n_clients, threshold)

print(f"Secret: {secret_value}")
print(f"Split into {n_clients} shares (threshold={threshold})")
print(f"\nShares (first 3): {shares[:3]}")

# Reconstruct with threshold shares
reconstructed = shamir_reconstruct(shares[:threshold])
print(f"\nReconstructed from {threshold} shares: {reconstructed}")
print(f"Match: {reconstructed == secret_value}")

## 3. Pairwise Masking Protocol

In [None]:
print("""

PAIRWISE MASKING PROTOCOL:

Goal: Clients add masks that cancel during aggregation

Protocol:

1. PAIRING PHASE
   ‚Ä¢ Server randomly pairs clients: (0,1), (2,3), ...
   ‚Ä¢ Unpaired client pairs with itself

2. MASK GENERATION
   For each pair (i, j):
   ‚Ä¢ Client i generates: mask_ij ~ Uniform large
   ‚Ä¢ Client j generates: mask_ji = -mask_ij
   
3. SHARED MASK EXCHANGE
   ‚Ä¢ Client i sends mask_ij to client j (encrypted)
   ‚Ä¢ Client j sends mask_ji to client i (encrypted)
   
4. UPDATE MASKING
   For each client i:
   ‚Ä¢ Compute total mask: mask_i = Œ£ received_masks + Œ£ sent_masks
   ‚Ä¢ Send: update_i + mask_i
   
5. AGGREGATION
   ‚Ä¢ Server computes: Œ£ (update_i + mask_i)
   ‚Ä¢ Masks cancel: Œ£ mask_i = 0 (pairwise cancellation)
   ‚Ä¢ Result: Œ£ update_i (exact sum!)

Security:
  ‚Ä¢ Server sees: update_i + mask_i (INDISTINGUISHABLE from random)
  ‚Ä¢ Masks are one-time use (fresh each round)
  ‚Ä¢ Combines with secret sharing for dropout recovery

""")

## 4. Handling Client Dropout

In [None]:
print("""

CLIENT DROPOPT PROBLEM:

Issue:
  ‚Ä¢ Client drops after contributing its mask
  ‚Ä¢ Its mask doesn't cancel (pair still waiting)
  ‚Ä¢ Aggregate is wrong: Œ£ (update_i + mask_i) ‚â† Œ£ update_i

Solution: Double Masking + Secret Sharing

1. FIRST MASK (Pairwise)
   ‚Ä¢ As above: mask_ij with client j
   
2. SECOND MASK (Secret Shared)
   ‚Ä¢ Client splits its update into shares
   ‚Ä¢ Distributes shares to other clients
   
3. DROPOUT RECOVERY
   ‚Ä¢ If client drops:
     - Its pairwise masks are recovered from secret shares
     - Remaining clients can compute the aggregate
   ‚Ä¢ If ‚â• threshold clients remain: Recover exact sum

Guarantee:
  ‚Ä¢ Tolerates up to (n - threshold) dropouts
  ‚Ä¢ Example: n=10, threshold=7 ‚Üí Tolerate 3 dropouts

""")

## 5. Comparison: Secure Agg vs DP

In [None]:
comparison_df = pd.DataFrame({
    'Aspect': [
        'Privacy Guarantee',
        'Utility Loss',
        'Communication',
        'Computation',
        'Dropout Tolerance',
        'Implementation',
    ],
    'Secure Aggregation': [
        'Perfect (information-theoretic)',
        'None (exact aggregation)',
        'Higher (mask exchange)',
        'Moderate (crypto operations)',
        'Threshold-based (t of n)',
        'Complex (Bonawitz protocol)',
    ],
    'Differential Privacy': [
        'Œµ-DP (probabilistic)',
        'Yes (noise degrades accuracy)',
        'Same (just adds noise)',
        'Low (noise generation)',
        'Full (any dropout OK)',
        'Simple (add noise)',
    ],
})

print("\n" + "="*70)
print("SECURE AGGREGATION vs DIFFERENTIAL PRIVACY")
print("="*70)
print(comparison_df.to_string(index=False))

## 6. Summary

### Secure Aggregation Summary:

**Core Idea:**
- Clients add pairwise masks that cancel during aggregation
- Server sees ONLY the sum, not individual updates
- Information-theoretic security (stronger than DP)

**Key Components:**

1. **Shamir's Secret Sharing**:
   - Split secret into shares
   - t-of-n threshold scheme
   - < t shares reveal nothing

2. **Pairwise Masking**:
   - mask_ij + mask_ji = 0
   - Cancels during aggregation
   - One-time use per round

3. **Dropout Recovery**:
   - Secret share masks as backup
   - Tolerate up to (n - t) dropouts

**Advantages:**
- ‚úÖ Perfect privacy (no utility loss)
- ‚úÖ Exact aggregation (no noise)
- ‚úÖ Information-theoretic security

**Limitations:**
- ‚ùå Complex implementation
- ‚ùå Higher communication overhead
- ‚ùå Requires threshold clients

**When to Use:**
- ‚úÖ High privacy requirement
- ‚úÖ Can tolerate communication cost
- ‚úÖ Stable client participation

### Next Steps:
‚Üí **Day 24**: SignGuard (multi-layer defense)
‚Üí **Day 25**: Membership Inference Attack (privacy attacks)

---

**üìÅ Project Location**: `05_security_research/secure_aggregation_fl/`

**üìö Paper**: Bonawitz et al., "Practical Secure Aggregation for Privacy-Preserving Machine Learning", CCS 2017