# Federated Tensorflow CIFAR10 Tutorial
Using `tf.data` API

In [1]:
!pip install ml_collections -q

## Imports

In [2]:
# import os
# os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.7'

In [3]:
import tensorflow as tf
print('TensorFlow', tf.__version__)

TensorFlow 2.8.2


In [4]:
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import jax.numpy as jnp
import ml_collections
import optax
import tensorflow_datasets as tfds
import logging

  from .autonotebook import tqdm as notebook_tqdm


## Connect to the Federation

Start `Director` and `Envoy` before proceeding with this cell. 

This cell connects this notebook to the Federation.

In [5]:
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
director_port = 50055

# Create a Federation
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)

## Query Datasets from Shard Registry

In [6]:
shard_registry = federation.get_shard_registry()
shard_registry

{'ENV1': {'shard_info': node_info {
    name: "ENV1"
  }
  shard_description: "CIFAR10 dataset, shard Train segment train[0:16666] Test Segment test[0:3333] number 1/3.\nSamples [Train/Valid]: [16666/3333]"
  sample_shape: "32"
  sample_shape: "32"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-08-10 02:10:51',
  'current_time': '2022-08-10 02:10:53',
  'valid_duration': seconds: 10,
  'experiment_name': 'ExperimentName Mock'},
 'ENV2': {'shard_info': node_info {
    name: "ENV2"
  }
  shard_description: "CIFAR10 dataset, shard Train segment train[16666:33332] Test Segment test[3333:6666] number 2/3.\nSamples [Train/Valid]: [16666/3333]"
  sample_shape: "32"
  sample_shape: "32"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-08-10 02:10:53',
  'current_time': '2022-08-10 02:10:53',
  'valid_duration': seconds: 10,
  'experiment_name': 'Experi

In [7]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

'Sample shape: (32, 32, 3), target shape: (1,)'

In [8]:
editor_relpaths = ('configs/default.py', 'train.py')
from configs import default as config_lib
config = config_lib.get_config()

## Describing FL experiment

In [9]:
from openfl.interface.interactive_api.experiment import TaskInterface
from openfl.interface.interactive_api.experiment import ModelInterface
from openfl.interface.interactive_api.experiment import FLExperiment

### Register model

In [10]:
# Define model
class CNN(nn.Module):
    """A simple CNN model."""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

def create_train_state(rng, config):
    """Creates initial `TrainState`."""
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 32, 32, 3]))['params'].unfreeze()
    tx = optax.sgd(config.learning_rate, config.momentum)
    return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)


In [11]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
model = create_train_state(init_rng, config) # state == weights == model; this makes the trainign and validation loop stateless


# # Define optimizer
# # optimizer = tf.optimizers.Adam(learning_rate=1e-4)

# # Loss and metrics. These will be used later.
# loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
# val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

# Create ModelInterface
framework_adapter = 'custom_adapter.CustomFrameworkAdapter'
MI = ModelInterface(model=model, optimizer=None, framework_plugin=framework_adapter)



### Register dataset

In [12]:
from openfl.interface.interactive_api.experiment import DataInterface

class CIFAR10FedDataset(DataInterface):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        # shard_descriptor.get_split(...) returns a tf.data.Dataset
        # Check cifar10_shard_descriptor.py for details
        self.train_set = shard_descriptor.get_split('train')
        self.valid_set = shard_descriptor.get_split('valid')

    def get_train_loader(self):
        """Output of this method will be provided to tasks with optimizer in contract"""
        return self.train_set
        # bs = self.kwargs.get('train_bs', 32)
        # return self.train_set.batch(bs)

    def get_valid_loader(self):
        """Output of this method will be provided to tasks without optimizer in contract"""
        return self.valid_set
        # bs = self.kwargs.get('valid_bs', 32)
        # return self.valid_set.batch(bs)
    
    def get_train_data_size(self) -> int:
        """Information for aggregation"""
        return len(self.train_set)

    def get_valid_data_size(self) -> int:
        """Information for aggregation"""
        return len(self.valid_set)

### Create CIFAR10 federated dataset

In [13]:
fed_dataset = CIFAR10FedDataset()

## Define and register FL tasks

In [14]:
@jax.jit
def apply_model(state, images, labels):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params):
        logits = state.apply_fn({'params': params}, images)
        one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy


@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [15]:
from tensorflow.keras.utils import Progbar

In [16]:
def train_epoch(state, train_ds, batch_size, rng):
    """Train for a single epoch."""
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size
    pbar = Progbar(steps_per_epoch)

    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))

    epoch_loss = []
    epoch_accuracy = []
    step = 1
    for perm in perms:
        # print("Printing Perm")
        # print(perm)
        batch_images = train_ds['image'][perm, ...]
        batch_labels = train_ds['label'][perm, ...]
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
        pbar.update(step, values={'epoch loss': loss, 'epoch accuracy': accuracy}.items())
        step = step + 1
        
    train_loss = jnp.array(epoch_loss).mean().item()
    train_accuracy = jnp.array(epoch_accuracy).mean().item()
    return state, train_loss, train_accuracy

In [17]:
# from tensorflow.keras.utils import Progbar

TI = TaskInterface()

@TI.register_fl_task(model='model', data_loader='dataset', optimizer='optimizer', device='device')  
def train(model, dataset, optimizer, device, loss_fn=None, warmup=False):
    
    state, train_loss, train_accuracy = train_epoch(model, dataset, config.batch_size, init_rng)
    print("Training acc over epoch: %.4f" % (float(train_accuracy),))
    return {'train_acc': train_accuracy,}
    # Iterate over the batches of the dataset.
    # pbar = Progbar(len(dataset))
    
#     for step, (x, y) in enumerate(dataset):
        
#         # Gradient
#         with tf.GradientTape() as tape:
#             logits = model(x, training=True)
#             loss_value = loss_fn(y, logits)
#         grads = tape.gradient(loss_value, model.trainable_weights)
#         optimizer.apply_gradients(zip(grads, model.trainable_weights))

#         # Update training metric.
#         train_acc_metric.update_state(y, logits)
#         pbar.update(step+1, 
#                     values={'loss': loss_value, 'acc': train_acc_metric.result()}.items())
#         if warmup: break
    
#     # Display metrics at the end of each epoch.
#     train_acc = train_acc_metric.result()
#     print("Training acc over epoch: %.4f" % (float(train_acc),))

#     # Reset training metrics at the end of each epoch
#     train_acc_metric.reset_states()
#     return {'train_acc': train_acc,}


@TI.register_fl_task(model='model', data_loader='dataset', device='device')     
def validate(model, dataset, device):
    print("VAL SIZE")
    print(len(dataset['image']))
    _, val_loss, val_accuracy = apply_model(model, dataset['image'], dataset['label'])
    print("Validation accuracy: %.4f" % (float(val_accuracy),))
    return {'validation_accuracy': val_accuracy,}

    # Run a validation loop at the end of each epoch.
#     for x, y in dataset:
#         logits = model(x, training=False)
#         # Update val metrics
#         val_acc_metric.update_state(y, logits)
#     val_acc = val_acc_metric.result()
#     val_acc_metric.reset_states()
#     print("Validation acc: %.4f" % (float(val_acc),))
            
#     return {'validation_accuracy': val_acc,}

## Time to start a federated learning experiment

In [18]:
# create an experimnet in federation
experiment_name = 'cifar10_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [19]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
ROUNDS_TO_TRAIN = 2
fl_experiment.start(model_provider=MI,
                   task_keeper=TI,
                   data_loader=fed_dataset,
                   rounds_to_train=ROUNDS_TO_TRAIN,
                   opt_treatment='CONTINUE_GLOBAL')




In [None]:
fl_experiment.stream_metrics()