In [1]:
# url = "https://usaupload.com/72Eb/mnist.zip?download_token=122af6eb9e746e2774d3422cba7775dd9ca2a43aaa3156f4c0b5043723a24823"
# name = url.split("/")[-1]
# fileName = name.split("?")[0]

# !wget $url
# !mv $name $fileName

In [2]:
# %%capture
# !unzip mnist.zip;

## Start

In [3]:
import numpy as np
import random
import tensorflow as tf
import cv2
import os
from imutils import paths
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from keras import models, layers

# from sklearn.preprocessing import LabelBinarizer
# from sklearn.metrics import accuracy_score

# from keras.models import Sequential
# from keras.layers import Activation
# from keras.layers import Dense
# from keras import optimizers
# from keras import backend as K

In [4]:
mySeed = 42
np.random.seed(mySeed)
random.seed(mySeed)
tf.random.set_seed(mySeed)
# torch.manual_seed(mySeed)

## Load Data

In [5]:
def load_mnist_bypath(lst_image_paths, verbose=-1):
    """ Expect to read images where each class is in a separate directory,
        For example: images of type 0 are in folder 0
    """

    lstData = list()
    lstLabel = list()

    for (i, imgPath) in enumerate(lst_image_paths):
        img = cv2.imread(imgPath, cv2.IMREAD_GRAYSCALE)
        img = img.flatten()
        img = img/255
        
        label = imgPath.split(os.path.sep)[-2]
        
        lstData.append(img)
        lstLabel.append(label)
        
        # show an update every `verbose` images
        if verbose > 0 and i > 0 and (i+1) % verbose == 0:
            print(f"[INFO] processed {i+1}/{len(lst_image_paths)}")
            break

    return lstData, lstLabel

In [6]:
img_path = "mnist/trainingSet/trainingSet"
# img_path = "/content/mnist/trainingSet/trainingSet"

# Generate a list of all images
lst_image_paths = list(paths.list_images(img_path))

lstData, lstLabel = load_mnist_bypath(lst_image_paths, verbose=10000)
data = np.array(lstData)
labels = np.array(lstLabel)

x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

oneHotEncoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
y_train = oneHotEncoder.fit_transform(y_train.reshape(-1, 1))
y_test = oneHotEncoder.transform(y_test.reshape(-1, 1))

print(f"Train x={x_train.shape}, y={y_train.shape}")
print(f"Test x={x_test.shape}, y={y_test.shape}")

[INFO] processed 10000/42000
Train x=(8000, 784), y=(8000, 3)
Test x=(2000, 784), y=(2000, 3)


## Training

In [7]:
def create_clients_with_data_assignment(image_list, label_list, num_clients=10, initial="client"):
    """ return: A dictionary with the customer id as the dictionary key and the value
                will be the data fragment - tuple of images and labels.
        args:
            image_list: a numpy array object with the images
            label_list: list of binarized labels (one-hot encoded)
            num_clients: number of customers (clients)
            initial: the prefix of the clients, e.g., client_1
     """

    # create list of client names
    client_names = [f"{initial}_{i+1}" for i in range(num_clients)]

    # shuffle the data
    data = list(zip(image_list, label_list))
    random.shuffle(data)

    # shard the data and split it for each customer
    size = len(data) // num_clients
    shards = [data[i: i+size]  for i in range(0, size*num_clients, size)]

    # Check if the fragment number is equal to the number of clients
    assert(len(shards) == len(client_names))

    return {client_names[i]: shards[i]  for i in range(len(client_names))}

In [8]:
def batch_data(data_shard, batch_size=32):
    """ Receives a piece of data (imgs, labels) from a client and creates a tensorflow Dataset object in it
        args:
            data_shard: data and labels that make up a customer's data shard
            batch_size: batch size
        return:
            data tensorflow Dataset object
    """
    #seperate shard into data and labels lists
    data, label = zip(*data_shard)
    dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label)))
    return dataset.shuffle(len(label)).batch(batch_size)

In [9]:
def scale_model_weights(weight, scalar):
    """ Scale the model weights """
    
    weight_final = []
    for i in range(len(weight)):
        weight_final.append(weight[i] * scalar)

    return weight_final

In [10]:
class MLP:
    @staticmethod
    def build(shape, classes):
        model = models.Sequential([
            layers.Dense(100, activation="relu", input_shape=shape),
            layers.Dense(100, activation="relu"),
            layers.Dense(classes, activation="softmax"),
            ])

        return model

In [11]:
# Hyper parameters

num_clients = 2
batch_size = 32

client_select_rate = 1

comms_round = 1
client_epochs = 1
lr = 0.01
# optimizer = optimizers.Adam(learning_rate=lr, decay=lr/comms_round)
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=lr, decay=lr/comms_round)
# optimizer = "adam"
loss = "categorical_crossentropy"
metrics = ["accuracy"]

In [12]:
# Create clients and batched data

clients = create_clients_with_data_assignment(x_train, y_train, num_clients=num_clients, initial="client")

# Bached data with tensorflow data object
clients_batched = dict()
for (client_name, data) in clients.items():
    clients_batched[client_name] = batch_data(data, batch_size)

# Convert labels to tensorflow data object
test_batched = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(len(y_test))

In [22]:
mlp_global = MLP()
global_model = mlp_global.build(x_train.shape[1:], y_train.shape[-1])

# Global training loop collection
for comm_round in range(comms_round):

    # global model's weights - will serve as the initial weights for all local models
    global_weights = global_model.get_weights()

    # initial list to collect local model weights after scalling
    scaled_local_weight_list = list()

    # randomize client - using keys
    client_names = list(clients_batched.keys())
    random.shuffle(client_names)
    client_select = client_names[0: int(num_clients*client_select_rate)]

    # calculate total training data across selected clients
    # if all clients have a same length
    global_count = len(clients[client_select[0]]) * len(client_select)

    # loop through each client and create a new local model
    for client in client_select:
        mlp_local = MLP()
        local_model = mlp_local.build(x_train.shape[1:], y_train.shape[-1])
        local_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

        # set the weight of the local model from the weight of the global model
        local_model.set_weights(global_weights)

        # fit local model with client's data
        local_model.fit(clients_batched[client], epochs=client_epochs, verbose=0)

        # scale the model weights and added to the list
        local_count = len(clients[client])
        scaling_factor = local_count / global_count
        scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
        scaled_local_weight_list.append(scaled_weights)

        print("test")
        print(local_model.get_weights())
        
        break

#         # Check local accuracy
#         # acc_l, loss_l = check_local_loss(client, local_model)

#         # clear session to free memory after each communication round
#         K.clear_session()

#     # to get the average over all the local model, we simply take the sum of the scaled weights
#     average_weights = sum_scaled_weights(scaled_local_weight_list)

#     # update global model
#     global_model.set_weights(average_weights)

#     # test global model and print out metrics after each communications round
#     for (X_test, Y_test) in test_batched:
#         global_acc, global_loss = test_model(X_test, Y_test, global_model, comm_round)

test
[array([[-0.10273435, -0.04704752,  0.00808045, ...,  0.00588167,
        -0.00892175,  0.04568824],
       [-0.08046278, -0.08837207,  0.07953827, ..., -0.01924175,
        -0.0003405 ,  0.05012556],
       [-0.1063384 ,  0.07703272,  0.1499303 , ..., -0.06314262,
         0.07507354, -0.0030708 ],
       ...,
       [ 0.03414189,  0.04948349, -0.01683494, ..., -0.01540671,
        -0.00967073, -0.07834828],
       [ 0.01090276,  0.02770779, -0.04498865, ...,  0.01563124,
        -0.06267883,  0.01385543],
       [-0.00375625,  0.04419962, -0.04669923, ..., -0.076101  ,
         0.07397996,  0.02152166]], dtype=float32), array([ 0.03723201, -0.04715949, -0.026675  , -0.04260453,  0.03148137,
        0.04119245, -0.04683571, -0.01977016, -0.07218079, -0.01400003,
       -0.10265958,  0.01042398, -0.06595653, -0.01400225, -0.04816522,
       -0.05215771, -0.01774357, -0.03047089, -0.01886009,  0.00966556,
       -0.0549637 ,  0.03881105, -0.07077589, -0.00620759,  0.00185925,
     