In [5]:
import tensorflow as tf
import numpy as np

In [6]:
def load_and_partition_data(num_clients=3):
    # Load MNIST dataset
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Normalize and reshape the data
    x_train = x_train.astype("float32") / 255.0
    x_train = np.expand_dims(x_train, -1)
    x_test = x_test.astype("float32") / 255.0
    x_test = np.expand_dims(x_test, -1)

    # Create a partition for each client
    partition_size = len(x_train) // num_clients
    partitions = []
    for i in range(num_clients):
        start = i * partition_size
        end = start + partition_size
        partitions.append((x_train[start:end], y_train[start:end]))

    return partitions, (x_test, y_test)

In [7]:
# ---- Label Flipping Attack Settings ----
# CLient 2 is the attacker
ATTACKER_ID = 1
# the label you want to flip FROM (1 -> 7 in this case as they are the closest)
SOURCE_LABEL = 1
# the label you want to flip TO
TARGET_LABEL = 7
# 1.0 = flip all labels; 0.3 = flip only 30%
FLIP_FRACTION = 1.0


In [9]:
client_partitions, test_set = load_and_partition_data(num_clients=3)

# Save each client's partition to a separate file
for i, client_data in enumerate(client_partitions):
    x_client = client_data[0]
    y_client = client_data[1].copy()

    # Apply label flipping attack ONLY for the attacker client
    if i == ATTACKER_ID:
        print(f"Client {i+1} is the ATTACKER. Performing label flipping...")

        # Find all indices with SOURCE_LABEL
        indices = np.where(y_client == SOURCE_LABEL)[0]

        # Flip only a fraction of them
        num_to_flip = int(len(indices) * FLIP_FRACTION)
        flip_indices = np.random.choice(indices, size=num_to_flip, replace=False)

        # Perform the flip
        y_client[flip_indices] = TARGET_LABEL

        print(f"  Flipped {num_to_flip} labels from {SOURCE_LABEL} → {TARGET_LABEL}")

    # Save the (possibly modified) dataset
    file_path = f"flipped_client_{i+1}_data.npz"
    np.savez(file_path, x=x_client, y=y_client)
    print(f"Saved {file_path} with {len(x_client)} samples.")


Saved flipped_client_1_data.npz with 20000 samples.
Client 2 is the ATTACKER. Performing label flipping...
  Flipped 2282 labels from 1 → 7
Saved flipped_client_2_data.npz with 20000 samples.
Saved flipped_client_3_data.npz with 20000 samples.
