In [1]:
from collections import OrderedDict

import nest_asyncio
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from ocddetection import models
from ocddetection.data import preprocessing
from ocddetection.federated.learning.impl.personalization.layers import process, utils

In [2]:
nest_asyncio.apply()

In [3]:
def __model_fn(window_size: int, hidden_size: int, dropout_rate: float) -> utils.PersonalizationLayersDecorator:
    base, personalized, model = models.personalized_bidirectional(
        window_size,
        len(preprocessing.SENSORS),
        len(preprocessing.LABEL2IDX),
        hidden_size,
        dropout_rate
    )

    return utils.PersonalizationLayersDecorator(
        base,
        personalized,
        tff.learning.from_keras_model(
            model,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            input_spec=(
                tf.TensorSpec((None, window_size, len(preprocessing.SENSORS)), dtype=tf.float32),
                tf.TensorSpec((None, window_size), dtype=tf.int32)
            ),
            metrics=[
                tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
            ]
        )
    )

In [4]:
model = __model_fn(5, 5, 0.4)

In [12]:
base_weights_type = tff.framework.type_from_tensors(model.base_model.get_weights())

In [13]:
tff.framework.type_from_tensors(tff.learning.ModelWeights.from_model(model.base_model))

StructType([('trainable', StructType([TensorType(tf.float32, [77, 5]), TensorType(tf.float32, [5]), TensorType(tf.float32, [5]), TensorType(tf.float32, [5])]) as list), ('non_trainable', StructType([TensorType(tf.float32, [5]), TensorType(tf.float32, [5])]) as list)]) as ModelWeights

In [14]:
tff.learning.framework.weights_type_from_model(model)

StructType([('trainable', StructType([TensorType(tf.float32, [77, 5]), TensorType(tf.float32, [5]), TensorType(tf.float32, [5]), TensorType(tf.float32, [5]), TensorType(tf.float32, [5, 20]), TensorType(tf.float32, [5, 20]), TensorType(tf.float32, [20]), TensorType(tf.float32, [5, 20]), TensorType(tf.float32, [5, 20]), TensorType(tf.float32, [20]), TensorType(tf.float32, [10, 5]), TensorType(tf.float32, [5]), TensorType(tf.float32, [5]), TensorType(tf.float32, [5])]) as list), ('non_trainable', StructType([TensorType(tf.float32, [5]), TensorType(tf.float32, [5]), TensorType(tf.float32, [5]), TensorType(tf.float32, [5])]) as list)]) as ModelWeights

In [None]:
def model_fn() -> tff.learning.Model:
    base, personalized, model = keras_model_fn()
    
    return utils.PersonalizationLayersDecorator(
        base=base,
        personalized=personalized,
        model=tff.learning.from_keras_model(
            model,
            loss=tf.keras.losses.BinaryCrossentropy(),
            input_spec=(
                tf.TensorSpec((None, 2), dtype=tf.float32),
                tf.TensorSpec((None, 1), dtype=tf.int32)
            ),
            metrics=[tf.keras.metrics.BinaryAccuracy()]
        )
    )

In [None]:
def client_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=0.01)

def server_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=1.0, momentum=0.9)

In [None]:
client_idx2ids = ["zero", "one"]
client_id2idx = {"zero": 0, "one": 1}

In [None]:
ds = {
    "zero": tf.data.Dataset.from_tensor_slices(
        (
            np.asarray(np.random.random((5, 2)), dtype=np.float32),
            np.zeros((5, 1), dtype=np.int32)
        )
    ).batch(5).repeat(5),
    "one": tf.data.Dataset.from_tensor_slices(
        (
            np.asarray(np.random.random((10, 2)), dtype=np.float32),
            np.ones((10, 1), dtype=np.int32)
        )
    ).batch(5).repeat(5)
}

In [None]:
dummy_model = model_fn()

In [None]:
initial_personalisation_weights = [variable.numpy() for variable in dummy_model.personalized_variables]

In [None]:
client_states = {
    "zero": client.State(0, initial_personalisation_weights),
    "one": client.State(1, initial_personalisation_weights)
}

In [None]:
def client_state_fn():
    return client.State(-1, initial_personalisation_weights)

In [None]:
iterator = iterator.iterator(model_fn, client_state_fn, server_optimizer_fn, client_optimizer_fn)

In [None]:
state = iterator.initialize()

In [None]:
state

In [None]:
for r in range(10):
    state, outputs, updated_client_states = iterator.next(
        state,
        [ds[i] for i in client_idx2ids],
        [client_states[i] for i in client_idx2ids]
    )
    
    for client_state in updated_client_states:
        client_id = client_idx2ids[client_state.client_index.numpy()]
        client_states[client_id] = client_state
    
    print('Round: {}'.format(r))
    print(outputs)