In [3]:
!pip install tensorflow-federated


Collecting tensorflow-federated
  Downloading tensorflow_federated-0.87.0-py3-none-manylinux_2_31_x86_64.whl.metadata (19 kB)
Collecting attrs~=23.1 (from tensorflow-federated)
  Downloading attrs-23.2.0-py3-none-any.whl.metadata (9.5 kB)
Collecting dm-tree==0.1.8 (from tensorflow-federated)
  Downloading dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting dp-accounting==0.4.3 (from tensorflow-federated)
  Downloading dp_accounting-0.4.3-py3-none-any.whl.metadata (1.8 kB)
Collecting google-vizier==0.1.11 (from tensorflow-federated)
  Downloading google_vizier-0.1.11-py3-none-any.whl.metadata (10 kB)
Collecting jaxlib==0.4.14 (from tensorflow-federated)
  Downloading jaxlib-0.4.14-cp311-cp311-manylinux2014_x86_64.whl.metadata (2.0 kB)
Collecting jax==0.4.14 (from tensorflow-federated)
  Downloading jax-0.4.14.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m28.3 MB/s[0m eta [36m0:00:00

In [1]:
import tensorflow as tf
import tensorflow_federated as tff
import collections


ERROR:jax._src.xla_bridge:Jax plugin configuration error: Plugin module %s could not be loaded
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/xla_bridge.py", line 428, in discover_pjrt_plugins
    plugin_module = importlib.import_module(plugin_module_name)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_r

In [6]:
# Load a small federated dataset
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

print("Clients:", emnist_train.client_ids[:5])


Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:44<00:00, 3811020.27it/s]


Clients: ['f0000_14', 'f0001_41', 'f0005_26', 'f0006_12', 'f0008_45']


In [7]:
# Sample a few client IDs
client_ids = emnist_train.client_ids[:5]
print("Example client IDs:", client_ids)

Example client IDs: ['f0000_14', 'f0001_41', 'f0005_26', 'f0006_12', 'f0008_45']


## Looking for client's data

In [8]:
# Load dataset for one client
example_dataset = emnist_train.create_tf_dataset_for_client(client_ids[0])

# Inspect sample
for example in example_dataset.take(1):
    print(example)


OrderedDict([('label', <tf.Tensor: shape=(), dtype=int32, numpy=1>), ('pixels', <tf.Tensor: shape=(28, 28), dtype=float32, numpy=
array([[1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 

## Preprocess

In [9]:
import tensorflow as tf

BATCH_SIZE = 20
NUM_EPOCHS = 1

def preprocess(dataset):
    def batch_format_fn(element):
        return (
            tf.expand_dims(element['pixels'], -1),  # shape: (28, 28, 1)
            element['label']
        )
    return (dataset
            .map(batch_format_fn)
            .shuffle(buffer_size=100)
            .batch(BATCH_SIZE)
            .repeat(NUM_EPOCHS))


In [12]:
def make_federated_data(client_data, client_ids):
    return [preprocess(client_data.create_tf_dataset_for_client(x)) for x in client_ids]


In [27]:
def create_keras_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(32, kernel_size=3, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

def model_fn():
    keras_model = create_keras_model()
    # Define the input specification manually
    input_spec = (
        tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),  # x (features)
        tf.TensorSpec(shape=(None,), dtype=tf.int32)              # y (labels)
    )
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=input_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

In [21]:
import tensorflow_federated as tff
import tensorflow as tf
print(f"TFF Version: {tff.__version__}")
print(f"TF Version: {tf.__version__}")

TFF Version: 0.87.0
TF Version: 2.14.1


In [30]:
federated_train_data = [
    preprocess(emnist_train.create_tf_dataset_for_client(client_id))
    for client_id in client_ids
]

In [31]:
import tensorflow_federated as tff
import tensorflow as tf

# This should work with TFF 0.87.0
iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn,
    # Use TFF's optimizer builder instead of a Keras optimizer directly.
    client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=0.02),
)
state = iterative_process.initialize()

# The key difference - handle the new return type
for round_num in range(1, 6):
    result = iterative_process.next(state, federated_train_data)
    state = result.state  # Extract state from result object
    metrics = result.metrics  # Extract metrics from result object
    print(f'Round {round_num}, Metrics={metrics}')

Round 1, Metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.09166667), ('loss', 2.38833), ('num_examples', 480), ('num_batches', 26)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
Round 2, Metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.08958333), ('loss', 2.3750854), ('num_examples', 480), ('num_batches', 26)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
Round 3, Metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.083333336), ('loss', 2.3588126), ('num_examples', 480), ('num_batches', 26)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', 