# Federated SLDA TF.Data and Custom Dataset Tutorial
Using `tf.data` API

In [1]:
# Install TF if not already. We recommend TF2.7 or greater.
# !pip install tensorflow==2.8

## Imports

In [2]:
%env TF_FORCE_GPU_ALLOW_GROWTH=true
import tensorflow as tf

import tensorflow_datasets as tfds

# Config/Options
from openfl.cl.config import Decoders
from openfl.cl.config import IMG_AUGMENT_LAYERS

# Model/Loss definitions
from openfl.cl.models.slda import SLDA
from openfl.cl.models import losses
from openfl.cl.models.utils import extract_features

# Dataset handling (synthesize/build/query)
from openfl.cl.lib.dataset.repository import DatasetRepository
from openfl.cl.lib.dataset.utils import as_tuple, decode_example, get_label_distribution
from openfl.cl.lib.dataset.synthesizer import synthesize_by_sharding_over_labels

env: TF_FORCE_GPU_ALLOW_GROWTH=true


2022-12-20 10:26:30.868694: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
  from .autonotebook import tqdm as notebook_tqdm
2022-12-20 10:26:32.601269: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-20 10:26:33.152821: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2022-12-20 10:26:33.152873: I tensorflow/core/common_runtime/gpu/gpu_device.cc

In [3]:
import tensorflow as tf
print('TensorFlow', tf.__version__)

TensorFlow 2.9.0


Experiment Options


In [4]:
DATASET = 'cifar10'   # If loading a public TensorFlow dataset
# DATASET = '/tmp/repository/vege'  # If loading a local TFRecord dataset

IMG_SIZE = (32, 32)
BATCH_SIZE = 32
SHUFFLE_BUFFER = 16384

## Connect to the Federation

Start `Director` and `Envoy` before proceeding with this cell. 

This cell connects this notebook to the Federation.

In [5]:
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
director_port = 50051

# Create a Federation
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)

## Query Datasets from Shard Registry

In [6]:
shard_registry = federation.get_shard_registry()
shard_registry

{'Q1': {'shard_info': node_info {
    name: "Q1"
  }
  shard_description: "CIFAR10 dataset, shard number 1/3.\nSamples [Train/Valid]: [15000/10000]"
  sample_shape: "32"
  sample_shape: "32"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-12-20 10:26:21',
  'current_time': '2022-12-20 10:26:33',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'Q2': {'shard_info': node_info {
    name: "Q2"
  }
  shard_description: "CIFAR10 dataset, shard number 2/3.\nSamples [Train/Valid]: [15000/10000]"
  sample_shape: "32"
  sample_shape: "32"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-12-20 10:26:26',
  'current_time': '2022-12-20 10:26:33',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'Q3': {'shard_info': node_info {
    name: "Q3"
  }
  shard_description: "CIFAR10 dataset, shard number

In [7]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

'Sample shape: (32, 32, 3), target shape: (1,)'

## Describing FL experiment

In [8]:
from openfl.interface.interactive_api.experiment import TaskInterface
from openfl.interface.interactive_api.experiment import ModelInterface
from openfl.interface.interactive_api.experiment import FLExperiment

### Register dataset

In [9]:
# """Load the dataset: Public or Local"""
# if tf.io.gfile.isdir(DATASET):
#     repo = DatasetRepository(data_dir=DATASET)
#     builder = repo.get_builder()  # Builds all versions by default
#     ds_info = builder.info
#     (raw_train_ds, raw_test_ds) = builder.as_dataset(split=['train', 'test'],
#                                                      decoders=Decoders.SIMPLE_DECODER)
# else:
# Load TFDS dataset by name (publicly-hosted on TF)
(raw_train_ds, raw_test_ds), ds_info = tfds.load(DATASET,
                                                 split=['train', 'test'],
                                                 with_info=True,
                                                 decoders=Decoders.SIMPLE_DECODER)
print('About: ', ds_info)
print('Element Spec: ', raw_train_ds.element_spec)
print('Training samples: ', len(raw_train_ds))
print('Testing samples: ', len(raw_test_ds))

About:  tfds.core.DatasetInfo(
    name='cifar10',
    full_name='cifar10/3.0.2',
    description="""
    The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
    """,
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    data_path='/home/sunilach/tensorflow_datasets/cifar10/3.0.2',
    download_size=162.17 MiB,
    dataset_size=132.40 MiB,
    features=FeaturesDict({
        'id': Text(shape=(), dtype=tf.string),
        'image': Image(shape=(32, 32, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=50000, num_shards=1>,
    },
    citation="""@TECHREPORT{Krizhevsky09learningmultiple,
        author = {Alex Krizhevsky},
        title = {Learn

Define Feature Extractor


In [10]:
backbone = tf.keras.applications.EfficientNetV2B0(
            include_top=False,
            weights='imagenet',
            input_shape=(*IMG_SIZE, 3),
            pooling='avg'
        )
backbone.trainable = False

"""Add augmentation/input layers"""
feature_extractor = tf.keras.Sequential([
    tf.keras.layers.InputLayer(backbone.input_shape[1:]),
    backbone,
], name='feature_extractor')

feature_extractor.summary()

Model: "feature_extractor"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 efficientnetv2-b0 (Function  (None, 1280)             5919312   
 al)                                                             
                                                                 
Total params: 5,919,312
Trainable params: 0
Non-trainable params: 5,919,312
_________________________________________________________________


In [11]:
# valid_ds = (test_features
#                     .cache()
#                     .map(as_tuple(x='image', y='label'))
#                     .batch(BATCH_SIZE)
#                     .prefetch(tf.data.AUTOTUNE))

In [12]:
# len(list(valid_ds))

In [13]:
# BATCH_SIZE

In [14]:
# test_features


In [15]:
# """Extract train/test feature embeddings"""
# print(f'Extracting train set features')
# train_features = extract_features(dataset=(raw_train_ds
#                                         .map(decode_example(IMG_SIZE))
#                                         .map(as_tuple(x='image', y='label'))
#                                         .batch(BATCH_SIZE)
#                                         .prefetch(tf.data.AUTOTUNE)), model=feature_extractor)
# print(f'Extracting test set features')
# test_features = extract_features(dataset=(raw_test_ds
#                                         .map(decode_example(IMG_SIZE))
#                                         .map(as_tuple(x='image', y='label'))
#                                         .batch(BATCH_SIZE)
#                                         .prefetch(tf.data.AUTOTUNE)), model=feature_extractor)
# print('Features Dataset spec: ', train_features.element_spec)

In [16]:
# partitioned_dataset = synthesize_by_sharding_over_labels(train_features, 
#                                                          num_partitions=4,
# #                                                          shuffle_labels=True)

In [17]:
# len(valid_ds)

In [18]:
# valid_ds = (test_features
#                     .cache()
#                     .map(as_tuple(x='image', y='label'))
#                     .prefetch(tf.data.AUTOTUNE))

In [19]:
# len(list(partitioned_dataset[]))

In [20]:
# print('Clients:', len(partitioned_dataset))
# for client_id in partitioned_dataset:
#     dist = get_label_distribution(partitioned_dataset[client_id])
#     print(f'Client {client_id}: {dist}')

In [21]:
# partitioned_dataset[0]

In [22]:
from openfl.interface.interactive_api.experiment import DataInterface

class CIFAR10FedDataset(DataInterface):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        # shard_descriptor.get_split(...) returns a tf.data.Dataset
        # Check cifar10_shard_descriptor.py for details
        
        self.train_set = self._shard_descriptor.get_split('train')
        self.valid_set = self._shard_descriptor.get_split('valid')
        self.train_size = self._shard_descriptor.get_split('train_size')
        self.valid_size = self._shard_descriptor.get_split('test_size')
        
    def get_train_loader(self):
        """Output of this method will be provided to tasks with optimizer in contract"""
        return self.train_set

    def get_valid_loader(self):
        """Output of this method will be provided to tasks without optimizer in contract"""
        return self.valid_set
    
    def get_train_data_size(self) -> int:
        """Information for aggregation"""
        return self.train_size

    def get_valid_data_size(self) -> int:
        """Information for aggregation"""
        return self.valid_size

### Create CIFAR10 federated dataset

In [23]:
fed_dataset = CIFAR10FedDataset()

Register Model


In [24]:
# # Define model
# model = tf.keras.Sequential([
#     tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
#     tf.keras.layers.MaxPooling2D((2, 2)),
#     tf.keras.layers.BatchNormalization(),
#     tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
#     tf.keras.layers.MaxPooling2D((2, 2)),
#     tf.keras.layers.BatchNormalization(),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(10, activation=None),
# ], name='simplecnn')
# model.summary()

# # Define optimizer
# optimizer = tf.optimizers.Adam(learning_rate=1e-4)

# # Loss and metrics. These will be used later.
# loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
# val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()


model = SLDA(n_components=feature_extractor.output_shape[-1],
             num_classes=ds_info.features['label'].num_classes)

model.compile(metrics=['accuracy'])
# Create ModelInterface
framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=None, framework_plugin=framework_adapter)

BEFORE BUILD SLDA


2022-12-20 10:26:36.004181: I tensorflow/core/util/cuda_solvers.cc:179] Creating GpuSolver handles for stream 0x81ec2a0


AFTER BUILD SLDA


2022-12-20 10:26:38.383679: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


In [25]:
# import copy

In [26]:
# res_model = copy.deepcopy(model)
# res_model.__class__ = SLDA
# res_model.compile(metrics=['accuracy'])

In [27]:
# import numpy as np

In [28]:
model.weights

[<tf.Variable 'means:0' shape=(10, 1280) dtype=float32, numpy=
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>,
 <tf.Variable 'counts:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'sigma:0' shape=(1280, 1280) dtype=float32, numpy=
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>,
 <tf.Variable 'sigma_inv:0' shape=(1280, 1280) dtype=float32, numpy=
 array([[10000.,     0.,     0., ...,     0.,     0.,     0.],
        [    0., 10000.,     0., ...,     0.,     0.,     0.],
        [    0.,     0.,

In [29]:
# res_model.get_weights()

In [30]:
# a = np.array(res_model.get_weights())         # save weights in a np.array of np.arrays
# res_model.set_weights(a + 1)

In [31]:
# res_model.weights

In [32]:
# model.set_weights(res_model.get_weights())

In [33]:
# model.weights

In [34]:
# import cloudpickle

In [35]:
# filepath = 'somefilemodelinterface'


In [36]:
# with open(filepath, 'wb') as f:
#     cloudpickle.dump(model, f)

In [37]:
# import cloudpickle

In [38]:
# filepath = 'model_ckpt_path'

In [39]:
# # with open(filepath, 'wb') as f:
#     cloudpickle.dump(model, f)

In [40]:
# model1 = tf.keras.Sequential([
#     tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
#     tf.keras.layers.MaxPooling2D((2, 2)),
#     tf.keras.layers.BatchNormalization(),
#     tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
#     tf.keras.layers.MaxPooling2D((2, 2)),
#     tf.keras.layers.BatchNormalization(),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(10, activation=None),
# ], name='simplecnn')

In [41]:
# with open(filepath, 'wb') as f:
#     cloudpickle.dump(model1, f)

Test Serializer and Deserializer

In [42]:
# from importlib import import_module
# from os.path import splitext

In [43]:
# def plan_build(template, settings):
#         """
#         Create an instance of a openfl Component or Federated DataLoader/TaskRunner.

#         Args:
#             template: Fully qualified class template path
#             settings: Keyword arguments to class constructor

#         Returns:
#             A Python object
#         """
#         class_name = splitext(template)[1].strip('.')
#         module_path = splitext(template)[0]

# #         Plan.logger.info(f'Building [red]🡆[/] Object [red]{class_name}[/] '
# #                          f'from [red]{module_path}[/] Module.',
# #                          extra={'markup': True})
# #         Plan.logger.debug(f'Settings [red]🡆[/] {settings}',
# #                           extra={'markup': True})
# #         Plan.logger.debug(f'Override [red]🡆[/] {override}',
# #                           extra={'markup': True})

# #         settings.update(**override)

#         module = import_module(module_path)
#         instance = getattr(module, class_name)(**settings)

#         return instance

In [44]:
# serializer = plan_build('openfl.plugins.interface_serializer.cloudpickle_serializer.CloudpickleSerializer', {})
# framework_adapter = plan_build('openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin', {})

In [45]:
# framework_adapter.serialization_setup()

In [46]:
# framework_adapter

In [47]:
# serializer.serialize(model, 'custom_loader_ob1.pkl')

In [48]:
# m = serializer.restore_object('custom_loader_ob1.pkl')

In [49]:
# type(model)

In [50]:
# type(m)

In [51]:
# from tensorflow import keras
# import tensorflow as tf
# import cloudpickle

# filename = 'custom_loader_ob1.pkl'

# with open(filename, 'rb') as f:
#     mi = cloudpickle.load(f)

# print(type(mi))
# print(mi)


In [52]:
# mi.model.compile(metrics=['accuracy'])

# print(mi.model.get_weights())

# mi.model.fit(tf.random.uniform((1, 1280)), tf.random.uniform((1,), minval=0, maxval=10, dtype=tf.int64))

# val = mi.model.evaluate(tf.random.uniform((32, 1280)), tf.random.uniform((32,), minval=0, maxval=10, dtype=tf.int64))
# # val = mi.model.evaluate(tf.random.uniform((32, 1280)))

# print(val)

## Define and register FL tasks

In [53]:
from tensorflow.keras.utils import Progbar
from openfl.interface.aggregation_functions import FedSLDAAggregation

agg_fn = FedSLDAAggregation()
TI = TaskInterface()

@TI.register_fl_task(model='model', data_loader='dataset', optimizer='optimizer', device='device')
@TI.set_aggregation_function(agg_fn)
def train(model, dataset, optimizer, device, warmup=False):
    res_model = SLDA(1280, 10)
    res_model.set_weights(model.get_weights())
    res_model.compile(metrics=['accuracy'])
    res_model.fit(dataset, epochs=1)
    train_acc = res_model.evaluate(dataset.unbatch().batch(128))
    
    # Exit
    model.set_weights(res_model.get_weights())
    return {'train_acc': train_acc,}


@TI.register_fl_task(model='model', data_loader='dataset', device='device')     
def validate(model, dataset, device):
    # Run a validation loop at the end of each epoch.
    res_model = SLDA(1280, 10)
    res_model.set_weights(model.get_weights())
    res_model.compile(metrics=['accuracy'])
    val_acc = res_model.evaluate(dataset)
    return {'validation_accuracy': val_acc,}

## Time to start a federated learning experiment

In [54]:
# create an experimnet in federation
experiment_name = 'cifar10_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [55]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
ROUNDS_TO_TRAIN = 10
fl_experiment.start(model_provider=MI,
                   task_keeper=TI,
                   data_loader=fed_dataset,
                   rounds_to_train=ROUNDS_TO_TRAIN,
                   opt_treatment='CONTINUE_GLOBAL', )
fl_experiment.stream_metrics()

INFO:tensorflow:Assets written to: ram://23d7aa69-3dfa-4b91-8861-267400bd042f/assets


INFO:tensorflow:Assets written to: ram://23d7aa69-3dfa-4b91-8861-267400bd042f/assets


In [56]:
model.weights

[<tf.Variable 'means:0' shape=(10, 1280) dtype=float32, numpy=
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>,
 <tf.Variable 'counts:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>,
 <tf.Variable 'sigma:0' shape=(1280, 1280) dtype=float32, numpy=
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>,
 <tf.Variable 'sigma_inv:0' shape=(1280, 1280) dtype=float32, numpy=
 array([[10000.,     0.,     0., ...,     0.,     0.,     0.],
        [    0., 10000.,     0., ...,     0.,     0.,     0.],
        [    0.,     0.,

In [57]:
model.get_weights()

[array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 array([[10000.,     0.,     0., ...,     0.,     0.,     0.],
        [    0., 10000.,     0., ...,     0.,     0.,     0.],
        [    0.,     0., 10000., ...,     0.,     0.,     0.],
        ...,
        [    0.,     0.,     0., ..., 10000.,     0.,     0.],
        [    0.,     0.,     0., ...,     0., 10000.,     0.],
        [    0.,     0.,     0., ...,     0.,     0., 10000.]],
       dtype=fl