# FLEX-clash: how to use the built-in defenses

In this notebook, we will show how to use the built-in defenses of the FLEX-clash library. This defenses are implemented as aggregation functions which allows minimal changes in our FLEX experiments.

### Setting up the experiment

Before defending ourselves, we need to set up the experiment and implement an attack. We will train a model in the `federated_mnist` dataset and then attack it using the `@model_poisoned` from the `FLEX-clash` library. For more information on this attack, please refer to the `poison_models` notebook.


In [27]:
from flex.datasets import load
import tensorflow as tf

flex_dataset, test_dataset = load("federated_emnist", return_test=True, split="digits")

flex_dataset["server"] = test_dataset

[36m[sultan]: md5 -q ./emnist-digits.mat;[0m
[01;31m[sultan]: Unable to run 'md5 -q ./emnist-digits.mat;'[0m
[01;31m[sultan]: --{ TRACEBACK }----------------------------------------------------------------------------------------------------[0m
[01;31m[sultan]: | NoneType: None[0m
[01;31m[sultan]: | [0m
[01;31m[sultan]: -------------------------------------------------------------------------------------------------------------------[0m
[01;31m[sultan]: --{ STDERR }-------------------------------------------------------------------------------------------------------[0m
[01;31m[sultan]: | /bin/sh: 1: md5: not found[0m
[01;31m[sultan]: -------------------------------------------------------------------------------------------------------------------[0m
[33m[sultan]: The following are additional information that can be used to debug this exception.[0m
[33m[sultan]: The following is the context used to run:[0m
[33m[sultan]: 	 - cwd: None[0m
[33m[sultan]: 	 - sudo:

In [28]:
from flex.pool import FlexPool
from flex.pool.primitives_tf import init_server_model_tf


# Defining the model
def get_model():
    model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
    return model

flex_pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=init_server_model_tf, model=get_model())

clients = flex_pool.clients
server = flex_pool.servers
aggregator = flex_pool.aggregators

print(f"Number of nodes in the pool {len(flex_pool)}: {len(server)} server plus {len(clients)} clients. The server is also an aggregator")

Number of nodes in the pool 3580: 1 server plus 3579 clients. The server is also an aggregator


In [29]:
from flex.pool.primitives_tf import deploy_server_model_tf

#Select clients
clients_per_round=20
selected_clients_pool = flex_pool.clients.select(clients_per_round)

server.map(deploy_server_model_tf, selected_clients_pool)

In [30]:
from flex.pool.primitives_tf import train_tf

selected_clients_pool.map(train_tf, batch_size=512, epochs=1, verbose=False)

Now, just before sending the client's weights to the aggregator/server, we will randomize the weights of one random client

In [31]:
from flexclash.model import model_poisoner
from flex.model import FlexModel
import numpy as np

randomized_weights_client = selected_clients_pool.select(1)

@model_poisoner
def weight_randomizer(client_model: FlexModel):
    rand_weights = [np.random.randn(*w.shape) for w in client_model["model"].get_weights()]
    client_model["model"].set_weights(rand_weights)
    return client_model

randomized_weights_client.map(weight_randomizer)

INFO:tensorflow:Assets written to: ram://fadaad1e-a2fe-4961-a2d3-357b51d7882a/assets


And then, we collect the weights from the clients

In [32]:
from flex.pool.primitives_tf import collect_clients_weights_tf

aggregator.map(collect_clients_weights_tf, selected_clients_pool)

### Defending ourselves

Now that our clients have trained their models and sent their weights to the server, we can use the built-in defenses to protect our model from the poisoned weights. In our case, we will use `multikrum` as a defense. This defense is implemented as an aggregation function in the `FLEX-clash` library, so we only need to change the aggregation function in our experiment.

Note that `FLEX-clash` implements more defenses aside from multikrum. In order to get a list of all the implemented defenses, you can refer to the documentation of `FLEX-clash`.

In [33]:
from flexclash.pool.defences import multikrum

aggregator.map(multikrum)

### Wrapping up and whole training round

After defending ourselves, we can proceed with the training round as usual. We will train the model using the aggregated weights and then send the updated weights to the clients.

In [34]:
from flex.pool.primitives_tf import set_aggregated_weights_tf
aggregator.map(set_aggregated_weights_tf, server)

Putting everything together we have the following code:

In [39]:
from flex.pool.primitives_tf import evaluate_model_tf

def train_model(n_rounds=10):
    for _ in range(n_rounds):
        selected_clients_pool = flex_pool.clients.select(clients_per_round)
        server.map(deploy_server_model_tf, selected_clients_pool)
        selected_clients_pool.map(train_tf, batch_size=512, epochs=1, verbose=False)
        randomized_weights_client = selected_clients_pool.select(1)
        randomized_weights_client.map(weight_randomizer)
        aggregator.map(collect_clients_weights_tf, selected_clients_pool)
        aggregator.map(multikrum)
        aggregator.map(set_aggregated_weights_tf, server)
        # Evaluate the model
        [(loss, acc)] = server.map(evaluate_model_tf)
        print(f"Test loss: {loss}, Test accuracy: {acc}")

In [40]:
train_model(10)

INFO:tensorflow:Assets written to: ram://bffa99f7-8bbe-428f-a2f4-e5893f4b4d58/assets
Test loss: 72.41793060302734, Test accuracy: 0.256850004196167
INFO:tensorflow:Assets written to: ram://bd28eddd-dcd4-4e84-b3e1-8a9ec1d0d678/assets
Test loss: 66.6341323852539, Test accuracy: 0.3015249967575073
INFO:tensorflow:Assets written to: ram://3c9ddcbf-8931-4bce-bbd3-e7bb329fc660/assets
Test loss: 52.304656982421875, Test accuracy: 0.3829500079154968
INFO:tensorflow:Assets written to: ram://54574cc1-07d8-4e49-a251-7f5ce2976566/assets
Test loss: 47.13664245605469, Test accuracy: 0.40869998931884766
INFO:tensorflow:Assets written to: ram://8ecdc7b6-1dcf-4e21-88f5-e6587ce9191a/assets
Test loss: 41.96918869018555, Test accuracy: 0.45809999108314514
INFO:tensorflow:Assets written to: ram://adf0630e-396c-4bae-8413-135fb8262c67/assets
Test loss: 47.48677062988281, Test accuracy: 0.525825023651123
INFO:tensorflow:Assets written to: ram://01fac880-790f-4b68-b801-18c8ac4a5d8a/assets
Test loss: 34.5077819