In [7]:
import numpy as np

# Simulate 3 clients, each with their own local model update (weights)
num_clients = 3
model_dim = 5  # example model with 5 parameters

# Clients generate random model updates
client_updates = [np.random.rand(model_dim) for _ in range(num_clients)]

print("Client Updates (original):")
for i, update in enumerate(client_updates):
    print(f"Client {i+1}:", update)


Client Updates (original):
Client 1: [0.12399584 0.5912612  0.00124619 0.1935042  0.96519322]
Client 2: [0.83714278 0.42288568 0.83671578 0.37253776 0.85427903]
Client 3: [0.67526861 0.67121011 0.34039998 0.20446692 0.13474787]


In [9]:
# Simulate pairwise masks for secure aggregation
def generate_pairwise_masks(num_clients, dim):
    masks = [[np.random.rand(dim) for _ in range(num_clients)] for _ in range(num_clients)]
    return masks

pairwise_masks = generate_pairwise_masks(num_clients, model_dim)

# Each client applies masks
masked_updates = []
for i in range(num_clients):
    masked_update = client_updates[i].copy()
    for j in range(num_clients):
        if i != j:
            masked_update += pairwise_masks[i][j]  # add mask
            masked_update -= pairwise_masks[j][i]  # subtract peer's shared mask
    masked_updates.append(masked_update)

print("\nMasked Updates (sent to server):")
for i, update in enumerate(masked_updates):
    print(f"Client {i+1}:", update)



Masked Updates (sent to server):
Client 1: [ 1.21581244  0.58539925 -0.20581144  0.69571031  1.05448949]
Client 2: [-0.09802374  0.70029512  0.52968995  0.15305303  0.37228872]
Client 3: [ 0.51861854  0.39966261  0.85448343 -0.07825447  0.52744191]


In [11]:
# Server aggregates masked updates
aggregated_update = np.sum(masked_updates, axis=0)

print("\nAggregated Update at Server (after Secure Aggregation):")
print(aggregated_update)

# Ground truth: expected result if no masks were used
true_aggregation = np.sum(client_updates, axis=0)

print("\nTrue Aggregation (for validation):")
print(true_aggregation)

# Verify if Secure Aggregation worked
print("\nDifference:", np.abs(aggregated_update - true_aggregation))



Aggregated Update at Server (after Secure Aggregation):
[1.63640724 1.68535699 1.17836195 0.77050888 1.95422013]

True Aggregation (for validation):
[1.63640724 1.68535699 1.17836195 0.77050888 1.95422013]

Difference: [2.22044605e-16 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00]
