# **Setting up environment**

In [None]:
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()

In [None]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

In [None]:
@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()

# **Federated Data**

## Defining TFF the type of a **federated float** hosted by a group of client devices

In [None]:
federated_float_on_clients = tff.type_at_clients(tf.float32)                      # Constructs a federated type of the form tf.float32@CLIENTS.
print(federated_float_on_clients)
print(str(federated_float_on_clients.member))
print(str(federated_float_on_clients.placement))
print(federated_float_on_clients.all_equal)

In [None]:
simple_regression_model_type = (
    tff.StructType([('a', tf.float32), ('b', tf.float32)]))

str(simple_regression_model_type)

In [None]:
str(tff.type_at_clients(
    simple_regression_model_type, all_equal=True))

# **Placements**

## **Design Overview**

## **Specifying Placements**

# **Federated Computations**

## **Declaring federated computations**

In [None]:
@tff.federated_computation(tff.type_at_clients(tf.float32))
def get_average_temperature(sensor_readings):                                     # a computation that calculates the average of the temperatures reported by the sensor array
  return tff.federated_mean(sensor_readings)

In [None]:
str(get_average_temperature.type_signature)                                       # Returns the TFF type of the object

## **Executing federated computations**

In [None]:
get_average_temperature([68.5, 70.3, 69.8])

In [None]:
@tff.federated_computation(tff.type_at_clients(tf.float32))
def get_average_temperature(sensor_readings):

  print ('Getting traced, the argument is "{}".'.format(
      type(sensor_readings).__name__))

  return tff.federated_mean(sensor_readings)

## **Composing federated computations**

# **TensorFlow Logic**

## **Declaring TensorFlow computations**

In [None]:
@tff.tf_computation(tf.float32)                                                   # a function that takes a number and adds 0.5 to it.
def add_half(x):
  return tf.add(x, 0.5)

str(add_half.type_signature)

In [None]:
@tff.federated_computation(tff.type_at_clients(tf.float32))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)                                           # Maps a federated value pointwise using a mapping function.

str(add_half_on_clients.type_signature)

## **Executing TensorFlow computations**

In [None]:
add_half_on_clients([1.0, 3.0, 2.0])

In [None]:
try:

  # Eager mode
  constant_10 = tf.constant(10.)

  @tff.tf_computation(tf.float32)
  def add_ten(x):
    return x + constant_10

except Exception as err:
  print (err)

In [None]:
def get_constant_10():
  return tf.constant(10.)

@tff.tf_computation(tf.float32)
def add_ten(x):
  return x + get_constant_10()

add_ten(5.0)

## **Working with `tf.data.Dataset`s**

In [None]:
float32_sequence = tff.SequenceType(tf.float32)

print(str(float32_sequence))

int32_sequence = tff.SequenceType(tf.int32)

print(str(int32_sequence))

In [None]:
@tff.tf_computation(tff.SequenceType(tf.float32))                                 # inputs sequence with each element being float32 type
def get_local_temperature_average(local_temperatures):                            # calculates the average of temperatures in a single local data set using the tf.data.Dataset.reduce operator
  sum_and_count = (
      local_temperatures.reduce((0.0, 0), lambda x, y: (x[0] + y, x[1] + 1)))     # Applies function inside reduce to all the elements of local_temperature. Initializes x=(0.0,0) then x[0]+y caculates the sum of temperatures and x[1]+1 calculates the count
  return sum_and_count[0] / tf.cast(sum_and_count[1], tf.float32)

str(get_local_temperature_average.type_signature)

In [None]:
@tff.tf_computation(tff.SequenceType(tf.int32))
def foo(x):
  return x.reduce(np.int32(0), lambda x, y: x + y)

foo([1, 2, 3])

In [None]:
get_local_temperature_average([68.5, 70.3, 69.8])

In [None]:
@tff.tf_computation(tff.SequenceType(collections.OrderedDict(
    [('A', tf.int32),('B', tf.int32)])))                                          # declare a computation that accepts a sequence of pairs `A`, `B`
def foo(ds):
  print('element_structure = {}'.format(ds.element_spec))                         # element_spec allows to inspect the type of each element component
  return ds.reduce(np.int32(0), lambda total, x: total + x['A'] * x['B'])         # Returns sum of their (A,B) products

print(str(foo.type_signature))

foo([{'A': 2, 'B': 3}, {'A': 4, 'B': 5}])

# **Putting it all together**

In [None]:
@tff.federated_computation(
    tff.type_at_clients(tff.SequenceType(tf.float32)))
def get_global_temperature_average(sensor_readings):
  return tff.federated_mean(
      tff.federated_map(get_local_temperature_average, sensor_readings))          # maps the function get_local_temperature_average to sensor_readings

print(str(get_global_temperature_average.type_signature))

In [None]:
get_global_temperature_average([[68.0, 70.0], [71.0], [68.0, 72.0, 70.0]])

## Exercise

We leave it as an exercise for the reader to update the above code; the `tff.federated_mean` operator accepts the weight as an optional second argument (expected to be a federated float).

In [None]:
@tff.tf_computation(tff.SequenceType(tf.float32))      
def wt(local_temperatures):                            
  c = (local_temperatures.reduce(0, lambda x, y: x + 1))     
  return tf.cast(c, tf.float32)                                                   # type casting to float as the get_global_temperature_average2 required float32 input

In [None]:
@tff.federated_computation(
    tff.type_at_clients(tff.SequenceType(tf.float32)))
def get_global_temperature_average2(sensor_readings):
  return tff.federated_mean(tff.federated_map(get_local_temperature_average, sensor_readings), tff.federated_map(wt,sensor_readings))

In [None]:
get_global_temperature_average2([[68.0, 70.0], [71.0], [68.0, 72.0, 70.0]])