In [1]:
!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()

In [20]:
import jax
import numpy as np
import tensorflow_federated as tff

In [4]:
tff.__version__

'0.39.0'

In [5]:
@tff.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)



In [6]:
comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')

'xla'

In [7]:
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())


HloModule xla_computation_add_numbers.0, entry_computation_layout={((s32[], s32[]))->(s32[])}

ENTRY main.6 {
  arg_tuple.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(arg_tuple.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(arg_tuple.1), index=1
  add.4 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.5 = (s32[]) tuple(add.4)
}




In [8]:
tff.backends.xla.set_local_python_execution_context()


In [9]:
add_numbers(2, 3)


5

In [10]:
import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

In [11]:
def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

In [32]:
from types import ModuleType
def recursive_dir(module):
    current = {}
    if not isinstance(module, ModuleType):
        return current
    for name in dir(module):
        current[name] = recursive_dir(name)
    return current

In [33]:
res =recursive_dir(tff.learning)

{'BatchOutput': {},
 'ClientWeighting': {},
 'MetricFinalizersType': {},
 'Model': {},
 'ModelWeights': {},
 '__builtins__': {},
 '__cached__': {},
 '__doc__': {},
 '__file__': {},
 '__loader__': {},
 '__name__': {},
 '__package__': {},
 '__path__': {},
 '__spec__': {},
 'add_debug_measurements': {},
 'add_debug_measurements_with_mixed_dtype': {},
 'algorithms': {},
 'build_federated_evaluation': {},
 'build_local_evaluation': {},
 'build_personalization_eval': {},
 'client_weight_lib': {},
 'compression_aggregator': {},
 'ddp_secure_aggregator': {},
 'debug_measurements': {},
 'deprecation': {},
 'dp_aggregator': {},
 'entropy_compression_aggregator': {},
 'federated_aggregate_keras_metric': {},
 'federated_evaluation': {},
 'framework': {},
 'from_keras_model': {},
 'keras_utils': {},
 'metrics': {},
 'model': {},
 'model_update_aggregator': {},
 'models': {},
 'optimizers': {},
 'personalization_eval': {},
 'reconstruction': {},
 'robust_aggregator': {},
 'secure_aggregator': {},
 '

In [25]:
# from tensorflow_federated.python.tests import jax_components

STEP_SIZE = 0.001

print(dir(tff.learning))
trainer = tff.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

['BatchOutput', 'ClientWeighting', 'MetricFinalizersType', 'Model', 'ModelWeights', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'add_debug_measurements', 'add_debug_measurements_with_mixed_dtype', 'algorithms', 'build_federated_evaluation', 'build_local_evaluation', 'build_personalization_eval', 'client_weight_lib', 'compression_aggregator', 'ddp_secure_aggregator', 'debug_measurements', 'deprecation', 'dp_aggregator', 'entropy_compression_aggregator', 'federated_aggregate_keras_metric', 'federated_evaluation', 'framework', 'from_keras_model', 'keras_utils', 'metrics', 'model', 'model_update_aggregator', 'models', 'optimizers', 'personalization_eval', 'reconstruction', 'robust_aggregator', 'secure_aggregator', 'state_with_new_model_weights', 'templates']


AttributeError: module 'tensorflow_federated.python.learning' has no attribute 'build_jax_federated_averaging_process'