In [None]:
!pip install --quiet --upgrade tensorflow_federated
!pip install nest_asyncio

import nest_asyncio
nest_asyncio.apply()


In [None]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()

In [3]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()


In [None]:
example_dataset = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[0])

it = iter(example_dataset)

for x in range(5):
    example_element = next(it)

example_element['label'].numpy()

#modeling

In [5]:
NUM_CLIENTS = 5
NUM_EPOCHS = 25
BATCH_SIZE = 50
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER=10
NUM_CLASS = 10
input_shape = (28, 28,1)


def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `image` and return the features as an `OrderedDict`."""
   
    element['pixels'] = tf.expand_dims(element['pixels'], 3)

    return collections.OrderedDict(
        x = element['pixels'],
        y = element['label'])

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn)

In [None]:
preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),next(iter(preprocessed_example_dataset)))

print(sample_batch['x'].shape)

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]


federated_train_data = make_federated_data(emnist_train, sample_clients)


In [7]:
def create_keras_model():
    return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=input_shape),
      tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
      tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
      tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
      tf.keras.layers.Dropout(0.25),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dropout(0.5),
      tf.keras.layers.Dense(NUM_CLASS, activation='softmax'),
    ])


In [8]:
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [9]:
iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

state = iterative_process.initialize()

In [11]:
import attr
@attr.s(eq=False, frozen=True)
class ServerState(object):
  """Structure for state on the server.
  Fields:
  -   `model`: A dictionary of model's trainable variables.
  -   `optimizer_state`: the list of variables of the optimizer.
  """
  model = attr.ib()
  optimizer_state = attr.ib()

  @classmethod
  def from_tff_result(cls, anon_tuple):
    return cls(
        model=tff.learning.framework.ModelWeights.from_tff_result(anon_tuple.model),
        optimizer_state=list(anon_tuple.optimizer_state))


In [41]:
#research folder contains the research files downloaded from Tensorflow Github page.
!cd './research'

In [None]:
from utils import checkpoint_manager

In [None]:
cd ..

In [None]:
!rm -rf './model_emnist_ff.h5'

In [32]:
!mkdir './model_emnist_ff.h5'

In [None]:
ckpt_manager = checkpoint_manager.FileCheckpointManager("./model_emnist_ff.h5")


import time
first = time.time()
print("first since epoch =", first)	

NUM_ROUNDS = 101
SIM_METRICS = [NUM_ROUNDS]


USERS_PER_ROUND = NUM_CLIENTS

for round_num in range(1, NUM_ROUNDS):

  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
   
ckpt_manager.save_checkpoint(state, round_num=NUM_ROUNDS)
seconds = time.time()
print("Time diff =", seconds - first)

#Mode Restore

In [23]:
restored_state = ckpt_manager.load_latest_checkpoint(state)

model_last = create_keras_model()
restored_state[0].model.assign_weights_to(model_last)

#Testing 

In [24]:
def count_iterable(i):
    return sum(1 for e in i)

all_clients = 3383
all_train_counts = []

for client_id in range(all_clients):
    example_dataset = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[client_id])
    it = iter(example_dataset)
    all_train_counts.append([])
    c = count_iterable(it)
    all_train_counts[client_id].append(c)

#all_train_counts

all_test_counts = []
for client_id in range(all_clients):
    example_dataset = emnist_test.create_tf_dataset_for_client(emnist_test.client_ids[client_id])
    it = iter(example_dataset)
    all_test_counts.append([])
    c = count_iterable(it)
    all_test_counts[client_id].append(c)


In [None]:
CLIENT_MAX = 3383

def calculate_metric(model, client_num):
    correct = 0
    total = 0
    per_client_total = 0
    per_client_correct = 0
    all_accuracy = []

    for client_id in range(client_num):
        all_accuracy.append([])
        example_dataset = emnist_test.create_tf_dataset_for_client(emnist_test.client_ids[client_id])
        #example_dataset = federated_test_data[client_id]
        it = iter(example_dataset)
        for img in range(all_test_counts[client_id][0]-1):
            test_example_element = next(it)
            actual_label = test_example_element['label'].numpy()
            test_image = tf.reshape(test_example_element['pixels'].numpy(), [-1, 28, 28, 1])
            predicted_label = model.predict_classes(test_image, batch_size=1)

            per_client_total += 1
            if(predicted_label[0] == actual_label):
                per_client_correct += 1

        all_accuracy[client_id].append(per_client_correct/per_client_total*100)
        total += per_client_total
        correct += per_client_correct
        per_client_total = 0
        per_client_correct = 0


    return np.average(all_accuracy)


model_accuracy = calculate_metric(model_last, CLIENT_MAX)
print(model_accuracy)
