In [4]:
import numpy as np
import random
import cv2
import os
from imutils import paths
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score

In [5]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
from tensorflow.keras import backend as K

In [6]:
def load(paths, verbose=-1):
    data = list()
    labels = list()
    for (i, imgpath) in enumerate(paths):
        im_gray = cv2.imread(imgpath, cv2.IMREAD_GRAYSCALE)
        image = np.array(im_gray).flatten()
        label = imgpath.split(os.path.sep)[-2]
        data.append(image/255)
        labels.append(label)
        if verbose > 0 and i > 0 and (i + 1) % verbose == 0:
            print(i + 1, len(paths))
    return data, labels

In [8]:
img_path = '/content/archive/trainingSet/trainingSet/'

image_paths = list(paths.list_images(img_path))
image_list, label_list = load(image_paths, verbose=7000)
lb = LabelBinarizer()
label_list = lb.fit_transform(label_list)
X_train, X_test, y_train, y_test = train_test_split(image_list, 
                                                    label_list, 
                                                    test_size=0.1, 
                                                    random_state=42)

7000 42000
14000 42000
21000 42000
28000 42000
35000 42000
42000 42000


In [9]:
def create_clients(image_list, label_list, num_clients=10, initial='clients'):
    client_names = ['{}_{}'.format(initial, i+1) for i in range(num_clients)]
    data = list(zip(image_list, label_list))
    random.shuffle(data)
    size = len(data)//num_clients
    shards = [data[i:i + size] for i in range(0, size*num_clients, size)]
    assert(len(shards) == len(client_names))

    return {client_names[i] : shards[i] for i in range(len(client_names))} 

In [10]:
clients = create_clients(X_train, y_train, num_clients=10, initial='client')

In [11]:
def batch_data(data_shard, bs=32):
    data, label = zip(*data_shard)
    dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label)))
    return dataset.shuffle(len(label)).batch(bs)

In [12]:
clients_batched = dict()
for (client_name, data) in clients.items():
    clients_batched[client_name] = batch_data(data)
test_batched = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(len(y_test))

In [13]:
class SimpleMLP:
    @staticmethod
    def build(shape, classes):
        model = Sequential()
        model.add(Dense(200, input_shape=(shape,)))
        model.add(Activation("relu"))
        model.add(Dense(200))
        model.add(Activation("relu"))
        model.add(Dense(classes))
        model.add(Activation("softmax"))
        return model

In [15]:
lr = 0.01 
comms_round = 100
loss='categorical_crossentropy'
metrics = ['accuracy']
optimizer = SGD(learning_rate=lr, decay=lr / comms_round, momentum=0.9)

In [21]:
def weight_scalling_factor(clients_trn_data, client_name):
    client_names = list(clients_trn_data.keys())
    bs = list(clients_trn_data[client_name])[0][0].shape[0]
    global_count = sum([tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() for client_name in client_names])*bs
    local_count = tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy()*bs
    return local_count/global_count


def sum_scaled_weights(scaled_weight_list):
    avg_grad = list()
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)
        avg_grad.append(layer_mean)
    return avg_grad

In [17]:
def scale_model_weights(weight, scalar):
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])
    return weight_final

In [18]:
def test_model(X_test, Y_test,  model, comm_round):
    cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    #logits = model.predict(X_test, batch_size=100)
    logits = model.predict(X_test)
    loss = cce(Y_test, logits)
    acc = accuracy_score(tf.argmax(logits, axis=1), tf.argmax(Y_test, axis=1))
    print('comm_round: {} | global_acc: {:.3%} | global_loss: {}'.format(comm_round, acc, loss))
    return acc, loss

In [22]:
smlp_global = SimpleMLP()
global_model = smlp_global.build(784, 10)
for comm_round in range(comms_round):
    global_weights = global_model.get_weights()
    scaled_local_weight_list = list()
    client_names= list(clients_batched.keys())
    random.shuffle(client_names)
    for client in client_names:
        smlp_local = SimpleMLP()
        local_model = smlp_local.build(784, 10)
        local_model.compile(loss=loss, 
                      optimizer=optimizer, 
                      metrics=metrics)
        local_model.set_weights(global_weights)
        
        local_model.fit(clients_batched[client], epochs=1, verbose=0)
        scaling_factor = weight_scalling_factor(clients_batched, client)
        scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
        scaled_local_weight_list.append(scaled_weights)
        K.clear_session()
    average_weights = sum_scaled_weights(scaled_local_weight_list)

    global_model.set_weights(average_weights)
    for(X_test, Y_test) in test_batched:
        global_acc, global_loss = test_model(X_test, Y_test, global_model, comm_round)

comm_round: 0 | global_acc: 87.905% | global_loss: 1.6915663480758667
comm_round: 1 | global_acc: 90.714% | global_loss: 1.6246594190597534
comm_round: 2 | global_acc: 91.881% | global_loss: 1.60403311252594
comm_round: 3 | global_acc: 92.810% | global_loss: 1.5902385711669922
comm_round: 4 | global_acc: 93.357% | global_loss: 1.58234441280365
comm_round: 5 | global_acc: 93.619% | global_loss: 1.5753751993179321
comm_round: 6 | global_acc: 93.952% | global_loss: 1.5676215887069702
comm_round: 7 | global_acc: 94.143% | global_loss: 1.563452959060669
comm_round: 8 | global_acc: 94.381% | global_loss: 1.5594509840011597
comm_round: 9 | global_acc: 94.571% | global_loss: 1.555127501487732
comm_round: 10 | global_acc: 94.952% | global_loss: 1.551578402519226
comm_round: 11 | global_acc: 95.000% | global_loss: 1.5489455461502075
comm_round: 12 | global_acc: 95.024% | global_loss: 1.546693205833435
comm_round: 13 | global_acc: 95.143% | global_loss: 1.544580340385437
comm_round: 14 | global_a