# FedSLDA: With known class IDs across clients (no reconciliation)
Streaming Linear Discriminant Analysis (SLDA), is a type of generative model that learns a linear classifier over precomputed features from a frozen feature extractor.

SLDA learns a per-class Gaussian distribution with covariance matrix that is shared across all classes. 

Objective in this notebook, is to train multiple SLDA models over synthetically-generated client datasets, average them, and obtain an aggregated model. Rinse, repeat.

### Imports

In [1]:
import sys
sys.path.append('/home/sunilach/openfl/forked-intel-openfl/openfl/cl/')
# sys.path.append('/home/sunilach/openfl/fl_cl_ebm/')

In [2]:
%env TF_FORCE_GPU_ALLOW_GROWTH=true

import tensorflow as tf
import tensorflow_datasets as tfds

# Config/Options
from config import Decoders
from config import IMG_AUGMENT_LAYERS

# Model/Loss definitions
from models.slda import SLDA
from models import losses
from models.utils import extract_features

# Dataset handling (synthesize/build/query)
from lib.dataset.repository import DatasetRepository
from lib.dataset.utils import as_tuple, decode_example, get_label_distribution
from lib.dataset.synthesizer import synthesize_by_sharding_over_labels
from lib.dataset.synthesizer import synthesize_by_dirichlet_over_labels

env: TF_FORCE_GPU_ALLOW_GROWTH=true


2022-12-20 09:36:55.278918: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
  from .autonotebook import tqdm as notebook_tqdm
2022-12-20 09:36:56.997091: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-20 09:36:57.563031: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2022-12-20 09:36:57.563083: I tensorflow/core/common_runtime/gpu/gpu_device.cc

### Experiment Options

In [3]:
DATASET = 'cifar10'   # If loading a public TensorFlow dataset
# DATASET = '/tmp/repository/vege'  # If loading a local TFRecord dataset

IMG_SIZE = (32, 32)
BATCH_SIZE = 32
SHUFFLE_BUFFER = 16384


### Load the *entire* Dataset
We deal with `tf.data.Dataset` APIs for all our simulations.

The additional argument to note here, is the `decoders`. We supply our custom `Decoders.SIMPLE_DECODER` that partially decodes the data for two main reasons:
1. It only parses `image` and `label` keys from the dataset (we're only dealing with classification problems here).
2. It 'skips' decoding the images to tensors (hence you see it as `tf.string` type). This is for performance reasons. As you'll see, we decode it when we build our data pipeline for training/testing on-the-fly.

In [4]:
"""Load the dataset: Public or Local"""
if tf.io.gfile.isdir(DATASET):
    repo = DatasetRepository(data_dir=DATASET)
    builder = repo.get_builder()  # Builds all versions by default
    ds_info = builder.info
    (raw_train_ds, raw_test_ds) = builder.as_dataset(split=['train', 'test'],
                                                     decoders=Decoders.SIMPLE_DECODER)
else:
    # Load TFDS dataset by name (publicly-hosted on TF)
    (raw_train_ds, raw_test_ds), ds_info = tfds.load(DATASET,
                                                     split=['train', 'test'],
                                                     with_info=True,
                                                     decoders=Decoders.SIMPLE_DECODER)
print('About: ', ds_info)
print('Element Spec: ', raw_train_ds.element_spec)
print('Training samples: ', len(raw_train_ds))
print('Testing samples: ', len(raw_test_ds))

About:  tfds.core.DatasetInfo(
    name='cifar10',
    full_name='cifar10/3.0.2',
    description="""
    The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
    """,
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    data_path='/home/sunilach/tensorflow_datasets/cifar10/3.0.2',
    download_size=162.17 MiB,
    dataset_size=132.40 MiB,
    features=FeaturesDict({
        'id': Text(shape=(), dtype=tf.string),
        'image': Image(shape=(32, 32, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=50000, num_shards=1>,
    },
    citation="""@TECHREPORT{Krizhevsky09learningmultiple,
        author = {Alex Krizhevsky},
        title = {Learn

### Feature Extraction
Let's choose a pretrained backbone to extract features. Since in this experiment we keep the backbone frozen and finetune only a few additional layers, it is much faster to iterate if we compute all features of all images at once.

In [5]:
"""Choose Model backbone to extract features"""
backbone = tf.keras.applications.EfficientNetV2B0(
    include_top=False,
    weights='imagenet',
    input_shape=(*IMG_SIZE, 3),
    pooling='avg'
)
backbone.trainable = False

"""Add augmentation/input layers"""
feature_extractor = tf.keras.Sequential([
    tf.keras.layers.InputLayer(backbone.input_shape[1:]),
    IMG_AUGMENT_LAYERS,
    backbone,
], name='feature_extractor')

feature_extractor.summary()

Model: "feature_extractor"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 augment_layers (Sequential)  (None, 32, 32, 3)        0         
                                                                 
 efficientnetv2-b0 (Function  (None, 1280)             5919312   
 al)                                                             
                                                                 
Total params: 5,919,312
Trainable params: 0
Non-trainable params: 5,919,312
_________________________________________________________________


In [6]:
"""Extract train/test feature embeddings"""
print(f'Extracting train set features')
train_features = extract_features(dataset=(raw_train_ds
                                        .map(decode_example(IMG_SIZE))
                                        .map(as_tuple(x='image', y='label'))
                                        .batch(BATCH_SIZE)
                                        .prefetch(tf.data.AUTOTUNE)), model=feature_extractor)
print(f'Extracting test set features')
test_features = extract_features(dataset=(raw_test_ds
                                        .map(decode_example(IMG_SIZE))
                                        .map(as_tuple(x='image', y='label'))
                                        .batch(BATCH_SIZE)
                                        .prefetch(tf.data.AUTOTUNE)), model=feature_extractor)
print('Features Dataset spec: ', train_features.element_spec)

Extracting train set features


2022-12-20 09:37:02.373892: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8700


Extracting test set features
Features Dataset spec:  {'image': TensorSpec(shape=(1280,), dtype=tf.float32, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}


### Creating a Federated Dataset
Now that we have the extracted features, we would like to partition this entire training set into `n` parts, `n` should be reminiscent of the number of participants.

In [13]:
N_CLIENTS = 7

# This returns a dictionary of partitioned datasets for each client, keyed by client_id, an integer
client_datasets = synthesize_by_sharding_over_labels(train_features, 
                                                         num_partitions=N_CLIENTS, 
                                                         shuffle_labels=True)
# client_datasets = synthesize_by_dirichlet_over_labels(train_features, 
#                                                       num_partitions=N_CLIENTS, 
#                                                       concentration_factor=0.1)

# Check the label counts of each partition
print('Clients:', len(client_datasets))
for client_id in client_datasets:
    dist = get_label_distribution(client_datasets[client_id])
    print(f'Client {client_id}: {dist}')

Clients: 7
Client 0: {4: 5000}
Client 1: {6: 5000}
Client 2: {2: 5000}
Client 3: {7: 5000}
Client 4: {3: 5000}
Client 5: {5: 5000}
Client 6: {9: 5000}


### Define `N_CLIENTS` SLDA Models

In [14]:
# [Zero Mean/Std] Initialize N_CLIENTS candidate models
client_models = {
    i: SLDA(n_components=feature_extractor.output_shape[-1],
             num_classes=ds_info.features['label'].num_classes)
    for i in range(len(client_datasets))
}
print(client_models)

# Compile all client models. No loss/optimizer since it is a gradient-free algorithm
_ = [
    client_models[i].compile(metrics=['accuracy']) 
    for i in range(len(client_datasets))
]

BEFORE BUILD SLDA
AFTER BUILD SLDA
BEFORE BUILD SLDA
AFTER BUILD SLDA
BEFORE BUILD SLDA
AFTER BUILD SLDA
BEFORE BUILD SLDA
AFTER BUILD SLDA
BEFORE BUILD SLDA
AFTER BUILD SLDA
BEFORE BUILD SLDA
AFTER BUILD SLDA
BEFORE BUILD SLDA
AFTER BUILD SLDA
{0: <models.slda.SLDA object at 0x7f7d800d9460>, 1: <models.slda.SLDA object at 0x7f7d607b2250>, 2: <models.slda.SLDA object at 0x7f7d60782430>, 3: <models.slda.SLDA object at 0x7f7d800f16a0>, 4: <models.slda.SLDA object at 0x7f7d800e74f0>, 5: <models.slda.SLDA object at 0x7f7d606fadc0>, 6: <models.slda.SLDA object at 0x7f7d60751460>}


### Federated Helper Functions
For aggregation, broadcast

In [15]:
def aggregate(models: dict):
    """Aggregates the dictionary of SLDA models"""
    # Initialize fresh model
    aggregated_model = SLDA(n_components=feature_extractor.output_shape[-1],
                        num_classes=ds_info.features['label'].num_classes)
    
    # Total datapoints across clients
    total_datapoints = tf.math.add_n([
        models[i].counts
        for i in models
    ])
    print('Total datapoints across all clients: ', total_datapoints)
    
    # Weightage per client
    weightage_per_client = {
        i: tf.expand_dims(models[i].counts / total_datapoints, -1)
        for i in models
    }
    
    print('Weightage per client: ')
    print(weightage_per_client)

    # Mean of mean
    aggregated_model.means.assign(tf.math.add_n([
        weightage_per_client[i] * models[i].means
        for i in models
    ]))
    
    # Add up covariances (TODO: Not sure if this mathematically holds)
    aggregated_model.sigma.assign(tf.math.add_n([
        models[i].sigma
        for i in models
    ]))
    
    return aggregated_model

def broadcast(aggregated_model, client_models):
    """Copies Mean/Sigma from `aggregated_model`
    and resets Counts of SLDA models"""
    for client_id in client_models:
        client_models[client_id].means.assign(aggregated_model.means)
        client_models[client_id].sigma.assign(aggregated_model.sigma)
        client_models[client_id].counts.assign(tf.zeros_like(client_models[client_id].counts))
#         client_models[client_id]._steps.assign(0)  # Necessary?

### Federated Training Loop

In [16]:
# Build test dataset pipeline (common for all)
test_ds = (test_features
            .cache()
            .map(as_tuple(x='image', y='label'))
            .batch(BATCH_SIZE)
            .prefetch(tf.data.AUTOTUNE))

# Train multiple rounds
N_ROUNDS = 5

for r in range(N_ROUNDS):
    print(f'Round {r}')
    
    # Train each client on its respective dataset
    for client_id in client_datasets:
        print(f'Training Client {client_id}')
        client_models[client_id].fit((client_datasets[client_id]
                                    .cache()
                                    .shuffle(SHUFFLE_BUFFER)
                                    .map(as_tuple(x='image', y='label'))
                                    .batch(1)  # SLDA learns 1-sample at a time. Inference can be done on batch.
                                    .prefetch(tf.data.AUTOTUNE)), epochs=1, validation_data=test_ds)
    
    # Aggregate model at end of round
    aggregated_model = aggregate(client_models)
    aggregated_model.compile(metrics=['accuracy'])
    acc = aggregated_model.evaluate(test_ds)
    print('Aggregated model accuracy:', acc)
    
    # Broadcast models to all clients
    broadcast(aggregated_model, client_models)

Round 0
Training Client 0
Training Client 1
Training Client 2
Training Client 3
Training Client 4
Training Client 5
Training Client 6
BEFORE BUILD SLDA
AFTER BUILD SLDA
Total datapoints across all clients:  tf.Tensor([   0.    0. 5000. 5000. 5000. 5000. 5000. 5000.    0. 5000.], shape=(10,), dtype=float32)
Weightage per client: 
{0: <tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[nan],
       [nan],
       [ 0.],
       [ 0.],
       [ 1.],
       [ 0.],
       [ 0.],
       [ 0.],
       [nan],
       [ 0.]], dtype=float32)>, 1: <tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[nan],
       [nan],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 1.],
       [ 0.],
       [nan],
       [ 0.]], dtype=float32)>, 2: <tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[nan],
       [nan],
       [ 1.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [nan],
       [ 0.]], dtype=float32)>, 3: <tf.Tensor: shape=(10, 1), dtype=fl

BEFORE BUILD SLDA
AFTER BUILD SLDA
Total datapoints across all clients:  tf.Tensor([   0.    0. 5000. 5000. 5000. 5000. 5000. 5000.    0. 5000.], shape=(10,), dtype=float32)
Weightage per client: 
{0: <tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[nan],
       [nan],
       [ 0.],
       [ 0.],
       [ 1.],
       [ 0.],
       [ 0.],
       [ 0.],
       [nan],
       [ 0.]], dtype=float32)>, 1: <tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[nan],
       [nan],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 1.],
       [ 0.],
       [nan],
       [ 0.]], dtype=float32)>, 2: <tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[nan],
       [nan],
       [ 1.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [nan],
       [ 0.]], dtype=float32)>, 3: <tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[nan],
       [nan],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 1.],
       [

### Summary

Regardless of the heterogeneity and number of clients, convergence is one-shot. Why? This needs to be confirmed.