# FLEX-clash: how to use the built-in defenses

In this notebook, we will show how to use the built-in defenses of the FLEX-clash library. This defenses are implemented as aggregation functions which allows minimal changes in our FLEX experiments.

### Setting up the experiment

Before defending ourselves, we need to set up the experiment and implement an attack. We will train a model in the `federated_mnist` dataset and then attack it using the `@model_poisoned` from the `FLEX-clash` library. For more information on this attack, please refer to the `poison_models` notebook.


In [7]:
from flex.datasets import load
import tensorflow as tf

flex_dataset = load("federated_emnist", return_test=False, split="digits")

[36m[sultan]: md5 -q ./emnist-digits.mat;[0m
[01;31m[sultan]: Unable to run 'md5 -q ./emnist-digits.mat;'[0m
[01;31m[sultan]: --{ TRACEBACK }----------------------------------------------------------------------------------------------------[0m
[01;31m[sultan]: | NoneType: None[0m
[01;31m[sultan]: | [0m
[01;31m[sultan]: -------------------------------------------------------------------------------------------------------------------[0m
[01;31m[sultan]: --{ STDERR }-------------------------------------------------------------------------------------------------------[0m
[01;31m[sultan]: | /bin/sh: 1: md5: not found[0m
[01;31m[sultan]: -------------------------------------------------------------------------------------------------------------------[0m
[33m[sultan]: The following are additional information that can be used to debug this exception.[0m
[33m[sultan]: The following is the context used to run:[0m
[33m[sultan]: 	 - cwd: None[0m
[33m[sultan]: 	 - sudo:

In [8]:
from flex.pool import FlexPool
from flex.pool.primitives_tf import init_server_model_tf


# Defining the model
def get_model():
    model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
    return model

flex_pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=init_server_model_tf, model=get_model())

clients = flex_pool.clients
server = flex_pool.servers
aggregator = flex_pool.aggregators

print(f"Number of nodes in the pool {len(flex_pool)}: {len(server)} server plus {len(clients)} clients. The server is also an aggregator")

Number of nodes in the pool 3580: 1 server plus 3579 clients. The server is also an aggregator


In [9]:
from flex.pool.primitives_tf import deploy_server_model_tf

#Select clients
clients_per_round=20
selected_clients_pool = flex_pool.clients.select(clients_per_round)

server.map(deploy_server_model_tf, selected_clients_pool)

In [10]:
from flex.pool.primitives_tf import train_tf

selected_clients_pool.map(train_tf, batch_size=512, epochs=1, verbose=False)

Now, just before sending the client's weights to the aggregator/server, we will randomize the weights of one random client

In [11]:
from flexclash.model import model_poisoner
from flex.model import FlexModel
import numpy as np

randomized_weights_client = selected_clients_pool.select(1)

@model_poisoner
def weight_randomizer(client_model: FlexModel):
    rand_weights = [np.random.randn(*w.shape) for w in client_model["model"].get_weights()]
    client_model["model"].set_weights(rand_weights)
    return client_model

randomized_weights_client.map(weight_randomizer)

INFO:tensorflow:Assets written to: ram://04293d8e-fca8-46b5-aeaa-f6c6a611f361/assets


And then, we collect the weights from the clients

In [12]:
from flex.pool.primitives_tf import collect_clients_weights_tf

# We collect the weights of the both the selected clients and the poisoned client
aggregator.map(collect_clients_weights_tf, selected_clients_pool)
aggregator.map(collect_clients_weights_tf, randomized_weights_client)

### Defending ourselves

Now that our clients have trained their models and sent their weights to the server, we can use the built-in defenses to protect our model from the poisoned weights. In our case, we will use `multikrum` as a defense. This defense is implemented as an aggregation function in the `FLEX-clash` library, so we only need to change the aggregation function in our experiment.

Note that `FLEX-clash` implements more defenses aside from multikrum. In order to get a list of all the implemented defenses, you can refer to the documentation of `FLEX-clash`.

In [13]:
from flexclash.pool.defences import multikrum

aggregator.map(multikrum)