In [1]:
!pip install tensorflow_federated

Collecting tensorflow_federated
  Downloading tensorflow_federated-0.19.0-py2.py3-none-any.whl (602 kB)
[K     |████████████████████████████████| 602 kB 12.9 MB/s 
[?25hCollecting attrs~=19.3.0
  Downloading attrs-19.3.0-py2.py3-none-any.whl (39 kB)
Collecting tensorflow-privacy~=0.5.0
  Downloading tensorflow_privacy-0.5.2-py3-none-any.whl (192 kB)
[K     |████████████████████████████████| 192 kB 47.9 MB/s 
Collecting tensorflow~=2.5.0
  Downloading tensorflow-2.5.2-cp37-cp37m-manylinux2010_x86_64.whl (454.4 MB)
[K     |████████████████████████████████| 454.4 MB 25 kB/s 
Collecting grpcio~=1.34.0
  Downloading grpcio-1.34.1-cp37-cp37m-manylinux2014_x86_64.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 49.9 MB/s 
[?25hCollecting tqdm~=4.28.1
  Downloading tqdm-4.28.1-py2.py3-none-any.whl (45 kB)
[K     |████████████████████████████████| 45 kB 3.5 MB/s 
[?25hCollecting cachetools~=3.1.1
  Downloading cachetools-3.1.1-py2.py3-none-any.whl (11 kB)
Collecting tensorf

In [3]:
from google.colab import drive
drive.mount('/content/gdrive')
import sys
sys.path.append('/content/gdrive/My Drive')

Mounted at /content/gdrive


In [None]:
import tensorflow as tf
import tensorflow_federated as tff
import fed_compression
import dnn_models as dnn

from config import *
from pathlib import Path
from utils import plot_graph
from datetime import datetime
from dataset import load_dataset
from matplotlib import pyplot as plt
from tensorflow.keras import losses, metrics, optimizers

now = datetime.now()
date_time = now.strftime("%d.%m.%Y__%H.%M.%S")

this_dir = Path.cwd()
model_dir = this_dir / "saved_models" / name_dt / str(datetime)
output_dir = this_dir / "results" / name_dt / str(datetime)

if not model_dir.exists():
    model_dir.mkdir(parents=True)

if not output_dir.exists():
    output_dir.mkdir(parents=True)


federated_train_data, preprocessed_sample_dataset = load_dataset(phase='train')


def model_fn():
    # We _must_ create a new model here, and _not_ capture it from an external
    # scope. TFF will call this within different graph contexts.

    keras_model = dnn.keras_model(model_name)
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=preprocessed_sample_dataset.element_spec,
        loss=losses.CategoricalCrossentropy(),
        metrics=[metrics.CategoricalAccuracy()])


iterative_process = fed_compression.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: optimizers.Adam(learning_rate=client_lr),
    server_optimizer_fn=lambda: optimizers.SGD(learning_rate=server_lr))


print(str(iterative_process.initialize.type_signature))
state = iterative_process.initialize()


x_test, y_test = load_dataset(phase='test')

tff_train_acc = []
tff_val_acc = []
tff_train_loss = []
tff_val_loss = []

eval_model = None
for round_num in range(1, NUM_ROUNDS+1):
    state, tff_metrics = iterative_process.next(state, federated_train_data)
    keras_model = dnn.keras_model(model_name)
    eval_model.compile(optimizer=optimizers.Adam(learning_rate=client_lr),
                       loss=losses.SparseCategoricalCrossentropy(),
                       metrics=[metrics.SparseCategoricalAccuracy()])

    tff.learning.assign_weights_to_keras_model(eval_model, state.model)

    ev_result = eval_model.evaluate(x_test, y_test, verbose=0)
    print('round {:2d}, metrics={}'.format(round_num, tff_metrics))
    print(f"Eval loss : {ev_result[0]} and Eval accuracy : {ev_result[1]}")
    tff_train_acc.append(float(tff_metrics.sparse_categorical_accuracy))
    tff_val_acc.append(ev_result[1])
    tff_train_loss.append(float(tff_metrics.loss))
    tff_val_loss.append(ev_result[0])

metric_collection = {"sparse_categorical_accuracy": tff_train_acc,
                     "val_sparse_categorical_accuracy": tff_val_acc,
                     "loss": tff_train_loss,
                     "val_loss": tff_val_loss}

if eval_model:
    eval_model.save(model_dir / (name_dt + ".h5"))
else:
    print("training didn't started")
    exit()

fig = plt.figure(figsize=(10, 6))
plot_graph(list(range(1, 26))[4::5], tff_train_acc, label='Train Accuracy')
plot_graph(list(range(1, 26))[4::5], tff_val_acc, label='Validation Accuracy')
plt.legend()
plt.savefig(output_dir / "federated_model_Accuracy.png")

plt.figure(figsize=(10, 6))
plot_graph(list(range(1, 26))[4::5], tff_train_loss, label='Train loss')
plot_graph(list(range(1, 26))[4::5], tff_val_loss, label='Validation loss')
plt.legend()
plt.savefig(output_dir / "federated_model_loss.png")



# saving metric values to text file

txt_file_path = output_dir / (name_dt + ".txt")
with open(txt_file_path.as_posix(), "w") as handle:
    content = []
    for key, val in metric_collection.items():
        line_content = key
        val = [str(k) for k in val]
        line_content = line_content + " " + " ".join(val)
        content.append(line_content)
    handle.write("\n".join(content))

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(IntProgress(value=0, description='Dl Completed...', max=4, style=ProgressStyle(description_widt…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m
Number of client datasets: 10
First dataset: <PrefetchDataset shapes: OrderedDict([(x, (None, 32, 32, 1)), (y, (None, 10))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>
ResNet model is built with tf.keras.application.ResNet101
Instructions for updating:
Colocations handled automatically by placer.


Instructions for updating:
Colocations handled automatically by placer.


ResNet model is built with tf.keras.application.ResNet101
ResNet model is built with tf.keras.application.ResNet101
ResNet model is built with tf.keras.application.ResNet101
loading the model 
