# Federated SLDA TF.Data and Custom Dataset Tutorial
Using `tf.data` API

In [1]:
# Install TF if not already. We recommend TF2.7 or greater.
# !pip install tensorflow==2.8

## Imports

In [2]:
%env TF_FORCE_GPU_ALLOW_GROWTH=true
import tensorflow as tf

import tensorflow_datasets as tfds

# Config/Options
from openfl.cl.config import Decoders
from openfl.cl.config import IMG_AUGMENT_LAYERS

# Model/Loss definitions
from openfl.cl.models.slda import SLDA
from openfl.cl.models import losses
from openfl.cl.models.utils import extract_features

# Dataset handling (synthesize/build/query)
from openfl.cl.lib.dataset.repository import DatasetRepository
from openfl.cl.lib.dataset.utils import as_tuple, decode_example, get_label_distribution
from openfl.cl.lib.dataset.synthesizer import synthesize_by_sharding_over_labels

env: TF_FORCE_GPU_ALLOW_GROWTH=true


2022-11-29 18:15:01.653252: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
  from .autonotebook import tqdm as notebook_tqdm
2022-11-29 18:15:03.195954: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-29 18:15:03.723960: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2022-11-29 18:15:03.724011: I tensorflow/core/common_runtime/gpu/gpu_device.cc

In [3]:
import tensorflow as tf
print('TensorFlow', tf.__version__)

TensorFlow 2.9.0


Experiment Options


In [4]:
DATASET = 'cifar10'   # If loading a public TensorFlow dataset
# DATASET = '/tmp/repository/vege'  # If loading a local TFRecord dataset

IMG_SIZE = (32, 32)
BATCH_SIZE = 32
SHUFFLE_BUFFER = 16384

## Connect to the Federation

Start `Director` and `Envoy` before proceeding with this cell. 

This cell connects this notebook to the Federation.

In [5]:
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
director_port = 50051

# Create a Federation
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)

## Query Datasets from Shard Registry

In [6]:
shard_registry = federation.get_shard_registry()
shard_registry

{'ENVS1': {'shard_info': node_info {
    name: "ENVS1"
  }
  shard_description: "CIFAR10 dataset, shard number 1/2.\nSamples [Train/Valid]: [25000/10000]"
  sample_shape: "32"
  sample_shape: "32"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-11-29 18:14:31',
  'current_time': '2022-11-29 18:15:04',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'ENVS2': {'shard_info': node_info {
    name: "ENVS2"
  }
  shard_description: "CIFAR10 dataset, shard number 2/2.\nSamples [Train/Valid]: [25000/10000]"
  sample_shape: "32"
  sample_shape: "32"
  sample_shape: "3"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-11-29 18:14:42',
  'current_time': '2022-11-29 18:15:04',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'}}

In [7]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

'Sample shape: (32, 32, 3), target shape: (1,)'

## Describing FL experiment

In [8]:
from openfl.interface.interactive_api.experiment import TaskInterface
from openfl.interface.interactive_api.experiment import ModelInterface
from openfl.interface.interactive_api.experiment import FLExperiment

### Register dataset

In [9]:
# """Load the dataset: Public or Local"""
# if tf.io.gfile.isdir(DATASET):
#     repo = DatasetRepository(data_dir=DATASET)
#     builder = repo.get_builder()  # Builds all versions by default
#     ds_info = builder.info
#     (raw_train_ds, raw_test_ds) = builder.as_dataset(split=['train', 'test'],
#                                                      decoders=Decoders.SIMPLE_DECODER)
# else:
# Load TFDS dataset by name (publicly-hosted on TF)
(raw_train_ds, raw_test_ds), ds_info = tfds.load(DATASET,
                                                 split=['train', 'test'],
                                                 with_info=True,
                                                 decoders=Decoders.SIMPLE_DECODER)
print('About: ', ds_info)
print('Element Spec: ', raw_train_ds.element_spec)
print('Training samples: ', len(raw_train_ds))
print('Testing samples: ', len(raw_test_ds))

About:  tfds.core.DatasetInfo(
    name='cifar10',
    full_name='cifar10/3.0.2',
    description="""
    The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
    """,
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    data_path='/home/sunilach/tensorflow_datasets/cifar10/3.0.2',
    download_size=162.17 MiB,
    dataset_size=132.40 MiB,
    features=FeaturesDict({
        'id': Text(shape=(), dtype=tf.string),
        'image': Image(shape=(32, 32, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=50000, num_shards=1>,
    },
    citation="""@TECHREPORT{Krizhevsky09learningmultiple,
        author = {Alex Krizhevsky},
        title = {Learn

Define Feature Extractor


In [10]:
backbone = tf.keras.applications.EfficientNetV2B0(
            include_top=False,
            weights='imagenet',
            input_shape=(*IMG_SIZE, 3),
            pooling='avg'
        )
backbone.trainable = False

"""Add augmentation/input layers"""
feature_extractor = tf.keras.Sequential([
    tf.keras.layers.InputLayer(backbone.input_shape[1:]),
    backbone,
], name='feature_extractor')

feature_extractor.summary()

Model: "feature_extractor"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 efficientnetv2-b0 (Function  (None, 1280)             5919312   
 al)                                                             
                                                                 
Total params: 5,919,312
Trainable params: 0
Non-trainable params: 5,919,312
_________________________________________________________________


In [11]:
from openfl.interface.interactive_api.experiment import DataInterface

class CIFAR10FedDataset(DataInterface):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        # shard_descriptor.get_split(...) returns a tf.data.Dataset
        # Check cifar10_shard_descriptor.py for details
        
        self.train_set = self._shard_descriptor.get_split('train')
        self.valid_set = self._shard_descriptor.get_split('valid')
        self.train_size = self._shard_descriptor.get_split('train_size')
        self.valid_size = self._shard_descriptor.get_split('test_size')
        
    def get_train_loader(self):
        """Output of this method will be provided to tasks with optimizer in contract"""
        return self.train_set

    def get_valid_loader(self):
        """Output of this method will be provided to tasks without optimizer in contract"""
        return self.valid_set
    
    def get_train_data_size(self) -> int:
        """Information for aggregation"""
        return self.train_size

    def get_valid_data_size(self) -> int:
        """Information for aggregation"""
        return self.valid_size

### Create CIFAR10 federated dataset

In [12]:
fed_dataset = CIFAR10FedDataset()

### Register Model


In [13]:
model = SLDA(n_components=feature_extractor.output_shape[-1],
             num_classes=ds_info.features['label'].num_classes)

model.compile(metrics=['accuracy'])
# Create ModelInterface
framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=None, framework_plugin=framework_adapter)

BEFORE BUILD SLDA


2022-11-29 18:15:06.360415: I tensorflow/core/util/cuda_solvers.cc:179] Creating GpuSolver handles for stream 0x1caa7620


AFTER BUILD SLDA


2022-11-29 18:15:07.830594: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


In [14]:
import copy

In [15]:
res_model = copy.deepcopy(model)
res_model.compile(metrics=['accuracy'])

INFO:tensorflow:Assets written to: ram://70a00af8-e427-4082-87e7-37b43bf72701/assets


INFO:tensorflow:Assets written to: ram://70a00af8-e427-4082-87e7-37b43bf72701/assets


## Define and register FL tasks

In [16]:
from tensorflow.keras.utils import Progbar

TI = TaskInterface()

@TI.register_fl_task(model='model', data_loader='dataset', optimizer='optimizer', device='device')     
def train(model, dataset, optimizer, device, warmup=False):
#     print("Train Task inside DEF")
#     print(model.weights)
#     print(type(model))
#     print(type(res_model))
    res_model.__class__ = SLDA
#     print(type(res_model))
    res_model.set_weights(model.get_weights())
#     print(res_model.weights)
#     model.compile(metrics=['accuracy'])
#     print("After Train Compile")
    res_model.fit(dataset, epochs=1)
#     print("Train Task Fit Done")
#     print(res_model.weights)
    model.set_weights(res_model.get_weights())
#     print(model.weights)
    train_acc = model.evaluate(dataset)
    print("Train Accuracy")
    print(train_acc)
    return {'train_acc': train_acc[1],}


@TI.register_fl_task(model='model', data_loader='dataset', device='device')     
def validate(model, dataset, device):
    # Run a validation loop at the end of each epoch.
#     print("Validate Task inside DEF")
#     print(model.weights)
#     print(type(model))
#     print("After Validate Compile")
    val_acc = model.evaluate(dataset)
#     print(val_acc)
    print("Validation acc: %.4f" % (float(val_acc[1]),))
    return {'validation_accuracy': val_acc[1],}

## Time to start a federated learning experiment

In [17]:
# create an experimnet in federation
experiment_name = 'cifar10_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [18]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
ROUNDS_TO_TRAIN = 2
fl_experiment.start(model_provider=MI,
                   task_keeper=TI,
                   data_loader=fed_dataset,
                   rounds_to_train=ROUNDS_TO_TRAIN,
                   opt_treatment='CONTINUE_GLOBAL', )
fl_experiment.stream_metrics()

SeRIALIZING
model_interface_file
INFO:tensorflow:Assets written to: ram://f35c9273-9edc-4dad-b259-631e63d485d1/assets


INFO:tensorflow:Assets written to: ram://f35c9273-9edc-4dad-b259-631e63d485d1/assets


SeRIALIZING
tasks_interface_file
INFO:tensorflow:Assets written to: ram://d59c7ae0-3f53-4f55-a9b8-f2b8c9902910/assets


INFO:tensorflow:Assets written to: ram://d59c7ae0-3f53-4f55-a9b8-f2b8c9902910/assets


SeRIALIZING
dataloader_interface_file
SeRIALIZING
aggregation_function_interface_file
SeRIALIZING
task_assigner_file


