In [1]:
import attr
from collections import OrderedDict
from functools import partial

import nest_asyncio
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from ocddetection import data, federated
from ocddetection.federated import client, server, models

In [2]:
nest_asyncio.apply()

In [3]:
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(list(data.LABEL2IDX.keys()), list(data.LABEL2IDX.values())),
    0
)

In [4]:
df_train, df_val, _ = data.split(
    data.files('/mnt/dsets/OpportunityUCIDataset/dataset'),
    validation=[(1, 2)],
    test=[(2, 3), (2, 4), (3, 3), (3, 4)]
)

In [5]:
train_clients, train_dict = data.to_federated(df_train)
train = {
    idx: dataset \
        .map(partial(data.preprocess, sensors=data.SENSORS, label=data.MID_LEVEL, table=table)) \
        .filter(data.filter_nan) \
        .window(30, shift=15) \
        .flat_map(partial(data.windows, window_size=30)) \
        .batch(64)

    for idx, dataset in train_dict.items()
}

In [6]:
def model_fn():
    return tff.learning.from_keras_model(
        federated.models.bidirectional(
            30,
            len(data.SENSORS),
            len(data.LABEL2IDX),
            64,
            .4
        ),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        input_spec=(
            tf.TensorSpec((None, 30, len(data.SENSORS)), dtype=tf.float32, name='x'),
            tf.TensorSpec((None, 30), dtype=tf.int32, name='y')
        ),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

def server_optimizer_fn():
    return tf.keras.optimizers.SGD(1.0)

def client_optimizer_fn():
    return tf.keras.optimizers.Adam(.01)

In [7]:
iterator = federated.iterator(model_fn, server_optimizer_fn, client_optimizer_fn)



In [8]:
state = iterator.initialize()

In [10]:
next_state, metrics = iterator.next(state, list(train.values()))

In [11]:
metrics

OrderedDict([('sparse_categorical_accuracy', 0.71770966), ('loss', 2.0949936)])