<a href="https://colab.research.google.com/github/artsasse/fedkan/blob/main/Keras_MNIST_Federated_SGD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Funções

In [None]:
import numpy as np
import random
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score


def create_clients(image_list, label_list, num_clients=10, initial='client'):
    ''' return: a dictionary with keys clients' names and value as
                data shards - tuple of images and label lists.
        args:
            image_list: a list of numpy arrays of training images
            label_list:a list of binarized labels for each image
            num_client: number of federated members (clients)
            initials: the clients' name prefix, e.g., client_1
    '''
    # Create a list of client names
    client_names = ['{}_{}'.format(initial, i+1) for i in range(num_clients)]

    # Randomize the data
    data = list(zip(image_list, label_list))
    random.shuffle(data)

    # Shard data and place at each client
    size = len(data) // num_clients
    shards = [data[i:i + size] for i in range(0, size * num_clients, size)]

    # Number of clients must equal number of shards
    assert len(shards) == len(client_names)

    return {client_names[i]: shards[i] for i in range(len(client_names))}


def batch_data(data_shard, bs=32):
    '''Takes in a client's data shard and create a tfds object off it
    args:
        shard: a data, label constituting a client's data shard
        bs: batch size
    return:
        tfds object'''
    # Separate shard into data and labels lists
    data, label = zip(*data_shard)
    dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label)))
    return dataset.shuffle(len(label)).batch(bs)


class SimpleMLP:
    @staticmethod
    def build(shape, classes):
        model = Sequential()
        model.add(Dense(200, input_shape=(shape,)))
        model.add(Activation("relu"))
        model.add(Dense(200))
        model.add(Activation("relu"))
        model.add(Dense(classes))
        model.add(Activation("softmax"))
        return model


def weight_scaling_factor(clients_trn_data, client_name):
    client_names = list(clients_trn_data.keys())
    # Get the batch size
    bs = list(clients_trn_data[client_name])[0][0].shape[0]
    # First calculate the total training data points across clients
    global_count = sum([tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() for client_name in client_names]) * bs
    # Get the total number of data points held by a client
    local_count = tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() * bs
    return local_count / global_count


def scale_model_weights(weight, scalar):
    '''Function for scaling a model's weights'''
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])
    return weight_final


def sum_scaled_weights(scaled_weight_list):
    '''Return the sum of the listed scaled weights. This is equivalent to scaled avg of the weights'''
    avg_grad = list()
    # Get the average grad across all client gradients
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)
        avg_grad.append(layer_mean)
    return avg_grad


def test_model(X_test, Y_test, model, comm_round):
    cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    logits = model.predict(X_test)
    loss = cce(Y_test, logits)
    acc = accuracy_score(tf.argmax(logits, axis=1), tf.argmax(Y_test, axis=1))
    print('comm_round: {} | global_acc: {:.3%} | global_loss: {}'.format(comm_round, acc, loss))
    return acc, loss

# Execução

In [None]:
# Load MNIST dataset
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

# Preprocess the data
X_train = X_train.reshape(-1, 28 * 28) / 255.0
X_test = X_test.reshape(-1, 28 * 28) / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Split data into training and test set
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

# Create clients
clients = create_clients(X_train, y_train, num_clients=10, initial='client')

# Process and batch the training data for each client
clients_batched = dict()
for (client_name, data) in clients.items():
    clients_batched[client_name] = batch_data(data)

# Process and batch the test set
test_batched = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(len(y_test))

# Define number of communication rounds
comms_round = 10

# Initialize global model
smlp_global = SimpleMLP()
global_model = smlp_global.build(784, 10)

# Start global training loop
for comm_round in range(comms_round):

    # Get the global model's weights - will serve as the initial weights for all local models
    global_weights = global_model.get_weights()

    # Initial list to collect local model weights after scaling
    scaled_local_weight_list = []

    # Randomize client data - using keys
    client_names = list(clients_batched.keys())
    random.shuffle(client_names)

    # Loop through each client and create new local model
    for client in client_names:
        smlp_local = SimpleMLP()
        local_model = smlp_local.build(784, 10)

        # Create a new optimizer instance for each local model
        lr = 0.01
        loss = 'categorical_crossentropy'
        metrics = ['accuracy']
        local_optimizer = SGD(learning_rate=lr, decay=lr / comms_round, momentum=0.9)
        local_model.compile(loss=loss, optimizer=local_optimizer, metrics=metrics)

        # Set local model weight to the weight of the global model
        local_model.set_weights(global_weights)

        # Fit local model with client's data
        # SASSE - Preciso ter certeza sobre quais sao os dados que cada cliente tá usando
        local_model.fit(clients_batched[client], epochs=1, verbose=0)

        # Scale the model weights and add to list
        # SASSE - Nao entendi bem a logica de escalonar os pesos (tá de acordo com a formula?)
        scaling_factor = weight_scaling_factor(clients_batched, client)
        scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
        scaled_local_weight_list.append(scaled_weights)

        # Clear session to free memory after each communication round
        # SASSE - Não sei pra que serve isso
        tf.keras.backend.clear_session()

    # To get the average over all the local models, we simply take the sum of the scaled weights
    average_weights = sum_scaled_weights(scaled_local_weight_list)

    # Update global model
    global_model.set_weights(average_weights)

    # Test global model and print out metrics after each communication round
    for (X_test, Y_test) in test_batched:
        global_acc, global_loss = test_model(X_test, Y_test, global_model, comm_round)

# Prepare the SGD dataset
# SASSE - atualizar código para poder mudar o tamanho do batch
# SASSE - Nao entendi o shuffle
SGD_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(len(y_train)).batch(320)

# Initialize the SGD model
smlp_SGD = SimpleMLP()
SGD_model = smlp_SGD.build(784, 10)

# Create optimizer and compile model
lr = 0.01
loss = 'categorical_crossentropy'
metrics = ['accuracy']
optimizer = SGD(learning_rate=lr, decay=lr / comms_round, momentum=0.9)
SGD_model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

# Fit the SGD training data to model
SGD_model.fit(SGD_dataset, epochs=100, verbose=0)

# Test the SGD global model and print out metrics
for (X_test, Y_test) in test_batched:
    SGD_acc, SGD_loss = test_model(X_test, Y_test, SGD_model, 1)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
comm_round: 0 | global_acc: 90.060% | global_loss: 1.6277786493301392


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
comm_round: 1 | global_acc: 92.180% | global_loss: 1.5873966217041016


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
comm_round: 2 | global_acc: 93.690% | global_loss: 1.5679781436920166


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
comm_round: 3 | global_acc: 94.260% | global_loss: 1.5547382831573486


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
comm_round: 4 | global_acc: 94.760% | global_loss: 1.5457147359848022


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
comm_round: 5 | global_acc: 95.010% | global_loss: 1.5392152070999146


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
comm_round: 6 | global_acc: 95.700% | global_loss: 1.5317574739456177


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
comm_round: 7 | global_acc: 95.810% | global_loss: 1.5280823707580566


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
comm_round: 8 | global_acc: 96.050% | global_loss: 1.523876667022705


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
comm_round: 9 | global_acc: 96.290% | global_loss: 1.5203217267990112


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
comm_round: 1 | global_acc: 97.920% | global_loss: 1.4841071367263794
