# Example tensorflow federated Notebook

Important notes so that you can actually run the notebook
```
Python 3.9 is required to run the notebook (Warning, 3.8 doesn't work!)
the requirements are
    - tensorflow
    - tensorflow-federated
    - juypter
```

To run **TFF** in the notebook, it is required to run those two lines (this is notebook-specific)
```
import nest_asyncio
nest_asyncio.apply()
```

#### Imports

In [11]:
import tensorflow as tf
import tensorflow_federated as tff
import matplotlib.pyplot as plt

""" These two lines are required to make TFF work in a notebook!!!! """
import nest_asyncio
nest_asyncio.apply()

#### Example training on MNIST data (from the TFF homepage)

In [24]:
# Load simulation data.
source, _ = tff.simulation.datasets.emnist.load_data()
def client_data(n):
  return source.create_tf_dataset_for_client(source.client_ids[n]).map(
      lambda e: (tf.reshape(e['pixels'], [-1]), e['label'])
  ).repeat(10).batch(20)

# Pick a subset of client devices to participate in training.
train_data = [client_data(n) for n in range(3)]

# Wrap a Keras model for use with TFF.
def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(10, tf.nn.softmax, input_shape=(784,),
                            kernel_initializer='zeros')
  ])
  return tff.learning.from_keras_model(
      model,
      input_spec=train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.build_federated_averaging_process(
  model_fn,
  client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
state = trainer.initialize()
_states = []
_metrics = []
T = 40
for t in range(T):
    print("progress = {}".format(t/T))
    state, metrics = trainer.next(state, train_data)
    _states.append(state)
    _metrics.append(metrics)

progress = 0.0
progress = 0.025
progress = 0.05
progress = 0.075
progress = 0.1
progress = 0.125
progress = 0.15
progress = 0.175
progress = 0.2
progress = 0.225
progress = 0.25
progress = 0.275
progress = 0.3
progress = 0.325
progress = 0.35
progress = 0.375
progress = 0.4
progress = 0.425
progress = 0.45
progress = 0.475
progress = 0.5
progress = 0.525
progress = 0.55
progress = 0.575
progress = 0.6
progress = 0.625
progress = 0.65
progress = 0.675
progress = 0.7
progress = 0.725
progress = 0.75
progress = 0.775
progress = 0.8
progress = 0.825
progress = 0.85
progress = 0.875
progress = 0.9
progress = 0.925
progress = 0.95
progress = 0.975


In [26]:
loss = [m['train']['loss'] for m in metrics]
plt.plot(loss)

13.897555

In [16]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [20]:
emnist_train.element_type_structure

OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)),
             ('pixels',
              TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])