<a href="https://colab.research.google.com/github/Jaideep07/Federated-Learning-Intro/blob/main/Federated_Learning_Workshop.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## FL model implementation using a simple Feed Forward Neural Network on MNIST digits dataset

In [None]:
import numpy as np
import random
import cv2
import os

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt

In [None]:
# loading the MNIST dataset into X and y 
(X, y),_ = tf.keras.datasets.mnist.load_data()

In [None]:
# converting the imges of shape 28*28 into a vector of shape 784 and scaling the pixel intensities so 
# that they fall in the range of 0 and 1. Applying one hot encoding to the target variable y.
X = np.array([x.flatten()/255. for x in X])
lb = LabelBinarizer()
y = lb.fit_transform(y)

#split data into training and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
def create_clients(X, y, num_clients=10, initial='clients'):
    ''' return: a dictionary with keys clients' names and value as 
                data grps - tuple of images and label lists.
        args: 
            X: a list of numpy arrays of training images
            y: a list of binarized labels for each image
            num_client: number of fedrated members (clients)
            initials: the clients'name prefix, e.g, clients_1 
            
    '''

    #create a list of client names
    client_names = ['{}_{}'.format(initial, i+1) for i in range(num_clients)]

    #randomize the data
    data = list(zip(X, y))
    random.shuffle(data)

    #split the data equally for each client and place at each client
    size = len(data)//num_clients
    data_grp = [data[i:i + size] for i in range(0, size*num_clients, size)]

    #number of clients must equal number of shards
    assert(len(data_grp) == len(client_names))

    return {client_names[i] : data_grp[i] for i in range(len(client_names))} 

In [None]:
clients = create_clients(X_train, y_train, num_clients=10, initial='client')

In [None]:
def batch_data(data_grp, bs=32):
    '''Takes in a clients data shard and create a tfds object off it
    args:
        grp: a data, label constituting a client's data grp
        bs:batch size
    return:
        tfds object'''
    #seperate shard into data and labels lists
    data, label = zip(*data_grp)
    dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label)))
    return dataset.shuffle(len(label)).batch(bs)

In [None]:
clients_batched = dict()
for (client_name, data) in clients.items():
    clients_batched[client_name] = batch_data(data)
    
#process and batch the test set  
test_batched = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(len(y_test))

In [None]:
class SimpleMLP:
    @staticmethod
    def build(shape, classes):
        model = Sequential()
        model.add(Dense(200, input_shape=(shape,)))
        model.add(Activation("relu"))
        model.add(Dense(200))
        model.add(Activation("relu"))
        model.add(Dense(classes))
        model.add(Activation("softmax"))
        return model

In [None]:
lr = 0.01 
comms_round = 50
loss='categorical_crossentropy'
metrics = ['accuracy']
optimizer = SGD(learning_rate=lr, 
                decay=lr / comms_round, 
                momentum=0.9
               )   

In [None]:
def weight_scaling_factor(clients_trn_data, client_name):
    client_names = list(clients_trn_data.keys())
    #get the bs
    bs = list(clients_trn_data[client_name])[0][0].shape[0]

    #first calculate the total training data points across clinets
    global_count = sum([tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() for client_name in client_names])*bs

    # get the total number of data points held by a client
    local_count = tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy()*bs
    return local_count/global_count


In [None]:
def scale_model_weights(weight, scalar):
    '''function for scaling a models weights'''
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])
    return weight_final

In [None]:
def sum_scaled_weights(scaled_weight_list):
    '''Return the sum of the listed scaled weights. The is equivalent to scaled avg of the weights'''
    avg_grad = list()
    #get the average grad accross all client gradients
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)
        avg_grad.append(layer_mean)
        
    return avg_grad

In [None]:
def test_model(X_test, Y_test,  model, comm_round):
    cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    #logits = model.predict(X_test, batch_size=100)
    logits = model.predict(X_test)
    loss = cce(Y_test, logits)
    acc = accuracy_score(tf.argmax(logits, axis=1), tf.argmax(Y_test, axis=1))
    print('comm_round: {} | global_acc: {:.3%} | global_loss: {}'.format(comm_round, acc, loss))
    return acc, loss


In [None]:
#initialize global model
smlp_global = SimpleMLP()
global_model = smlp_global.build(784, 10)
        
#commence global training loop
for comm_round in range(comms_round):
            
    # get the global model's weights - will serve as the initial weights for all local models
    global_weights = global_model.get_weights()
    
    #initial list to collect local model weights after scalling
    scaled_local_weight_list = list()

    #randomize client data - using keys
    client_names= list(clients_batched.keys())
    random.shuffle(client_names)
    
    #loop through each client and create new local model
    for client in client_names:
        smlp_local = SimpleMLP()
        local_model = smlp_local.build(784, 10)
        local_model.compile(loss=loss, 
                      optimizer=optimizer, 
                      metrics=metrics)
        
        #set local model weight to the weight of the global model
        local_model.set_weights(global_weights)
        
        #fit local model with client's data
        local_model.fit(clients_batched[client], epochs=1, verbose=0)
        
        #scale the model weights and add to list
        scaling_factor = weight_scaling_factor(clients_batched, client)
        scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
        scaled_local_weight_list.append(scaled_weights)
        
        #clear session to free memory after each communication round
        K.clear_session()
        
    #to get the average over all the local model, we simply take the sum of the scaled weights
    average_weights = sum_scaled_weights(scaled_local_weight_list)
    
    #update global model 
    global_model.set_weights(average_weights)

    #test global model and print out metrics after each communications round
    for(X_test, Y_test) in test_batched:
        global_acc, global_loss = test_model(X_test, Y_test, global_model, comm_round)


comm_round: 0 | global_acc: 88.775% | global_loss: 1.653571605682373
comm_round: 1 | global_acc: 90.883% | global_loss: 1.6104905605316162
comm_round: 2 | global_acc: 91.942% | global_loss: 1.5950510501861572
comm_round: 3 | global_acc: 92.417% | global_loss: 1.5853462219238281
comm_round: 4 | global_acc: 92.950% | global_loss: 1.5780524015426636
comm_round: 5 | global_acc: 93.192% | global_loss: 1.5724875926971436
comm_round: 6 | global_acc: 93.467% | global_loss: 1.568231225013733
comm_round: 7 | global_acc: 93.792% | global_loss: 1.5645662546157837
comm_round: 8 | global_acc: 93.917% | global_loss: 1.5615589618682861
comm_round: 9 | global_acc: 94.050% | global_loss: 1.5589648485183716
comm_round: 10 | global_acc: 94.308% | global_loss: 1.5565916299819946
comm_round: 11 | global_acc: 94.408% | global_loss: 1.5544527769088745
comm_round: 12 | global_acc: 94.542% | global_loss: 1.5525596141815186
comm_round: 13 | global_acc: 94.667% | global_loss: 1.5510752201080322
comm_round: 14 | g

### Task implement the above algorithm on Fashion MNIST dataset [Dataset link](https://www.tensorflow.org/tutorials/keras/classification)

Note:
You need the just load the dataset and then run the code as it is to implement the model. Try changing the num_clients and number of local epochs to get better acc.