In [1]:
!pip install sklearn

You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [13]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import accuracy_score

In [14]:
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

In [15]:
X_train = X_train / 255.0
X_test = X_test / 255.0

# altera de uma matriz para um vetor
X_train = X_train.reshape(-1, 28*28)
X_test = X_test.reshape(-1, 28*28)
X_train.shape, X_test.shape

((60000, 784), (10000, 784))

In [16]:
lb = LabelBinarizer()
y_train = lb.fit_transform(y_train)
y_test = lb.fit_transform(y_test)

In [17]:
y_train.shape

(60000, 10)

In [18]:
def create_clients(X, y, num_clients=10, initial='clients'):
    data = list(zip(X_train, y_train))
    np.random.shuffle(data)
    
    size = len(data)//num_clients
    shards = [data[i:i + size] for i in range(0, size*num_clients, size)]

    return {'{}_{}'.format(initial, i+1): shards[i] for i in range(len(shards))} 

In [19]:
clients = create_clients(X_train, y_train)

In [20]:
len(clients['clients_1']), len(clients['clients_1'][0]), len(clients['clients_1'][0][0])

(6000, 2, 784)

In [21]:
class SimpleMLP:
    def build(self, shape=784, n_classes=10):
        model = tf.keras.models.Sequential()
        model.add(tf.keras.layers.Dense(units=256, activation='relu', input_shape=(shape, )))
        model.add(tf.keras.layers.Dropout(0.2))
        model.add(tf.keras.layers.Dense(units=128, activation='relu'))
        model.add(tf.keras.layers.Dropout(0.2))
        model.add(tf.keras.layers.Dense(units=n_classes, activation='softmax'))
        return model

In [22]:
lr = 0.01 
comms_round = 100
loss='categorical_crossentropy'
metrics = ['accuracy']
optimizer = tf.keras.optimizers.SGD(lr=lr, decay= lr/comms_round, momentum=0.9)    

In [23]:
def batch_data(data_shard, bs=32):
    '''Takes in a clients data shard and create a tfds object off it
    args:
        shard: a data, label constituting a client's data shard
        bs:batch size
    return:
        tfds object'''
    #seperate shard into data and labels lists
    data, label = zip(*data_shard)
    dataset = tf.data.Dataset.from_tensor_slices((list(data), list(label)))
    return dataset.shuffle(len(label)).batch(bs)

In [24]:
clients_batched = dict()
for (client_name, data) in clients.items():
    clients_batched[client_name] = batch_data(data)
    
#process and batch the test set  
test_batched = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(len(y_test))

In [41]:
def weight_scalling_factor(clients_trn_data, client_name):
    client_names = list(clients_trn_data.keys())
    #get the bs
    bs = list(clients_trn_data[client_name])[0][0].shape[0]
    #first calculate the total training data points across clinets
    global_count = sum([tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy() for client_name in client_names])*bs
    # get the total number of data points held by a client
    local_count = tf.data.experimental.cardinality(clients_trn_data[client_name]).numpy()*bs
    return local_count/global_count


def scale_model_weights(weight, scalar):
    '''function for scaling a models weights'''
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])
    return weight_final



def sum_scaled_weights(scaled_weight_list):
    '''Return the sum of the listed scaled weights. The is equivalent to scaled avg of the weights'''
    avg_grad = list()
    #get the average grad accross all client gradients
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)
        avg_grad.append(layer_mean)
        
    return avg_grad

def test_model(X_test, Y_test,  model, comm_round):
    cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
    #logits = model.predict(X_test, batch_size=100)
    logits = model.predict(X_test)
    loss = cce(Y_test, logits)
    acc = accuracy_score(tf.argmax(logits, axis=1), tf.argmax(Y_test, axis=1))
    print('comm_round: {} | global_acc: {:.3%} | global_loss: {}'.format(comm_round, acc, loss))
    return acc, loss

In [96]:
#initialize global model
smlp_global = SimpleMLP()
global_model = smlp_global.build(784, 10)
#initial list to collect local model weights after scalling
scaled_local_weight_list = list()

In [76]:
len(global_model.get_weights()), global_model.get_weights()

(6,
 [array([[ 0.04492731, -0.0377527 ,  0.02601371, ...,  0.01681063,
          -0.03858424, -0.07051905],
         [ 0.0234365 , -0.00465466,  0.04922183, ...,  0.03234618,
          -0.02290316,  0.03442608],
         [-0.01247863, -0.01517874,  0.00987884, ..., -0.01090468,
          -0.03614107, -0.02121803],
         ...,
         [ 0.05744471, -0.04410993, -0.07162853, ...,  0.0328931 ,
           0.00021853,  0.00152937],
         [ 0.02754694, -0.04638512,  0.07462248, ..., -0.07387449,
           0.02322656, -0.02798993],
         [-0.0335606 ,  0.05307056,  0.00978773, ...,  0.00063949,
          -0.00276811,  0.05166396]], dtype=float32),
  array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 

In [116]:
client = 'clients_1'

global_weights = global_model.get_weights()

smlp_local = SimpleMLP()
local_model = smlp_local.build(784, 10)
local_model.compile(loss=loss, 
              optimizer=optimizer, 
              metrics=metrics)

#set local model weight to the weight of the global model
local_model.set_weights(global_weights)

#fit local model with client's data
local_model.fit(clients_batched[client], epochs=1, verbose=0)

<tensorflow.python.keras.callbacks.History at 0x7f8274472780>

In [117]:
len(local_model.get_weights()), local_model.get_weights()

(6,
 [array([[ 0.02202396, -0.0673397 ,  0.01472675, ...,  0.07201938,
           0.011646  , -0.01588138],
         [ 0.02838603, -0.02794591, -0.04186536, ..., -0.06186739,
           0.0335608 ,  0.00673139],
         [-0.07395984, -0.00890434,  0.05034409, ...,  0.01630476,
           0.05079816, -0.02315992],
         ...,
         [-0.06528834,  0.05466234,  0.01821261, ..., -0.02343147,
          -0.06722378, -0.00811572],
         [ 0.04238575, -0.04846166, -0.00087806, ...,  0.0162291 ,
          -0.03502939,  0.05050121],
         [ 0.07224447,  0.07412121,  0.04664997, ..., -0.07057964,
           0.06417582, -0.04777728]], dtype=float32),
  array([-5.66468854e-03, -7.04928680e-05, -1.65084917e-02,  7.52111385e-03,
         -8.69235769e-03,  6.48062816e-03,  9.33363801e-04,  1.50237959e-02,
         -2.06794147e-03, -1.92101346e-04, -6.69149915e-03,  1.28754713e-02,
          3.67250363e-03,  9.68204986e-04,  2.58780969e-03,  1.60559709e-03,
          3.93112144e-03,  2.0117

In [78]:
scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
len(scaled_weights), scaled_weights

(6,
 [array([[ 4.4937083e-03, -3.7752700e-03,  2.6014070e-03, ...,
           1.6806001e-03, -3.8574692e-03, -7.0519052e-03],
         [ 2.3445585e-03, -4.6546609e-04,  4.9220785e-03, ...,
           3.2343380e-03, -2.2896016e-03,  3.4430702e-03],
         [-1.2509534e-03, -1.5177681e-03,  9.8988530e-04, ...,
          -1.0903846e-03, -3.6127183e-03, -2.1178494e-03],
         ...,
         [ 5.7146423e-03, -4.4721742e-03, -7.2167851e-03, ...,
           3.2227102e-03, -2.2313397e-04,  8.5649801e-05],
         [ 2.7495711e-03, -4.6563563e-03,  7.4629569e-03, ...,
          -7.4194120e-03,  2.2505997e-03, -2.8252930e-03],
         [-3.3505410e-03,  5.3045028e-03,  9.7986904e-04, ...,
           6.1521059e-05, -2.9147349e-04,  5.1598852e-03]], dtype=float32),
  array([-9.74121445e-04,  2.28618970e-03,  4.50373208e-03, -2.84800283e-03,
         -1.25335611e-03,  4.89986083e-03,  6.64546434e-03,  4.32040868e-03,
         -1.82566268e-03,  9.78314434e-04, -1.81013125e-03,  9.91427369e-05,
  

In [118]:
scaled_local_weight_list.append(scaled_weights)
scaled_local_weight_list

[[array([[ 4.4938847e-03, -3.7752700e-03,  2.6006282e-03, ...,
           1.6824458e-03, -3.8572466e-03, -7.0530167e-03],
         [ 2.3447538e-03, -4.6547022e-04,  4.9214177e-03, ...,
           3.2362861e-03, -2.2889711e-03,  3.4439038e-03],
         [-1.2484844e-03, -1.5200365e-03,  9.8493916e-04, ...,
          -1.0867817e-03, -3.6084354e-03, -2.1195828e-03],
         ...,
         [ 5.4660020e-03, -4.3894388e-03, -7.2130519e-03, ...,
           2.8649487e-03, -5.1749096e-04,  4.7402389e-04],
         [ 2.5357793e-03, -4.5854552e-03,  7.4377134e-03, ...,
          -7.5678853e-03,  2.1743998e-03, -2.5743032e-03],
         [-3.3758432e-03,  5.3068148e-03,  9.6991286e-04, ...,
           4.8039219e-05, -3.0098375e-04,  5.1878202e-03]], dtype=float32),
  array([-8.1193317e-03,  1.1999806e-02,  1.4020452e-02, -8.6388560e-03,
          7.9238545e-03,  2.1863133e-02,  1.9013379e-02,  8.5544208e-04,
         -5.1560095e-03,  1.4535682e-02, -1.8285362e-03,  2.2878146e-04,
          1.268049

In [None]:
#initialize global model
smlp_global = SimpleMLP()
global_model = smlp_global.build(784, 10)
#initial list to collect local model weights after scalling
scaled_local_weight_list = list()
second_list = list()

#loop through each client and create new local model
for client in client_names:
    smlp_local = SimpleMLP()
    local_model = smlp_local.build(784, 10)
    local_model.compile(loss=loss, 
                  optimizer=optimizer, 
                  metrics=metrics)

    #set local model weight to the weight of the global model
    local_model.set_weights(global_weights)

    #fit local model with client's data
    local_model.fit(clients_batched[client], epochs=1, verbose=0)
    
    second_list.append(local_model.get_weights())

    #scale the model weights and add to list
    scaling_factor = weight_scalling_factor(clients_batched, client)
    scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
    scaled_local_weight_list.append(scaled_weights)

In [None]:
average_weights = sum_scaled_weights(scaled_local_weight_list)
average_weights

In [None]:
second_list

In [None]:
global_weights = 0.1 * np.array(second_list[0])
for i in range(1,10):
    scaled_local_weight_list[i] = 0.1 * np.array(second_list[i])
    global_weights += np.array(second_list[i])

In [None]:
global_weights

In [None]:
average_weights

In [15]:
#commence global training loop
for comm_round in range(comms_round):
            
    # get the global model's weights - will serve as the initial weights for all local models
    global_weights = global_model.get_weights()
    
    #initial list to collect local model weights after scalling
    scaled_local_weight_list = list()

    #randomize client data - using keys
    client_names= list(clients_batched.keys())
    np.random.shuffle(client_names)
    
    #loop through each client and create new local model
    for client in client_names:
        smlp_local = SimpleMLP()
        local_model = smlp_local.build(784, 10)
        local_model.compile(loss=loss, 
                      optimizer=optimizer, 
                      metrics=metrics)
        
        #set local model weight to the weight of the global model
        local_model.set_weights(global_weights)
        
        #fit local model with client's data
        local_model.fit(clients_batched[client], epochs=1, verbose=0)
        
        #scale the model weights and add to list
        scaling_factor = weight_scalling_factor(clients_batched, client)
        scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
        scaled_local_weight_list.append(scaled_weights)
        
    #to get the average over all the local model, we simply take the sum of the scaled weights
    average_weights = sum_scaled_weights(scaled_local_weight_list)
    
    #update global model 
    global_model.set_weights(average_weights)
    
    #test global model and print out metrics after each communications round
    for(X_test, Y_test) in test_batched:
        global_acc, global_loss = test_model(X_test, Y_test, global_model, comm_round)

comm_round: 0 | global_acc: 76.620% | global_loss: 1.7872625589370728
comm_round: 1 | global_acc: 79.810% | global_loss: 1.7324069738388062
comm_round: 2 | global_acc: 81.900% | global_loss: 1.7141972780227661
comm_round: 3 | global_acc: 82.590% | global_loss: 1.702681303024292
comm_round: 4 | global_acc: 83.220% | global_loss: 1.6902883052825928
comm_round: 5 | global_acc: 83.520% | global_loss: 1.6865720748901367
comm_round: 6 | global_acc: 84.060% | global_loss: 1.6780565977096558
comm_round: 7 | global_acc: 84.200% | global_loss: 1.676525354385376
comm_round: 8 | global_acc: 84.360% | global_loss: 1.6692383289337158
comm_round: 9 | global_acc: 84.620% | global_loss: 1.6687312126159668
comm_round: 10 | global_acc: 84.770% | global_loss: 1.665055274963379
comm_round: 11 | global_acc: 84.820% | global_loss: 1.6633487939834595
comm_round: 12 | global_acc: 85.110% | global_loss: 1.6609630584716797
comm_round: 13 | global_acc: 85.080% | global_loss: 1.6586968898773193
comm_round: 14 | gl