# Federated Continual learning - SLDA Tutorial (CIFAR-10)
Using `tf.data` API

In [None]:
# Install TF if not already. We recommend TF2.7 or greater.
# !pip install tensorflow==2.9.0

### Imports

In [None]:
# %env TF_FORCE_GPU_ALLOW_GROWTH=true # Incremental growth of mem utilization on 'cuda' devices
%env CUDA_VISIBLE_DEVICES="-1" # Enforce 'cpu' platform execution

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from typing import Callable, Optional, Any


In [None]:
import tensorflow as tf
print('TensorFlow', tf.__version__)

### Define Decoders class and Image Augment Layers

In [None]:
"""Configure data extraction/loading primitives"""

class Decoders:
    """
    Decoders.SIMPLE_DECODER: Simple Image-Label only decoder.

    Performant, memory-efficient, allows caching maximum data onto memory

    About:
    - This decoder only loads `image` and `label` elements
    from the TFRecord dataset.
    - `image` decoding is skipped, i.e., loaded as a raw string.
    - You must decode `image` to `tf.tensor` using a `dataset.map()` function
    """
    SIMPLE_DECODER = tfds.decode.PartialDecoding(
          {
              'image': True,
              'label': True,
          },
          decoders={
              'image': tfds.decode.SkipDecoding(),
          })

"""Configure data augmentation while training"""
IMG_AUGMENT_LAYERS = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal')
], name='augment_layers')

### SLDA Model Definition

In [None]:
class SLDA(tf.keras.Model):
    """
    Lifelong machine learning with deep streaming Linear Discriminant Analysis.
    TL Hayes et. al. (CVPR'20): https://arxiv.org/abs/1909.01520

    LDA, at its simplest, is a linear model that learns simple n-dim multivariate
    Gaussian for each class with combined single covariance matrix.

    It can be interpreted as a Generative Model over **features** rather than high
    dimensional input images. Features can be extracted using a backbone
    pretrained on ImageNet, for instance.

    Example usage:
    ```python
    # Extract features from images using a backbone of your choice
    train_features = backbone_model(train_images)
    test_features = backbone_model(test_images)

    # LDA is a linear model over features as input, mapping to `num_classes` output classes.
    # Each feature could be `n_components` dimensional.
    slda_model = SLDA(n_components=512, num_classes=10)
    slda_model.compile(metrics=['accuracy'])

    # SLDA updates mean/covariance 1 sample at a time during training
    slda_model.fit(train_features.batch(1))

    # Evaluate over test dataset (any batch size works)
    slda_model.evaluate(test_features.batch(64))
    ```
    """

    def __init__(self, n_components: int, num_classes: int, shrinkage: float = 1e-4):
        """Instantiate an LDA model.

        Args:
            n_components (int):
              Input 1D feature dimension size.
            num_classes (int):
              Total output classes.
            shrinkage (float, optional):
              Shrinkage regularization factor. Defaults to 1e-4.
        """
        super(SLDA, self).__init__()
        self.n_components = n_components
        self.num_classes = num_classes
        self.shrinkage = shrinkage

        # Parameters
        self.means = self.add_weight(
            shape=(self.num_classes, self.n_components),
            initializer="zeros",
            trainable=False,
            name="means",
        )
        self.counts = self.add_weight(
            shape=(self.num_classes,),
            initializer="zeros",
            trainable=False,
            name="counts",
        )
        self.sigma = self.add_weight(
            shape=(self.n_components, self.n_components),
            initializer="zeros",
            trainable=False,
            name="sigma",
        )
        self.sigma_inv = self.add_weight(
            shape=tf.shape(self.sigma),
            initializer="zeros",
            trainable=False,
            name="sigma_inv",
        )
        self._steps = self.add_weight(
            shape=(), initializer="zeros", trainable=False, name="steps"
        )
        self._require_update = tf.Variable(initial_value=1.0, trainable=False)

        # Build by call
        self(tf.random.uniform((1, self.n_components)))

    def fit(self, X, **kwargs):
        if isinstance(X, tf.data.Dataset):
            (x, _) = next(iter(X))
            if x.shape[0] > 1:
                raise Exception(
                    "batch>1 for training dataset is not supported (expected batch=1)"
                )
            super().fit(X, **kwargs)

    def train_step(self, data):
        """Update mean/covariance for the given (x,y) pair"""
        # Unpack
        x, y = data

        # Calculate scatter
        x_minus_mu = x - tf.gather(self.means, y)
        scatter = tf.matmul(tf.transpose(x_minus_mu, [1, 0]), x_minus_mu)
        delta = scatter * tf.cast(self._steps / (self._steps + 1), tf.float32)

        # Update means, counts, sigma
        self.sigma.assign(
            (tf.cast(self._steps, tf.float32) * self.sigma + delta)
            / tf.cast(self._steps + 1, tf.float32)
        )
        self.means.assign(
            tf.tensor_scatter_nd_add(
                self.means, [y], x_minus_mu / (tf.gather(self.counts, y) + 1)
            )
        )
        self.counts.assign(tf.tensor_scatter_nd_add(self.counts, [y], [1]))
        self._require_update.assign(1.0)
        self._steps.assign_add(1)

        history = dict()
        return history

    def test_step(self, data):
        x, y = data
        y_pred = self(x)
        self.compiled_metrics.update_state(y, y_pred)
        history = {m.name: m.result() for m in self.metrics}
        return history

    @tf.function
    def update_inv(self):
        """Update inverse of regularized covariance"""
        reg_sigma = (1 - self.shrinkage) * self.sigma + self.shrinkage * tf.eye(
            tf.shape(self.sigma)[0]
        )
        self.sigma_inv.assign(tf.linalg.pinv(reg_sigma))
        self._require_update.assign(0.0)

    @tf.function
    def forward(self, x):
        # Forward pass
        m_T = tf.transpose(self.means, [1, 0])
        W = tf.matmul(self.sigma_inv, m_T)
        b = -0.5 * tf.reduce_sum(m_T * W, axis=0)
        logits = tf.matmul(x, W) + b
        return logits

    def call(self, x):
        """Inference step"""
        tf.cond(
            tf.cast(self._require_update, tf.bool),
            true_fn=lambda: self.update_inv(),
            false_fn=lambda: tf.no_op(),
        )
        return self.forward(x)


### Extract Features

In [None]:
def extract_features(dataset: tf.data.Dataset, model: Any) -> tf.data.Dataset:
    """Extract feature embeddings from the model for each image in the dataset.

    Args:
        dataset (tf.data.Dataset):
          A `tf.data.Dataset` instance with raw tensor image keyed as `image` and `label` as label.
        model (Any):
          A callable of type `tf.keras.Sequential` or `tf.keras.Model` or equivalent that can take `image`
          batch as input and return feature embedding.

    Returns:
        A `tf.data.Dataset` instance with each sample being a `(feature_embedding, label)` tuple
    """
    
    features = model.predict(dataset, verbose=1)
    labels = np.array(list(dataset.map(lambda x, y: y).unbatch().as_numpy_iterator()))
    return tf.data.Dataset.from_tensor_slices({"image": features, "label": labels})

### Experiment Options

In [None]:
DATASET = 'cifar10'   # If loading a public TensorFlow dataset

IMG_SIZE = (32, 32)
BATCH_SIZE = 32
SHUFFLE_BUFFER = 16384

## Connect to the Federation

Start `Director` and `Envoy` before proceeding with this cell. 

This cell connects this notebook to the Federation.

In [None]:
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
director_port = 50051

# Create a Federation
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)

## Query Datasets from Shard Registry

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

### Describing FL experiment

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface
from openfl.interface.interactive_api.experiment import ModelInterface
from openfl.interface.interactive_api.experiment import FLExperiment

### Register dataset

In [None]:
# Load TFDS dataset by name (publicly-hosted on TF)
(raw_train_ds, raw_test_ds), ds_info = tfds.load(DATASET,
                                                 split=['train', 'test'],
                                                 with_info=True,
                                                 decoders=Decoders.SIMPLE_DECODER)
print('About: ', ds_info)
print('Element Spec: ', raw_train_ds.element_spec)
print('Training samples: ', len(raw_train_ds))
print('Testing samples: ', len(raw_test_ds))

### Define backbone & feature extractor

In [None]:
backbone = tf.keras.applications.EfficientNetV2B0(
            include_top=False,
            weights='imagenet',
            input_shape=(*IMG_SIZE, 3),
            pooling='avg'
        )
backbone.trainable = False

"""Add augmentation/input layers"""
feature_extractor = tf.keras.Sequential([
    tf.keras.layers.InputLayer(backbone.input_shape[1:]),
    backbone,
], name='feature_extractor')

feature_extractor.summary()

### Shard Descriptor

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface
from openfl.interface.interactive_api.experiment import ModelInterface
from openfl.interface.interactive_api.experiment import FLExperiment
from openfl.interface.interactive_api.experiment import DataInterface

In [None]:
class CIFAR10FedDataset(DataInterface):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        # shard_descriptor.get_split(...) returns a tf.data.Dataset
        # Check cifar10_shard_descriptor.py for details
        
        self.train_set = self._shard_descriptor.get_split('train')
        self.valid_set = self._shard_descriptor.get_split('valid')
        self.train_size = self._shard_descriptor.get_split('train_size')
        self.valid_size = self._shard_descriptor.get_split('test_size')
        
    def get_train_loader(self):
        """Output of this method will be provided to tasks with optimizer in contract"""
        return self.train_set

    def get_valid_loader(self):
        """Output of this method will be provided to tasks without optimizer in contract"""
        return self.valid_set
    
    def get_train_data_size(self) -> int:
        """Information for aggregation"""
        return self.train_size

    def get_valid_data_size(self) -> int:
        """Information for aggregation"""
        return self.valid_size

In [None]:
fed_dataset = CIFAR10FedDataset()

### Model Interface register SLDA model

In [None]:
model = SLDA(n_components=feature_extractor.output_shape[-1],
             num_classes=ds_info.features['label'].num_classes)

model.compile(metrics=['accuracy'])

# Create ModelInterface
framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=None, framework_plugin=framework_adapter)

### Define Aggregation Algorithm

In [None]:
"""Federated SLDA Model Aggregation module."""

from openfl.interface.aggregation_functions.core import AggregationFunction

class FedSLDAAggregation(AggregationFunction):
    """FL SLDA aggregation."""

    def call(
        self, local_tensors, db_iterator, tensor_name, fl_round, tags
    ) -> np.ndarray:
        """Aggregate tensors.

        Args:
            local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate.
            db_iterator: iterator over history of all tensors. Columns:
                - 'tensor_name': name of the tensor.
                    Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'.
                - 'round': 0-based number of round corresponding to this tensor.
                - 'tags': tuple of tensor tags. Tags that can appear:
                    - 'model' indicates that the tensor is a model parameter.
                    - 'trained' indicates that tensor is a part of a training result.
                        These tensors are passed to the aggregator node after local learning.
                    - 'aggregated' indicates that tensor is a result of aggregation.
                        These tensors are sent to collaborators for the next round.
                    - 'delta' indicates that value is a difference between rounds
                        for a specific tensor.
                    also one of the tags is a collaborator name
                    if it corresponds to a result of a local task.

                - 'nparray': value of the tensor.
            tensor_name: name of the tensor
            fl_round: round number
            tags: tuple of tags for this tensor
        Returns:
            np.ndarray: aggregated tensor
        """
        
        # Mean of mean
        if "means:0" in tensor_name:
            weighted_mean = np.sum(np.array([x.tensor * x.weight for x in local_tensors]), axis=0)
            return weighted_mean
        
        # Add up sigma
        if "sigma:0" in tensor_name:
            sigma = np.sum(np.array([x.tensor for x in local_tensors]), axis=0)
            return sigma

        # aggregation not needed for 'count', 'variables' & 'step.
        # But still, now that these parameters are considered trainable and as a part of update, some consolidation is required?
        # Or can we not call this as a trainable parameter?
        # Currently returning mean of 'count', 'variables' & 'step. doesn't make sense!
        tensors = np.array([x.tensor for x in local_tensors])
        return np.mean(tensors, axis=0)


## Define and register FL tasks

In [None]:
from tensorflow.keras.utils import Progbar

agg_fn = FedSLDAAggregation()
TI = TaskInterface()

task_params = {
        'n_components': feature_extractor.output_shape[-1],
        'n_classes': ds_info.features['label'].num_classes
    }

@TI.add_kwargs(**task_params)
@TI.register_fl_task(model='model', data_loader='dataset', optimizer='optimizer', device='device')
@TI.set_aggregation_function(agg_fn)
def train(model, dataset, optimizer, device, n_components, n_classes, warmup=False):
    res_model = SLDA(n_components, n_classes)
    res_model.set_weights(model.get_weights())
    res_model.compile(metrics=['accuracy'])
    res_model.fit(dataset, epochs=1)
    train_acc = res_model.evaluate(dataset.unbatch().batch(128))
    
    # Exit
    model.set_weights(res_model.get_weights())
    return {'train_acc': train_acc,}


@TI.add_kwargs(**task_params)
@TI.register_fl_task(model='model', data_loader='dataset', device='device')
@TI.set_aggregation_function(agg_fn)
def validate(model, dataset, device, n_components, n_classes):
    # Run a validation loop at the end of each epoch.
    res_model = SLDA(n_components, n_classes)
    res_model.set_weights(model.get_weights())
    res_model.compile(metrics=['accuracy'])
    
    val_acc = res_model.evaluate(dataset)
    return {'validation_accuracy': val_acc,}

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'cifar10_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
ROUNDS_TO_TRAIN = 5
fl_experiment.start(model_provider=MI,
                   task_keeper=TI,
                   data_loader=fed_dataset,
                   rounds_to_train=ROUNDS_TO_TRAIN,
                   opt_treatment='CONTINUE_GLOBAL', )


In [None]:
fl_experiment.stream_metrics()