In [0]:
!pip install -q tensorflow tensorflow-datasets tensorflow-federated
!pip install -q jax dm-haiku

In [0]:
from typing import Any, Generator, Tuple, Mapping, Sequence, Optional, Callable
from collections import namedtuple
import functools, inspect, time

from absl import app
import haiku as hk
import jax
from jax.experimental import optix
from jax.tree_util import tree_multimap, tree_map, tree_reduce
import jax.numpy as jnp
from jax.lax import fori_loop
import tensorflow_datasets as tfds
import tensorflow_federated as tff
import tensorflow as tf

In [0]:
Batch = Mapping[str, jnp.ndarray]
ClientData = Generator[Batch, None, None]
LossFunction = Callable[[hk.Params, Batch], jnp.ndarray]
OptState = Any


# define hyperparameters format.
ServerHyperParams = namedtuple("ServerHyperParams", "sampled_clients\
                                                     batch_size\
                                                     num_epochs\
                                                     num_rounds\
                                                     seed")

# message to the client from server.
ClientMessage = namedtuple("ClientMessage", "params\
                                             opt_init_input")

# message to the server from client.
ServerMessage = namedtuple("ServerMessage", "aggregator_input\
                                             stateupdater_input")


# message to the server from client for book keeping.
DiagnosticsMessage = namedtuple("DiagnosticsMessage", "train_loss\
                                                       train_acc\
                                                       test_loss\
                                                       test_acc\
                                                       weight")
ClientOutput = Tuple[ServerMessage, DiagnosticsMessage]

# extracts messages from a list of client outputs.
@jax.partial(jax.jit, static_argnums=[1])  # fix extractor.
def extract_from_cout(
    couts: Sequence[ClientOutput],
    extractor: Callable[[ClientOutput], Any]
    ) -> Sequence[Any]:
  print("compiling: {}".format(inspect.currentframe().f_code.co_name))
  # TODO: use pytrees.transpose?
  # msgs = fori_loop(0, len(couts),
  #                    lambda i, msgs: msgs.append(extractor(couts[i])), [])
  msgs = [extractor(cout) for cout in couts]
  return msgs

# Construct data for multiple clients.
def make_federated_data(
    client_data: tff.simulation.ClientData, 
    client_ids: Sequence[str],
    batch_size : Optional[int] = 1024,
    train : Optional[bool] = True,
    num_epochs : Optional[int] = 1,
      seed : Optional[int] = 0
    ) -> Sequence[ClientData]:

  # Construct a tf.data.Dataset for client.
  def preprocess(
      ds: tf.data.Dataset,
      batch_size,
      train,
      num_epochs,
      seed
      ) -> ClientData:
    
    if train:
      ds = ds.repeat(num_epochs).shuffle(10*batch_size, seed)
    ds = ds.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return tfds.as_numpy(ds)
    
  return [
      preprocess(
          client_data.create_tf_dataset_for_client(idx),
          batch_size, train, num_epochs, seed)
      for idx in client_ids]

In [0]:
# Neural network model.
def net_fn(batch: Batch) -> jnp.ndarray:
  """Standard LeNet-300-100 MLP network."""
  x = batch["pixels"].astype(jnp.float32) / 255.
  mlp = hk.Sequential([
      hk.Flatten(),
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10), jax.nn.log_softmax])
  return mlp(x)
net: hk.Transformed = hk.transform(net_fn)
  
# Initialize neural network parameters 
def init(rng: jax.random.PRNGKey, batch: Batch) -> hk.Params:
  return net.init(rng, batch)

# get predictions from model.
def forward(params: hk.Params, batch: Batch):
  return jax.jit(net.apply)(params, batch)


# Training loss (cross-entropy).
@jax.jit
def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
  """Compute the loss of the network, including L2."""
  print("compiling: {}".format(inspect.currentframe().f_code.co_name))
  preds = forward(params, batch)
  labels = hk.one_hot(batch["label"], 10)
  # TODO: Put weight decay into optimizer
  l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
  softmax_xent = -jnp.mean(labels * preds)
  return softmax_xent + 1e-4 * l2_loss

# Evaluation metric (classification accuracy).
@jax.jit
def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray:
  print("compiling: {}".format(inspect.currentframe().f_code.co_name))
  preds = forward(params, batch)
  pred_class = jnp.argmax(preds, axis=-1)
  return jnp.mean(pred_class == batch["label"])

In [0]:
# one local update step.
@jax.partial(jax.jit, static_argnums=[2,4])  # fix loss function and optimizer.
def run_one_step(
    params: hk.Params,
    batch: Batch,
    client_opt: optix.InitUpdate,
    opt_state: OptState,
    loss: LossFunction
    ) -> Tuple[hk.Params, OptState]:
  """Learning rule (stochastic gradient descent)."""
  print("compiling: {}".format(inspect.currentframe().f_code.co_name))
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = client_opt.update(grads, opt_state)
  new_params = optix.apply_updates(params, updates)
  return new_params, opt_state  

# perform client updates.
def client_updater(
    msg: ClientMessage,
    ds: ClientData,
    client_opt: optix.InitUpdate,
    loss: LossFunction
    ) -> ClientOutput:
  opt_state = client_opt.init(*msg.opt_init_input)
  # iterate through data making updates.
  new_params = msg.params
  for minibatch in ds:
    new_params, opt_state = run_one_step(new_params, minibatch,
                                     client_opt, 
                                     opt_state, 
                                     loss)
  # compute and return the change in parameters.
  params_update = tree_multimap(lambda x, y: x - y, new_params, msg.params)
  
  # TODO: replace with an function which constructs message  
  msg_to_server = ServerMessage(
      aggregator_input=params_update,
      stateupdater_input=None
  )

  diagnostic_msg = DiagnosticsMessage(
      train_loss=0,
      train_acc=1,
      test_loss=0,
      test_acc=1,
      weight=1
  )
  return msg_to_server, diagnostic_msg

In [0]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
example_train = make_federated_data(emnist_train, 
                                    emnist_train.client_ids[:3], 
                                    batch_size = 32,
                                    num_epochs = 1,
                                    seed = 0)

# # Testing our client updater.
# params = init(jax.random.PRNGKey(42), next(example_train[0]))
# cmsg = ClientMessage(params, [params])
# new_params = client_updater(
#                 cmsg,
#                 example_train[1],
#                 optix.sgd(0.1),
#                 loss
#                 );

  collections.OrderedDict((name, ds.value) for name, ds in sorted(


In [0]:
emnist_train.client_ids == emnist_test.client_ids

True

In [0]:
# aggregate client updates.
# TODO: make the aggregator stateful.
@jax.jit
def average_params(params_list: Sequence[hk.Params]) -> hk.Params: 
  print("compiling: {}".format(inspect.currentframe().f_code.co_name))
  num_params = len(params_list)
  params_sum = functools.reduce(
      lambda t1, t2: tree_multimap(sum, t1, t2), params_list)
  params_avg = tree_map(lambda x: x/num_params, params_sum)
  return params_avg

In [0]:
# aggregate diagnositcs.
@jax.jit
def agg_diagnostics(
    client_outputs: Sequence[ClientOutput]) -> DiagnosticsMessage:
  d_msgs = extract_from_cout(client_outputs, lambda cout: cout[1])
  # TODO: aggregate and report statistics
  return d_msgs[0]


# update server params.
@jax.partial(jax.jit, static_argnums=[2,4])  # fix optimizer and aggregator.
def server_updater(
    server_params: hk.Params,
    client_outputs: Sequence[ClientOutput],
    server_opt: optix.InitUpdate,
    opt_state: OptState,
    aggregator: Callable[[Sequence[Any]], hk.Params]
    ) -> Tuple[hk.Params, OptState]:
  print("compiling: {}".format(inspect.currentframe().f_code.co_name))
  agg_inputs = extract_from_cout(client_outputs,
                                 lambda cout: cout[0].aggregator_input)
  agg_update = aggregator(agg_inputs)
  eff_grads = tree_map(lambda x: -1.0*x, agg_update)  # effective gradient.
  updates, opt_state = server_opt.update(eff_grads, opt_state)
  # TODO: allow opt_state to also be explicitly updated by state_updater
  server_params = optix.apply_updates(server_params, updates)
  return server_params, opt_state

# one round of federated learning.
def run_one_round(
    server_params: hk.Params,
    hyperparams: ServerHyperParams,
    client_data: tff.simulation.ClientData,
    client_opt: optix.InitUpdate,
    server_opt: optix.InitUpdate,
    opt_state: OptState,
    aggregator: Callable[[Sequence[Any]], hk.Params],
    loss: LossFunction,
    rng: jax.random.PRNGKey
) -> Tuple[hk.Params, OptState]:

  # choose `num_sample` indices out of `num_total` clients.
  num_total = len(client_data.client_ids)
  clientid_indx = jax.random.shuffle(
      rng, jnp.arange(num_total))[:hyperparams.sampled_clients]
  active_client_ids = [client_data.client_ids[i] for i in clientid_indx]

  print("making datasets for sampled clients.")
  sampled_dss = make_federated_data(client_data, active_client_ids, 
                            hyperparams.batch_size,
                            True,
                            hyperparams.num_epochs,
                            hyperparams.seed)
  
  print("computing updates from active clients.")
  # TODO: replace with an function which constructs message
  msg_to_clients = ClientMessage(
      params=server_params, 
      opt_init_input=[server_params])
  
  client_outputs = list(map(lambda ds: client_updater(
                          msg_to_clients,
                          ds,
                          client_opt,
                          loss), sampled_dss))
  print("aggregating client updates.")
  server_params, opt_state = server_updater(
      server_params,
      client_outputs,      
      server_opt,      
      opt_state,      
      aggregator)
  
  # TODO: aggregate and incporate new diagnostics. needs state!
  diag_state = agg_diagnostics(client_outputs)
  
  return server_params, opt_state

In [0]:
# Testing one round of federated averaging.
params = init(jax.random.PRNGKey(42), next(example_train[0]))
client_opt = optix.sgd(0.1)
server_opt = optix.sgd(1.0)
opt_state = server_opt.init(params)
rng = jax.random.PRNGKey(0)
hyperparams = ServerHyperParams(
    sampled_clients = 3,
    batch_size = 32,
    num_epochs = 5,
    num_rounds = 3,
    seed = 7
)

run_one_round(
    params,
    hyperparams,
    emnist_train,
    client_opt,
    server_opt,
    opt_state,
    average_params,
    loss,
    rng
);



making datasets for sampled clients.
computing updates from active clients.
compiling: run_one_step
compiling: loss


  collections.OrderedDict((name, ds.value) for name, ds in sorted(


compiling: run_one_step
compiling: loss
compiling: run_one_step
compiling: loss
compiling: run_one_step
compiling: loss
aggregating client updates.
compiling: server_updater
compiling: extract_from_cout
compiling: average_params
compiling: extract_from_cout


In [0]:

def federated_learning(
    hyperparams: ServerHyperParams,
    client_data: tff.simulation.ClientData,
    client_opt: optix.InitUpdate,
    server_opt: optix.InitUpdate,
    aggregator: Callable[[Sequence[hk.Params]], hk.Params],
    loss: LossFunction,
    init
    ) -> hk.Params:
  # initialize random generator, params, opt_state.
  rng = jax.random.PRNGKey(hyperparams.seed)
  # TODO: use synthethic data for init
  server_params = init(rng, next(example_train[2]))
  opt_state = server_opt.init(server_params)
  rngs = jax.random.split(rng, hyperparams.num_rounds)
  for round_num, rng in enumerate(rngs):
    print("\nrunning round {}".format(round_num))
    server_params, opt_state = run_one_round(server_params, hyperparams, 
                                             client_data, client_opt, 
                                             server_opt, opt_state, 
                                             aggregator, loss, rng)
  return server_params
  



In [0]:
# testing federated learning code.
hyperparams = ServerHyperParams(
    sampled_clients = 3,
    batch_size = 32,
    num_epochs = 100,
    num_rounds = 3,
    seed = 22
)

federated_learning(
    hyperparams,
    emnist_train,
    optix.sgd(0.1),
    optix.sgd(1.0),
    average_params,
    loss,
    init
    );


running round 0
making datasets for sampled clients.
computing updates from active clients.
compiling: run_one_step


  collections.OrderedDict((name, ds.value) for name, ds in sorted(


compiling: run_one_step
compiling: loss
aggregating client updates.
compiling: server_updater
compiling: extract_from_cout

running round 1
making datasets for sampled clients.
computing updates from active clients.
compiling: run_one_step
compiling: loss
compiling: run_one_step
compiling: loss
compiling: run_one_step
aggregating client updates.

running round 2
making datasets for sampled clients.
computing updates from active clients.
aggregating client updates.


# Questions

1. Does the net.init also initialize the output layer based on batch size?