## This is a demonstration of FlexDataset usage

First and foremost, we download the dataset we want to federate in this case MNIST

In [None]:
from tensorflow.keras.datasets import mnist

(train_X, train_y), (test_X, test_y) = mnist.load_data()

Show some samples

In [None]:
from matplotlib import pyplot as plt

for i in range(9):  
    plt.subplot(330 + 1 + i)
    plt.axis('off') 
    plt.imshow(train_X[i], cmap=plt.get_cmap('gray'))
plt.show()

We encapsulate the chosen dataset in a `FlexDataObject`

In [None]:
from flex.data import FlexDataObject

train_dataset = FlexDataObject(X_data=train_X, y_data=train_y)
test_dataset = FlexDataObject(X_data=test_X, y_data=test_y)

Now, we create a configuration, `FlexDatasetConfig` to federate our dataset, especifically:

    - We want to split the dataset between 10 clients.
    - Each client will have only one class.
    - Each client has only 20% of its assigned class.
    - Clients do not share classes between them.

In [None]:
from flex.data import FlexDatasetConfig
import numpy as np

config = FlexDatasetConfig(seed = 0) # We fix a seed to make our federation reproducible
config.n_clients = 10 # 10 clients
config.replacement = False # ensure that clients do not share any data
config.classes_per_client = np.unique(train_y) # assign each client one class
config.weights = [0.2] * config.n_clients # each client has only 20% of its assigned class

We apply the generated `FlexDatasetConfig` to a `FlexDataObject`, which encapsulates the centralized dataset.

In [None]:
from flex.data import FlexDataDistribution

federated_dataset = FlexDataDistribution.from_config(cdata=train_dataset, config=config)

Show the federated data

In [None]:
for client in federated_dataset:
    print(f"Client {client} has class {np.unique(federated_dataset[client].y_data)} and {len(federated_dataset[client])} elements, a sample of them is:")
    #pyplot.figure(figsize = (1,10))
    fig, ax = plt.subplots(1, 10) # rows, cols
    for i ,(x, y) in enumerate(federated_dataset[client]):
        ax[i].axis('off')
        ax[i].imshow(x, cmap=plt.get_cmap('gray'))
        if i >= 9:
            break
    plt.show()

Now, we try a more special configuration, we want to federate the dataset such that the number of data per client follows a gaussian distribution:

In [None]:
n_clients = 500
mu, sigma = 100, 1  # mean and standard deviation
normal_weights = np.random.default_rng(seed=0).normal(mu, sigma, n_clients)  # sample random numbers
normal_weights = np.clip(normal_weights, a_min=0, a_max=np.inf)  # remove negative values
normal_weights = normal_weights / sum(normal_weights) # normalize to sum 1

plt.hist(normal_weights, bins=15)
plt.title('Histogram of normal weights')
plt.show()

In [None]:
config = FlexDatasetConfig(seed=0, 
                            n_clients=n_clients,
                            replacement=False,
                            weights=normal_weights
                        )

normal_federated_dataset = FlexDataDistribution.from_config(cdata=train_dataset, config=config)

Plot histogram of data per client:

In [None]:
datasizes_per_client = [len(normal_federated_dataset[client]) for client in normal_federated_dataset]
n, bins, patches = plt.hist(datasizes_per_client)
plt.ylabel('Data sizes')
plt.title('Histogram of data sizes per client')
plt.show()

If we want, we can normalize the dataset of each client easily, using the `map` function from `FlexDataset`, for example we force each client to keep only pair labels:

In [None]:
import numpy as np

rng = np.random.default_rng(seed=0)
def keep_given_labels(client_dataset: FlexDataObject, selected_labels): # haz aquí otra operación que se te ocurra raruna
    client_dataset.X_data = client_dataset.X_data[np.isin(client_dataset.y_data, selected_labels)]
    client_dataset.y_data = client_dataset.y_data[np.isin(client_dataset.y_data, selected_labels)]
    return client_dataset

randomly_transformed_federated_dataset = normal_federated_dataset.map(None,  # Apply to all clients
                                                8,  # number of parallel processes 
                                                keep_given_labels,  # function to apply 
                                                [0, 2, 4, 6, 8] # argument for function
                                                )

for client in randomly_transformed_federated_dataset:
    print(f"Client {client} has classes {np.unique(randomly_transformed_federated_dataset[client].y_data)} and {len(randomly_transformed_federated_dataset[client])} elements, a sample of them is:")
    fig, ax = plt.subplots(1, 10) # rows, cols
    for i ,(x, y) in enumerate(randomly_transformed_federated_dataset[client]):
        ax[i].axis('off')
        ax[i].imshow(x, cmap=plt.get_cmap('gray'))
        if i >= 9:
            break
    if client >= 10:
        break
    plt.show()