# Federated TF Keras TinyImageNet Tutorial
## Using low-level Python API

# Long-Living entities update

* We now may have director running on another machine.
* We use Federation API to communicate with Director.
* Federation object should hold a Director's client (for user service)
* Keeping in mind that several API instances may be connacted to one Director.


* We do not think for now how we start a Director.
* But it knows the data shape and target shape for the DataScience problem in the Federation.
* Director holds the list of connected envoys, we do not need to specify it anymore.
* Director and Envoys are responsible for encrypting connections, we do not need to worry about certs.


* Yet we MUST have a cert to communicate to the Director.
* We MUST know the FQDN of a Director.
* Director communicates data and target shape to the Federation interface object.


* Experiment API may use this info to construct a dummy dataset and a `shard descriptor` stub.

### Install requirements

In [1]:
!pip install tensorflow==2.6.0

Collecting tensorflow==2.6.0
  Using cached tensorflow-2.6.0-cp38-cp38-manylinux2010_x86_64.whl (458.4 MB)
Collecting grpcio<2.0,>=1.37.0
  Using cached grpcio-1.40.0-cp38-cp38-manylinux2014_x86_64.whl (4.3 MB)
Collecting clang~=5.0
  Using cached clang-5.0-py3-none-any.whl
Collecting google-pasta~=0.2
  Using cached google_pasta-0.2.0-py3-none-any.whl (57 kB)
Collecting six~=1.15.0
  Using cached six-1.15.0-py2.py3-none-any.whl (10 kB)
Collecting tensorflow-estimator~=2.6
  Using cached tensorflow_estimator-2.6.0-py2.py3-none-any.whl (462 kB)
Collecting h5py~=3.1.0
  Using cached h5py-3.1.0-cp38-cp38-manylinux1_x86_64.whl (4.4 MB)
Collecting opt-einsum~=3.3.0
  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Collecting typing-extensions~=3.7.4
  Using cached typing_extensions-3.7.4.3-py3-none-any.whl (22 kB)
Collecting flatbuffers~=1.12.0
  Using cached flatbuffers-1.12-py2.py3-none-any.whl (15 kB)
Collecting astunparse~=1.6.3
  Using cached astunparse-1.6.3-py2.py3-none-any.wh

## Connect to the Federation

In [248]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = f'{cert_dir}/root_ca.crt'
# api_certificate = f'{cert_dir}/{client_id}.crt'
# api_private_key = f'{cert_dir}/{client_id}.key'

# federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051',
#                        cert_chain=cert_chain, api_cert=api_certificate, api_private_key=api_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051', tls=False)

In [250]:
shard_registry = federation.get_shard_registry()
shard_registry

{'env_one': {'shard_info': node_info {
    name: "env_one"
  }
  shard_description: "TinyImageNetDataset dataset, shard number 1 out of 2"
  sample_shape: "300"
  sample_shape: "400"
  sample_shape: "3"
  target_shape: "300"
  target_shape: "400",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2021-09-22 22:53:45',
  'current_time': '2021-09-22 22:53:48',
  'valid_duration': seconds: 120}}

## Dataset

In [253]:
import numpy as np
import tensorflow as tf

In [254]:
tf.__version__

'2.6.0'

In [255]:
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)

In [256]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

In [257]:
import math


def train_transform(img):
    array = tf.keras.preprocessing.image.img_to_array(img)
    if img.mode == 'L':
        original = tf.constant(array)
        array = tf.image.grayscale_to_rgb(original)
    array = tf.reshape(array, array.shape)
    
    return array


def val_transform(img):
    array = tf.keras.preprocessing.image.img_to_array(img)
    if img.mode == 'L':
        original = tf.constant(array)
        array = tf.image.grayscale_to_rgb(original)
    array = tf.reshape(array, array.shape)
    return array


class TransformedDataset(tf.keras.utils.Sequence):
    """Image Person ReID Dataset."""

    def __init__(self, dataset, batch_size, transform=None):
        """Initialize Dataset."""
        self.dataset = dataset
        self.batch_size = batch_size
        self.transform = transform

    def __len__(self):
        """Length of dataset."""
        return math.ceil(len(self.dataset) / self.batch_size)

    def __getitem__(self, index):
        first_id = index * self.batch_size
        last_id = (index + 1) * self.batch_size
        if len(self.dataset) < last_id:
            last_id = len(self.dataset)
    
        batch_x = []
        batch_y = []
        
        for i in range(first_id, last_id):
            img, label = self.dataset[i]
            img = self.transform(img) if self.transform else None
            batch_x.append(img)
            batch_y.append(label)

        return np.array(batch_x), np.array(batch_y)

In [258]:
class TinyImageNetDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.train_bs = kwargs['train_bs']
        self.valid_bs = kwargs['valid_bs']
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            batch_size=self.train_bs,
            transform=train_transform
        )
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            batch_size=self.valid_bs,
            transform=val_transform
        )

    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return self.train_set

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return self.valid_set
    
    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)

## Model

In [261]:
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential


num_classes = 200


data_augmentation = tf.keras.Sequential([
  layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
  layers.experimental.preprocessing.RandomRotation(0.2),
])


model = Sequential([
    layers.experimental.preprocessing.Rescaling(1./255),
#     data_augmentation,  
    layers.Conv2D(16, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(num_classes)
])

In [262]:
train_bs = 8
valid_bs = 8
optimizer = tf.keras.optimizers.Adam()
loss_fn = keras.losses.SparseCategoricalCrossentropy()

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

In [263]:
model.build(input_shape=(train_bs, 64, 64, 3))

In [264]:
model.summary()

Model: "sequential_79"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
rescaling_49 (Rescaling)     (8, 64, 64, 3)            0         
_________________________________________________________________
conv2d_136 (Conv2D)          (8, 64, 64, 16)           448       
_________________________________________________________________
max_pooling2d_138 (MaxPoolin (8, 32, 32, 16)           0         
_________________________________________________________________
conv2d_137 (Conv2D)          (8, 32, 32, 32)           4640      
_________________________________________________________________
max_pooling2d_139 (MaxPoolin (8, 16, 16, 32)           0         
_________________________________________________________________
conv2d_138 (Conv2D)          (8, 16, 16, 64)           18496     
_________________________________________________________________
max_pooling2d_140 (MaxPoolin (8, 8, 8, 64)           

In [265]:
task_interface = TaskInterface()


@task_interface.register_fl_task(model='model', data_loader='train_dataset',
                                 device='device', optimizer='optimizer')     
def train(model, train_dataset, optimizer, device, loss_fn=loss_fn, warmup=False):    
    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        y = tf.convert_to_tensor(y_batch_train)
        with tf.GradientTape() as tape:
            y_pred = model(x_batch_train, training=True)  # Forward pass

            loss = loss_fn(y, y_pred)

        # Compute gradients
        trainable_vars = model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        train_acc_metric.update_state(y, y_pred)
    
    # Reset training metrics at the end of each epoch
    train_acc = train_acc_metric.result()
    train_acc_metric.reset_states()
    return {'train_acc': train_acc, 'loss': loss}


@task_interface.register_fl_task(model='model', data_loader='val_dataset', device='device')     
def validate(model, val_dataset, device):
    for x_batch_val, y_batch_val in val_dataset:
        y = tf.convert_to_tensor(y_batch_val)
        # Compute predictions
        y_pred = model(x_batch_val, training=False)
        # Update the metrics.
        val_acc_metric.update_state(y, y_pred)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    
    return {'validation_accuracy': val_acc}

In [None]:
fed_dataset = TinyImageNetDataset(train_bs=train_bs, valid_bs=valid_bs)

In [266]:
# # Prepare data for local tests
# from tinyimagenet_shard_descriptor import TinyImageNetShardDescriptor


# sd = TinyImageNetShardDescriptor('tinyimagenet_data')
# fed_dataset.shard_descriptor = sd


In [267]:
# # Run local test
# for i in range(5):
#     print(i)
#     train(model=model, optimizer=optimizer, device='CPU', train_dataset=fed_dataset.get_train_loader())

In [268]:
# # Prepare data for local tests

# !pip install matplotlib
# import matplotlib.pyplot as plt


# def draw(history, epochs):
#     acc = history.history['accuracy']
#     val_acc = history.history['val_accuracy']

#     loss = history.history['loss']
#     val_loss = history.history['val_loss']

#     epochs_range = range(epochs)

#     plt.figure(figsize=(8, 8))
#     plt.subplot(1, 2, 1)
#     plt.plot(epochs_range, acc, label='Training Accuracy')
#     plt.plot(epochs_range, val_acc, label='Validation Accuracy')
#     plt.legend(loc='lower right')
#     plt.title('Training and Validation Accuracy')

#     plt.subplot(1, 2, 2)
#     plt.plot(epochs_range, loss, label='Training Loss')
#     plt.plot(epochs_range, val_loss, label='Validation Loss')
#     plt.legend(loc='upper right')
#     plt.title('Training and Validation Loss')
#     plt.show()
    
    
# def get_model(lr=0.001):
#     model = Sequential([
#         layers.experimental.preprocessing.Rescaling(1./255, input_shape=(64, 64, 3)),
# #         data_augmentation,
#         layers.Conv2D(16, 3, padding='same', activation='relu'),
#         layers.MaxPooling2D(),
#         layers.Conv2D(32, 3, padding='same', activation='relu'),
#         layers.MaxPooling2D(),
#         layers.Conv2D(64, 3, padding='same', activation='relu'),
#         layers.MaxPooling2D(),
#         layers.Conv2D(128, 3, padding='same', activation='relu'),
#         layers.MaxPooling2D(),
#         layers.Conv2D(256, 3, padding='same', activation='relu'),
#         layers.MaxPooling2D(),
#         layers.Flatten(),
#         layers.Dense(512, activation='relu'),
#         layers.Dense(num_classes)
#     ])
# #     base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(
# #         input_shape=(64, 64, 3), alpha=0.75, include_top=False, weights='imagenet',
# #     )
# #     base_model.layers[0].trainable = False
# #     x = base_model.output
# #     x = layers.MaxPooling2D()(x)
# #     x = layers.Flatten()(x)
# #     last_layer = tf.keras.layers.Dense(num_classes, activation = 'softmax')(x)
# #     model = tf.keras.models.Model(inputs = base_model.input, outputs = last_layer)
#     optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
# #     model.build(input_shape=[train_bs, 64, 64, 3])
#     model.compile(
#         optimizer=optimizer,
#         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
#         metrics=['accuracy'],
#     )
#     return model

In [269]:
# # Run local test

# models = {}
# for lr in [0.0001]:
#     model = get_model(lr)
#     models[lr] = model
#     epochs=5
#     print(f'{lr=}')
#     history = model.fit(
#       fed_dataset.get_train_loader(),
#       validation_data=fed_dataset.get_valid_loader(),
#       epochs=epochs
#     )
#     draw(history, epochs)

### Register model


In [270]:
framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

## Time to start a federated learning experiment

In [271]:
# create an experimnet in federation
experiment_name = 'tinyimagenet_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [272]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=5,
    opt_treatment='CONTINUE_GLOBAL'
)