In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

In [2]:
# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # Normalize data

# Flatten the input for simplicity (28x28 -> 784)
x_train = x_train.reshape(-1, 28 * 28)
x_test = x_test.reshape(-1, 28 * 28)

# Convert labels to one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# Split data into multiple clients
num_clients = 3
data_per_client = len(x_train) // num_clients

client_data = []
for i in range(num_clients):
    start = i * data_per_client
    end = start + data_per_client
    client_data.append((x_train[start:end], y_train[start:end]))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
def create_model():
    model = Sequential([
        Dense(5, activation='sigmoid', input_shape=(28 * 28,)),  # One layer with 5 neurons
        Dense(10, activation='softmax')  # Output layer for 10 classes
    ])
    return model

global_model = create_model()
global_model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

In [5]:
def client_compute_gradients(client_model, client_data, global_weights):
    client_model.set_weights(global_weights)
    x_client, y_client = client_data
    with tf.GradientTape() as tape:
        predictions = client_model(x_client, training=True)
        loss = client_model.compiled_loss(y_client, predictions)
    gradients = tape.gradient(loss, client_model.trainable_variables)
    return gradients

# Function to aggregate gradients
def aggregate_gradients(gradients_list):
    avg_gradients = [
        np.mean([grad[client_idx].numpy() for grad in gradients_list], axis=0)
        for client_idx in range(len(global_model.trainable_variables))
    ]
    return avg_gradients



In [19]:
# FedSGD process
num_rounds = 500
learning_rate = 0.1

for round_num in range(num_rounds):
    print(f"Round {round_num + 1}")
    
    # Store gradients from all clients
    gradients_list = []
    for client_idx, (x_client, y_client) in enumerate(client_data):
        client_model = create_model()
        x_client = tf.constant(x_client)
        y_client = tf.constant(y_client)
        client_model.compile(optimizer='sgd', loss='categorical_crossentropy')
        gradients = client_compute_gradients(client_model, (x_client, y_client), global_model.get_weights())
        gradients_list.append(gradients)
    
    # Aggregate gradients
    avg_gradients = aggregate_gradients(gradients_list)
    
    # Update global model
    global_weights = global_model.get_weights()
    updated_weights = [
        global_weights[i] - learning_rate * avg_gradients[i]
        for i in range(len(global_weights))
    ]
    global_model.set_weights(updated_weights)

    # Evaluate global model
    loss, accuracy = global_model.evaluate(x_test, y_test, verbose=0)
    print(f"Global Model Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")

Round 1
Global Model Loss: 1.2476, Accuracy: 0.6667
Round 2
Global Model Loss: 1.2466, Accuracy: 0.6673
Round 3
Global Model Loss: 1.2456, Accuracy: 0.6676
Round 4
Global Model Loss: 1.2447, Accuracy: 0.6678
Round 5
Global Model Loss: 1.2437, Accuracy: 0.6682
Round 6
Global Model Loss: 1.2427, Accuracy: 0.6684
Round 7
Global Model Loss: 1.2417, Accuracy: 0.6693
Round 8
Global Model Loss: 1.2408, Accuracy: 0.6702
Round 9
Global Model Loss: 1.2398, Accuracy: 0.6708
Round 10
Global Model Loss: 1.2388, Accuracy: 0.6710
Round 11
Global Model Loss: 1.2379, Accuracy: 0.6712
Round 12
Global Model Loss: 1.2369, Accuracy: 0.6712
Round 13
Global Model Loss: 1.2359, Accuracy: 0.6718
Round 14
Global Model Loss: 1.2350, Accuracy: 0.6718
Round 15
Global Model Loss: 1.2340, Accuracy: 0.6728
Round 16
Global Model Loss: 1.2331, Accuracy: 0.6736
Round 17
Global Model Loss: 1.2321, Accuracy: 0.6742
Round 18
Global Model Loss: 1.2311, Accuracy: 0.6749
Round 19
Global Model Loss: 1.2302, Accuracy: 0.6753
Ro