In [2]:
import tensorflow as tf
import tensorflow_federated as tff
import nest_asyncio

In [3]:
nest_asyncio.apply()
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [4]:
def get_perceptron():
    return tf.keras.Sequential([
        tf.keras.layers.Dense(1, input_shape=(2,))
    ])


def tff_perceptron_model_fn():
    # We _must_ create a new model here, and _not_ capture it from an external
    # scope. TFF will call this within different graph contexts.
    keras_model = get_perceptron()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=(tf.TensorSpec(shape=[None, 2], dtype=tf.float64),
            tf.TensorSpec(shape=[None,], dtype=tf.float64)),
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.0)]) 

In [5]:
model_fn = tff_perceptron_model_fn

In [6]:
# SERVER_STATE = {model weights}
@tff.tf_computation
def server_init():
    model = model_fn()
    return model.trainable_variables

# {model weights}@SERVER
@tff.federated_computation
def initialize_fn():
    return tff.federated_value(server_init(), tff.SERVER)

In [7]:
# Defining type signatures - 1
model_weights_type = server_init.type_signature.result
dummy_model = model_fn()
tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
federated_server_state_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

In [8]:
@tf.function
def client_update(model, dataset, server_weights, lr, accumulator):
    
    """Performs training (using the server model weights) on the client's dataset."""
    # Initialize the client model with the current server weights.
    client_weights = model.trainable_variables
    # Assign the server weights to the client model.
    tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

    # Update the local model using SGD.
    for batch in dataset:
        with tf.GradientTape(persistent=True) as tape:
            # Compute a forward pass on the batch of data
            outputs = model.forward_pass(batch)

        # Compute the corresponding gradient
        grads = tape.gradient(outputs.loss, client_weights)

        # Apply the gradient using a client optimizer.
        updated_accumulator = tf.nest.map_structure(lambda a, g: -lr*g, accumulator, grads)
        updated_weights = tf.nest.map_structure(lambda w, a: w+a, client_weights, updated_accumulator)
        
        tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, updated_weights)
        tf.nest.map_structure(lambda x, y: x.assign(y), accumulator, updated_accumulator)
    
    # dictionary containing  {metric: [sum, count], ..}
    out_data = model.report_local_outputs()
    
    return client_weights, out_data, out_data['loss'][1]

In [9]:
@tff.tf_computation(tf_dataset_type, model_weights_type, tf.float32)
def client_update_fn(tf_dataset, server_weights_at_client, learning_rate):
    model = model_fn()
    
    # To be used when optimizer is SGD with momentum
    accumulator = tf.nest.map_structure(lambda l: tf.Variable(tf.zeros(l.shape, l.dtype)), server_weights_at_client)

    client_weights, out_data, n = client_update(model, tf_dataset, 
        server_weights_at_client, learning_rate, accumulator)

    return client_weights, out_data, n

In [10]:
@tf.function
def server_update(model, mean_client_weights):
    """Updates the server model weights as the average of the client model weights."""
    model_weights = model.trainable_variables
    # Assign the mean client weights to the server model.
    tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
    return model_weights

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
    model = model_fn()
    return server_update(model, mean_client_weights)

In [11]:
# Defining type signatures - 2
client_learning_rates_type = tff.FederatedType(tf.float32, tff.CLIENTS, all_equal=False)
client_agg_weights_type = tff.FederatedType(tf.float32, tff.CLIENTS, all_equal=False)

In [12]:
# Helper functions for data extraction and processing
@tff.tf_computation(client_update_fn.type_signature.result)
def extract_weights(tp_wts_mts):
    return tp_wts_mts[0], tp_wts_mts[2] 

@tff.tf_computation(client_update_fn.type_signature.result)
def extract_only_weights(tp_wts_mts):
  return tp_wts_mts[0]

@tff.tf_computation(client_update_fn.type_signature.result)
def extract_training_metrics(tp_wts_mts):
    return tp_wts_mts[1]

# Receives a dictionary {metric: [sum_all_samples, total_samples]}
# Return dictionary of means for every metric
@tff.tf_computation(client_update_fn.type_signature.result[1])
def get_mean(metric_dict):
    d = {}
    for k,v in metric_dict.items():
        d[k] = v[0]/v[1] 
    return d

In [13]:
@tff.federated_computation(
    federated_server_state_type, 
    federated_dataset_type, 
    client_learning_rates_type, 
    client_agg_weights_type) # List of p_i = n_i/n
def next_fn(
    server_state, 
    federated_dataset, 
    client_learning_rates,
    client_agg_weights):
    
    # Broadcast the server weights to the clients.
    server_weights_at_clients = tff.federated_broadcast(server_state)

    # Each client computes their updated weights.
    # Epochs and lr supplied by orchestrator (i.e us) 
    # instead of server for the purposes of simulation
    client_weights_and_metrics = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_clients, client_learning_rates))
    
    client_weights = tff.federated_map(extract_only_weights, client_weights_and_metrics)
    client_metrics = tff.federated_map(extract_training_metrics, client_weights_and_metrics)
    
    # Weighted averaging of client models - sum p_i x w_i
    mean_client_weights = tff.federated_mean(client_weights, client_agg_weights)
    
    # compute mean of training metrics
    client_metrics_summed = tff.federated_sum(client_metrics)
    mean_metrics = tff.federated_map(get_mean, client_metrics_summed)

    # The server updates its model.
    server_state = tff.federated_map(server_update_fn, mean_client_weights)

    return server_state, mean_metrics

In [14]:
import json
import os
import numpy as np
from math import floor, ceil

In [15]:

class SyntheticData:
    def __init__(self, train_dir, test_dir):
        with open(os.path.join(train_dir, 'train.json')) as f:
            train_d = json.load(f)
        
        self.client_ids, self.num_samples, self.train_data = train_d['users'], train_d['num_samples'], train_d['user_data']
        
        with open(os.path.join(test_dir, 'test.json')) as f:
            test_d = json.load(f)
        
        self.test_data = test_d['user_data']

    def get_client_ids(self):
        return self.client_ids

    def create_dataset_for_client(self, client_id):
        client_data = self.train_data[client_id]
        return tf.data.Dataset.from_tensor_slices((client_data['x'], client_data['y'])).map(lambda a,b: (tf.cast(a, tf.float64), tf.cast(b, tf.float64)))

    def create_train_dataset_for_all_clients(self):
        xs = list()
        ys = list()
        for data in self.train_data.values():
            for x in data['x']:
                xs.append(x)
            for y in data['y']:
                ys.append(y)
        xs = np.array(xs)
        ys = np.array(ys)
        return tf.data.Dataset.from_tensor_slices((xs, ys))

    def create_test_dataset_for_all_clients(self):
        xs = list()
        ys = list()
        for data in self.test_data.values():
            for x in data['x']:
                xs.append(x)
            for y in data['y']:
                ys.append(y)
        xs = np.array(xs)
        ys = np.array(ys)
        return tf.data.Dataset.from_tensor_slices((xs, ys)).map(lambda a,b: (tf.cast(a, tf.float64), tf.cast(b, tf.float64)))

#------------------------------------------------------------------------------
def make_federated_data(dataset, preprocess_fn, client_ids, client_num_samples, client_capacities, batch_size, round_num):
    return [
      preprocess_fn(dataset.create_dataset_for_client(client_ids[i]), batch_size, client_capacities[i], client_num_samples[i], round_num)
      for i in range(len((client_ids)))
    ]

#------------------------------------------------------------------------------
def preprocess(dataset, b, u, n, r):
    u_p = floor(n/b)
    if(u <= u_p):
        return dataset.shuffle(n, seed=r).batch(b).take(u)
    else:
        x = ceil((b*u)/n)
        return dataset.repeat(x).shuffle(n, seed=r).batch(b).take(u)

In [16]:
def get_client_agg_weights(budgets, client_num_samples, batch_size):
    seen_training_samples = []
    tp_i = []
    for i in range(len(budgets)):
        tau_i = budgets[i]
        n_i = client_num_samples[i]
        tau_possible_i = floor(n_i / batch_size)
        tp_i.append(tau_possible_i)
        if(tau_i <= tau_possible_i):
            n_final_i = tau_i * batch_size
        else:
            n_final_i = n_i
        seen_training_samples.append(n_final_i)

    total_seen_training_samples = sum(seen_training_samples)
    agg_weights = [
        x/total_seen_training_samples for x in seen_training_samples]

    return agg_weights

In [30]:
iterative_process = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

In [18]:
train_dir = '/mnt/nfs/dhasade/optml/data/synthetic_data/train'
test_dir = '/mnt/nfs/dhasade/optml/data/synthetic_data/test'

In [19]:
dataset = SyntheticData(train_dir, test_dir)

In [20]:
total_clients = len(dataset.num_samples)

In [21]:
lower_bound = 2; upper_bound = 2; num_clients = 5; lr_schedule=lambda round_num: 0.5/(1 + round_num)

In [22]:
rng = np.random.default_rng(1)

In [33]:
state = iterative_process.initialize()

In [34]:
for round_num in range(100):
    client_indexes = rng.integers(
                    low=0, high=total_clients, size=num_clients)

    client_ids = []
    client_num_samples = []
    for client_index in client_indexes:
        client_ids.append(dataset.client_ids[client_index])
        client_num_samples.append(dataset.num_samples[client_index])

    budgets = rng.integers(
        low=lower_bound, high=upper_bound+1, size=num_clients)

    lr_to_clients = [lr_schedule(round_num)]*num_clients
    client_agg_weights = get_client_agg_weights(
        budgets, client_num_samples, 5)

    federated_train_data = make_federated_data(
        dataset, preprocess, client_ids, client_num_samples, budgets, 5, round_num)
    
    state, metrics = iterative_process.next(
        state,
        federated_train_data,
        lr_to_clients,
        client_agg_weights
        )
    
    for name, value in metrics.items():
        tf.summary.scalar('train_' + name, value, step=round_num)
        if('loss' in name):
            print('[Train loss', value, ']')
        else:
            print('[Train accuracy', value, ']')


[Train accuracy 0.72 ]
[Train loss 0.48016942 ]
[Train accuracy 0.72 ]
[Train loss 0.4644083 ]
[Train accuracy 0.88 ]
[Train loss 0.32609066 ]
[Train accuracy 0.82 ]
[Train loss 0.33324128 ]
[Train accuracy 0.86 ]
[Train loss 0.38307858 ]
[Train accuracy 0.78000003 ]
[Train loss 0.41607076 ]
[Train accuracy 0.78000003 ]
[Train loss 0.45950142 ]
[Train accuracy 0.78000003 ]
[Train loss 0.42677802 ]
[Train accuracy 0.82 ]
[Train loss 0.3205042 ]
[Train accuracy 0.9 ]
[Train loss 0.38034537 ]
[Train accuracy 0.74000007 ]
[Train loss 0.45960006 ]
[Train accuracy 0.82 ]
[Train loss 0.46016908 ]
[Train accuracy 0.94000006 ]
[Train loss 0.31331843 ]
[Train accuracy 0.73999995 ]
[Train loss 0.5235738 ]
[Train accuracy 0.85999995 ]
[Train loss 0.4059956 ]
[Train accuracy 0.84 ]
[Train loss 0.38186267 ]
[Train accuracy 0.82000005 ]
[Train loss 0.4100829 ]
[Train accuracy 0.91999996 ]
[Train loss 0.358624 ]
[Train accuracy 0.8 ]
[Train loss 0.45920867 ]
[Train accuracy 0.93999994 ]
[Train loss 0.

In [35]:
def evaluate_synthetic(server_model_weights, central_test_dataset):
  keras_model = get_perceptron()
  keras_model.compile(
      loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.0)]  
  )
  keras_model.set_weights(server_model_weights)
  return keras_model.evaluate(central_test_dataset)

In [40]:
central_test_dataset = dataset.create_test_dataset_for_all_clients().batch(100)

In [41]:
evaluate_synthetic(state, central_test_dataset)



[0.3944089114665985, 0.8475000262260437]