# FLEX-clash: how to backdoor data

In this notebook, we will show how easy is to generate a backdoor task for a client using flexclash. Particularly, we show the usage of the decorator `@backdoor_data`, that allow us to easily modify the data of a client dataset, before a FlexPool is created.

In our first example, we will load the federated EMNIST dataset and modify the dataset of a few clients by changing the bottom right cpixel of every image to white and label it as 0.

As always, the first step is creating/loading a `FlexDataset`, which represents a federated dataset

In [None]:
from flex.datasets import load

fed_emnist = load("federated_emnist", return_test=False, split="digits")

We select the clients we will backdoor

In [None]:
client_ids = list(fed_emnist.keys())
clients_to_backdoor = client_ids[:10]

Define the modification of the dataset wa want to perform and decorate it with `@backdoor_data`, in our case we change one pixel to white and label these images with label 0

In [None]:
import numpy as np
from flexclash.data import backdoor_data

@backdoor_data
def poison(img_array, label):
    new_label = 0
    img_array[-1,-1] = 255 # white pixel
    return img_array, new_label

Now, we apply our backdoor function to a set of clients

In [None]:
fed_emnist = fed_emnist.map(poison, clients_ids=clients_to_backdoor)

By inspecting a few samples of the dataset of the backdoored clients, we can observe the injected backdoor task

In [None]:
import matplotlib.pyplot as plt

for client in clients_to_backdoor:
    backdoored_dataset = fed_emnist[client]
    fig, ax = plt.subplots(1, 1) # rows, cols
    for x, y in backdoored_dataset:
        ax.set_title(f"Sample from client {client}, label {y}")
        ax.axis('off')
        ax.imshow(x, cmap=plt.get_cmap('gray'))
        break
    plt.show()

## Another backdoor
Easy right? Now let's try something more complex, as complex as my love for cats, I see them everywhere.

We are going to label every red car as cats for some clients with a portion of CIFAR 10 dataset. Note that we will import CIFAR 10 using the torchvision package

In [None]:
from torchvision import datasets

cifar10 = datasets.CIFAR10(
        root=".",
        train=True,
        download=True,
        transform=None,
)

from flex.data import FedDatasetConfig, FedDataDistribution, Dataset

config = FedDatasetConfig(seed=0)
config.replacement = False
config.n_clients = 100

cifar10_wrapped = Dataset.from_torchvision_dataset(cifar10)

fed_cifar = FedDataDistribution.from_config(
                centralized_data=cifar10_wrapped, 
                config=config
            )

Once we have loaded and federated the CIFAR 10 dataset (That was fast, right?), let's define our poisoning procedure

In [None]:
import numpy as np
from flexclash.data import backdoor_data

cat_label = 3
car_label = 1

@backdoor_data
def red_cars_as_cats(img, label):
    if label == car_label:
        car_array = np.array(img)
        sum_red_only = car_array[:,:,0].sum()
        threshold = car_array[:,:,1:].sum()
        if sum_red_only > threshold: # a red car has been found :)
            return img, cat_label
    # unmodified
    return img, label

We choose some clients and label their red cars as cats

In [None]:
client_ids = list(fed_cifar.keys())
catified_clients = client_ids[:10]

fed_cifar = fed_cifar.map(red_cars_as_cats, clients_ids=catified_clients)

Let's enjoy the evil of seeing red cars as cats:

In [None]:
import matplotlib.pyplot as plt

for client in catified_clients:
    backdoored_dataset = fed_cifar[client]
    for x, y in backdoored_dataset:
        if y == cat_label:
            car_array = np.array(x)
            sum_all_colors = np.sum(car_array)
            sum_red_only = car_array[:,:,0].sum()
            threshold = car_array[:,:,1:].sum()
            if sum_red_only > threshold:
                fig, ax = plt.subplots(1, 1) # rows, cols
                ax.set_title(f"Sample from client {client}, label {y}")
                ax.axis('off')
                ax.imshow(x)
                plt.show()
                break