In [2]:
!pip install tensorflow_federated

Collecting tensorflow_federated
  Downloading tensorflow_federated-0.87.0-py3-none-manylinux_2_31_x86_64.whl.metadata (19 kB)
Collecting attrs~=23.1 (from tensorflow_federated)
  Downloading attrs-23.2.0-py3-none-any.whl.metadata (9.5 kB)
Collecting dm-tree==0.1.8 (from tensorflow_federated)
  Downloading dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting dp-accounting==0.4.3 (from tensorflow_federated)
  Downloading dp_accounting-0.4.3-py3-none-any.whl.metadata (1.8 kB)
Collecting google-vizier==0.1.11 (from tensorflow_federated)
  Downloading google_vizier-0.1.11-py3-none-any.whl.metadata (10 kB)
Collecting jaxlib==0.4.14 (from tensorflow_federated)
  Downloading jaxlib-0.4.14-cp311-cp311-manylinux2014_x86_64.whl.metadata (2.0 kB)
Collecting jax==0.4.14 (from tensorflow_federated)
  Downloading jax-0.4.14.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m8.6 MB/s[0m eta [36m0:00:00

In [8]:
# Example code structure to get started
import tensorflow as tf
import tensorflow_federated as tff

# Load and preprocess MNIST data
def preprocess_data():
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train / 255.0
    x_test = x_test / 255.0
    return (x_train, y_train), (x_test, y_test)

# Create a simple CNN model
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(64, 3, activation='relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

In [2]:
# Create client data (simulating distributed data)
def create_client_data(x_train, y_train, num_clients=10):
    # Split data among clients (non-IID split)
    client_data = []
    samples_per_client = len(x_train) // num_clients

    for i in range(num_clients):
        start_idx = i * samples_per_client
        end_idx = (i + 1) * samples_per_client
        client_data.append({
            'x': x_train[start_idx:end_idx],
            'y': y_train[start_idx:end_idx]
        })
    return client_data

# Implement FedAvg algorithm
def federated_averaging(client_weights):
    # Average weights from all clients
    avg_weights = []
    for weights_list_tuple in zip(*client_weights):
        avg_weights.append(
            tf.reduce_mean(tf.stack(weights_list_tuple, axis=0), axis=0)
        )
    return avg_weights

In [3]:
# Implement differential privacy
def add_noise(gradients, noise_multiplier=1.0, l2_norm_clip=1.0):
    # Clip gradients
    grads_flat = tf.concat([tf.reshape(g, [-1]) for g in gradients], axis=0)
    grad_norm = tf.norm(grads_flat)
    clip_norm = tf.maximum(grad_norm / l2_norm_clip, 1.)
    gradients = [g / clip_norm for g in gradients]

    # Add noise
    noise_stddev = noise_multiplier * l2_norm_clip
    noisy_gradients = [
        g + tf.random.normal(g.shape, stddev=noise_stddev)
        for g in gradients
    ]
    return noisy_gradients

In [4]:
# Implement model compression
def compress_weights(weights, compression_ratio=0.1):
    # Example: Simple magnitude-based pruning
    flat_weights = tf.concat([tf.reshape(w, [-1]) for w in weights], axis=0)
    threshold = tf.sort(tf.abs(flat_weights))[
        int(len(flat_weights) * (1 - compression_ratio))
    ]
    return [tf.where(tf.abs(w) > threshold, w, 0.) for w in weights]

# Add monitoring
def monitor_training(metrics, round_number):
    print(f"Round {round_number}")
    print(f"Training accuracy: {metrics['accuracy']:.4f}")
    print(f"Loss: {metrics['loss']:.4f}")

In [5]:
# Implement client selection strategy
def select_clients(clients, round_number, selection_fraction=0.3):
    num_clients = max(1, int(len(clients) * selection_fraction))
    # Strategy: Select clients with most data first
    selected_clients = sorted(
        clients,
        key=lambda c: len(c['x']),
        reverse=True
    )[:num_clients]
    return selected_clients

# Handle imbalanced data
def balance_client_data(client_data):
    min_samples = min(len(data['x']) for data in client_data)
    balanced_data = []
    for data in client_data:
        indices = tf.random.shuffle(tf.range(len(data['x'])))[:min_samples]
        balanced_data.append({
            'x': tf.gather(data['x'], indices),
            'y': tf.gather(data['y'], indices)
        })
    return balanced_data

In [6]:
# Implement local fine-tuning
def personalize_model(global_model, client_data, epochs=5):
    model = tf.keras.models.clone_model(global_model)
    model.set_weights(global_model.get_weights())

    # Fine-tune on client's data
    model.fit(
        client_data['x'],
        client_data['y'],
        epochs=epochs,
        verbose=0
    )
    return model