[Reference](https://medium.com/biased-algorithms/a-step-by-step-guide-to-federated-learning-in-computer-vision-0984e4a7f8d5)

# Step 1: Setting Up the Environment

In [1]:
!pip install tensorflow tensorflow-federated

Collecting tensorflow-federated
  Downloading tensorflow_federated-0.87.0-py3-none-manylinux_2_31_x86_64.whl.metadata (19 kB)
Collecting attrs~=23.1 (from tensorflow-federated)
  Downloading attrs-23.2.0-py3-none-any.whl.metadata (9.5 kB)
Collecting dp-accounting==0.4.3 (from tensorflow-federated)
  Downloading dp_accounting-0.4.3-py3-none-any.whl.metadata (1.8 kB)
Collecting google-vizier==0.1.11 (from tensorflow-federated)
  Downloading google_vizier-0.1.11-py3-none-any.whl.metadata (10 kB)
Collecting jaxlib==0.4.14 (from tensorflow-federated)
  Downloading jaxlib-0.4.14-cp310-cp310-manylinux2014_x86_64.whl.metadata (2.0 kB)
Collecting jax==0.4.14 (from tensorflow-federated)
  Downloading jax-0.4.14.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml

In [2]:
!pip install torch syft

Collecting syft
  Downloading syft-0.9.2-py2.py3-none-any.whl.metadata (17 kB)
Collecting typing-extensions>=4.8.0 (from torch)
  Downloading typing_extensions-4.12.2-py3-none-any.whl.metadata (3.0 kB)
Collecting bcrypt==4.1.2 (from syft)
  Downloading bcrypt-4.1.2-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (9.5 kB)
Collecting boto3==1.34.56 (from syft)
  Downloading boto3-1.34.56-py3-none-any.whl.metadata (6.6 kB)
Collecting forbiddenfruit==0.1.4 (from syft)
  Downloading forbiddenfruit-0.1.4.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.8/43.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting packaging>=23.0 (from syft)
  Downloading packaging-24.1-py3-none-any.whl.metadata (3.2 kB)
Collecting pycapnp==2.0.0 (from syft)
  Downloading pycapnp-2.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)
Collecting pydantic==2.6.0 (from pydantic[email]==2.6.0->syf

# Step 2: Preparing the Dataset

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import cifar10

# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Function to partition the dataset into non-IID subsets
def create_noniid_partitions(x_data, y_data, num_clients=10):
    # Create non-IID partitions by randomly selecting class-specific samples for each client
    partitions = []
    num_classes = len(np.unique(y_data))

    for i in range(num_clients):
        indices = np.array([], dtype=int)
        for c in range(num_classes):
            # Select a subset of samples from each class (non-uniformly)
            class_indices = np.where(y_data == c)[0]
            chosen_indices = np.random.choice(class_indices, size=int(len(class_indices) / num_clients), replace=False)
            indices = np.concatenate([indices, chosen_indices])

        partitions.append((x_data[indices], y_data[indices]))

    return partitions

# Simulate 10 clients with non-IID data
client_partitions = create_noniid_partitions(x_train, y_train, num_clients=10)



# Step 3: Defining the Model

In [1]:
import tensorflow as tf
from tensorflow.keras import layers

def create_cnn_model():
    model = tf.keras.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(10, activation='softmax')  # 10 classes for CIFAR-10
    ])
    return model

ERROR:jax._src.xla_bridge:Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 438, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/usr/local/lib/python3.10/dist-packages/jax_plugins/xla_cuda12/__init__.py", line 85, in initialize
    options = xla_client.generate_pjrt_gpu_plugin_options()
AttributeError: module 'jaxlib.xla_client' has no attribute 'generate_pjrt_gpu_plugin_options'


# Step 4: Training the Model Locally

In [2]:
def train_local_model(model, train_data, train_labels, epochs=5, batch_size=32):
    # Compile the model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train the model locally
    model.fit(train_data, train_labels, epochs=epochs, batch_size=batch_size, verbose=1)

    return model

# Step 5: Aggregating Model Updates

In [3]:
def aggregate_weights(client_weights):
    # Initialize an empty list to store the averaged weights
    avg_weights = []

    # Zip all weights together and compute the mean for each layer
    for weights in zip(*client_weights):
        avg_weights.append(np.mean(weights, axis=0))

    return avg_weights

# Step 6: Model Validation and Testing

In [4]:
def validate_model(model, x_val, y_val):
    loss, accuracy = model.evaluate(x_val, y_val, verbose=0)
    print(f"Validation Loss: {loss}")
    print(f"Validation Accuracy: {accuracy}")

# Step 7: Deployment of the Federated Model

In [6]:
# Convert the model to TensorFlow Lite format
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the converted model
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

In [7]:
# Convert to a quantized TensorFlow Lite model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_model = converter.convert()

# Save the quantized model
with open('quantized_model.tflite', 'wb') as f:
    f.write(quantized_model)

In [8]:
def weighted_average(client_weights, client_sizes):
    total_size = sum(client_sizes)
    weighted_updates = [
        w * (size / total_size) for w, size in zip(client_weights, client_sizes)
    ]
    return sum(weighted_updates) / len(client_weights)

In [9]:
# Train for multiple epochs locally before sending updates
def federated_training(client_model, local_data, num_local_epochs=5):
    for epoch in range(num_local_epochs):
        client_model.fit(local_data, epochs=1)
    return client_model.get_weights()

In [10]:
class TrainingLogger(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"Epoch {epoch}: loss = {logs['loss']}, accuracy = {logs['accuracy']}")

# Apply logger during local training
client_model.fit(local_data, epochs=5, callbacks=[TrainingLogger()])