# How to federate datasets using FLEXible

In this notebooks, we show a few of the many ways in which FLEXible can federate a centralized dataset. We will use MNIST and CIFAR10 datasets in this notebooks

First, we download it and shot a few samples:

In [None]:
import tensorflow_datasets as tfds

ds_train, ds_test = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    batch_size=-1,
)

In order to use our tools, we need to encapsulate the dataset in a `FlexDataObject`. 

Note that train_X and train_y are assumed to be NumPy arrays and train_y must be a one dimensional NumPy array.

In [None]:
from flex.data import FlexDataObject

train_dataset = FlexDataObject.from_tfds_dataset(ds_train)
test_dataset = FlexDataObject.from_tfds_dataset(ds_test)

To federate a centralized dataset, it is required to describe the federation process in a `FlexDatasetConfig` object.

A `FlexDatasetConfig` object has the following fields:


- **seed**: Optional[int]
    Seed used to make the federated dataset generated reproducible with this configuration. Default None.
- **n_clients**: Optional[int]
    Number of clients among which to split the centralized dataset. If client_names is also given, we consider the number of clients to be the minimun between n_clients and the length of client_names. Default None.
- **client_names**: Optional[List[Hashable]]
    Names to identifty each client, if not provided clients will be indexed using integers. If n_clients is also given, we consider the number of clients to be the minimun of n_clients and the length of client_names. Default None.
- **weights**: Optional[npt.NDArray], A numpy.array which provides the proportion of data to give to each client. Default None.
- **replacement**: bool, whether the samping procedure used to split a centralized dataset is with replacement or not. Default True
- **classes_per_client**: Optional[Union[int, npt.NDArray, Tuple[int]]], classes to assign to each client, if provided as an int, it is the number classes per client, if provided as a tuple of ints, it establishes a mininum and a maximum of number of classes per client, a random number sampled in such interval decides the number of classes of each client. If provided as a list, it establishes the classes assigned to each client. Default None.
- **features_per_client**: Optional[Union[int, npt.NDArray, Tuple[int]]], Features to assign to each client, it share the same interface as classes_per_client.
- **indexes_per_client**: Optional[npt.NDArray]
    Data indexes to assign to each client, note that this option is incompatible with **classes_per_client**, **features_per_client** options. If replacement and weights are speficied, they are ignored.

    The following table shows the compatiblity of each option:

    | Options compatibility   | **n_clients** | **client_names** | **weights** | **weights_per_class** | **replacement** | **classes_per_client** | **features_per_client** | **indexes_per_client** |
    |-------------------------|---------------|------------------|-------------|-----------------------|-----------------|------------------------|-------------------------|------------------------|
    | **n_clients**           | -             | Y                | Y           | Y                     | Y               | Y                      | Y                       | Y                      |
    | **client_names**        | Y             | -                | Y           | Y                     | Y               | Y                      | Y                       | Y                      |
    | **weights**             | Y             | Y                | -           | N                     | Y               | Y                      | Y                       | N                      |
    | **weights_per_class**   | Y             | Y                | N           | -                     | Y               | Y                      | N                       | N                      |
    | **replacement**         | Y             | Y                | Y           | Y                     | -               | Y                      | Y                       | N                      |
    | **classes_per_client**  | Y             | Y                | Y           | Y                     | Y               | -                      | N                       | N                      |
    | **features_per_client** | Y             | Y                | Y           | N                     | Y               | N                      | -                       | N                      |
    | **indexes_per_client**  | Y             | Y                | N           | N                     | N               | N                      | N                       | -                      |



Let's implement the following description:

We have 10 federated clients, that do not share any instances, each client with data from a single class and with a 20% of the total data available for each class.

In [None]:
from flex.data import FlexDatasetConfig
import numpy as np

config = FlexDatasetConfig(seed = 0) # We fix a seed to make our federation reproducible
config.n_clients = 10 # 10 clients
config.replacement = False # ensure that clients do not share any data
config.classes_per_client = np.unique(train_y) # assign each client one class
config.weights = [0.2] * config.n_clients # each client has only 20% of its assigned class

We apply the generated `FlexDatasetConfig` to a `FlexDataObject`, which encapsulates the centralized dataset.

In [None]:
from flex.data import FlexDataDistribution

federated_dataset = FlexDataDistribution.from_config(cdata=train_dataset, config=config)

Show the federated data, to confirm that the federated split is correct:

In [None]:
for client in federated_dataset:
    print(f"Node {client} has class {np.unique(federated_dataset[client].y_data)} and {len(federated_dataset[client])} elements, a sample of them is:")
    #pyplot.figure(figsize = (1,10))
    fig, ax = plt.subplots(1, 10) # rows, cols
    for i ,(x, y) in enumerate(federated_dataset[client]):
        ax[i].axis('off')
        ax[i].imshow(x, cmap=plt.get_cmap('gray'))
        if i >= 9:
            break
    plt.show()

# Federate a dataset using weights to distribute data following a certain distribution

We try a more special configuration, we want to federate the dataset such that the number of data per client follows a gaussian distribution, consequently, we need to specify weights from a normal distribution:

In [None]:
n_clients = 500
mu, sigma = 100, 1  # mean and standard deviation
normal_weights = np.random.default_rng(seed=0).normal(mu, sigma, n_clients)  # sample random numbers
normal_weights = np.clip(normal_weights, a_min=0, a_max=np.inf)  # remove negative values
normal_weights = normal_weights / sum(normal_weights) # normalize to sum 1

plt.hist(normal_weights, bins=15)
plt.title('Histogram of normal weights')
plt.show()

In [None]:
config = FlexDatasetConfig(seed=0, 
                            n_clients=n_clients,
                            replacement=False,
                            weights=normal_weights
                        )

normal_federated_dataset = FlexDataDistribution.from_config(cdata=train_dataset, config=config)

Plot histogram of data per client:

In [None]:
datasizes_per_client = [len(normal_federated_dataset[client]) for client in normal_federated_dataset]
n, bins, patches = plt.hist(datasizes_per_client)
plt.ylabel('Data sizes')
plt.title('Histogram of data sizes per client')
plt.show()

# A more complex dataset federation

Now, lets federate CIFAR10 that fits the following description from [Personalized Federated Learning using Hypernetworks](https://paperswithcode.com/paper/personalized-federated-learning-using).

First, we sample two/ten classes for each client for CIFAR10/CIFAR100; Next, for each client i and selected class c, we sample $ \alpha_{i,c} \sim U(.4, .6)$, and assign it with $\frac{\alpha_{i,c}}{\sum_j \alpha_{j,c}}$ of the samples for this class. We repeat the above using 10, 50 and 100 clients. This procedure produces clients with different number of samples and classes.

1) We download the cifar10 dataset using torchivision and create a FlexDataObject with it using ``from_torchvision_dataset``. Note that, it is mandatory to at least provide the ``ToTensor`` transform.

In [None]:
from torchvision import datasets, transforms
from flex.data import FlexDataObject

cifar10 = datasets.CIFAR10(
        root=".",
        train=True,
        download=True
)
dataset = FlexDataObject.from_torchvision_dataset(cifar10)

2) Create a ``FlexDatasetConfig`` that fits the description given above.

In [None]:
from flex.data import FlexDatasetConfig
import numpy as np

# Sample two/ten classes for each client
config = FlexDatasetConfig(seed=0)
config.classes_per_client = (2, 10)
config.replacement = True # it is not clear whether clients share their data or not
config.n_clients = 10
num_classes = 10

# Assign a sample proportion for each client-class pair
alphas = np.random.uniform(0.4, 0.6, [config.n_clients, num_classes])
alphas = alphas / np.sum(alphas, axis=0)
config.weights_per_class = alphas

3) Create the federated dataset by applying the created ``FlexDatasetConfig``to a ``FlexDataObject`` using ``FlexDataDistribution.from_config``

In [None]:
from flex.data import FlexDataDistribution

personalized_cifar_dataset = FlexDataDistribution.from_config(cdata=dataset, config=config)

If we want, we can normalize the dataset of each client easily, using the `map` function from `FlexDataset`, for example we force each client to keep only pair labels:

In [None]:
import numpy as np

rng = np.random.default_rng(seed=0)
def keep_given_labels(client_dataset: FlexDataObject, selected_labels=None):
    if not selected_labels:
        selected_labels = []
    X_data = client_dataset.X_data[np.isin(client_dataset.y_data, selected_labels)]
    y_data = client_dataset.y_data[np.isin(client_dataset.y_data, selected_labels)]
    return FlexDataObject(X_data=X_data, y_data=y_data)

randomly_transformed_federated_dataset = normal_federated_dataset.map(func=keep_given_labels,  # function to apply
                                                num_proc=1,
                                                selected_labels=[0, 2, 4, 6, 8] # argument for function
                                                )

for i, client in enumerate(randomly_transformed_federated_dataset):
    print(f"Client {client} has classes {np.unique(randomly_transformed_federated_dataset[client].y_data)} and {len(randomly_transformed_federated_dataset[client])} elements, a sample of them is:")
    fig, ax = plt.subplots(1, 10) # rows, cols
    for j ,(x, y) in enumerate(randomly_transformed_federated_dataset[client]):
        ax[j].axis('off')
        ax[j].imshow(x, cmap=plt.get_cmap('gray'))
        if j >= 9:
            break
    if i >= 10:
        break
    plt.show()

### END
Congratulations, now you know how to federate a dataset using the *FlexDataDistribution* and the *FlexDatasetConfig* classes, so you can setup multiple experimental settings that fit most your hipothesis.