In [1]:
from functools import partial

import tensorflow as tf
import tensorflow_federated as tff

from ocddetection import metrics
from ocddetection.learning.federated.stateless import evaluation
from ocddetection.learning.federated.stateless.impl import averaging

In [2]:
window_size = 150
batch_size = 128
hidden_size = 64

In [3]:
val = evaluation.__load_data(
    '/opportunity/augmented/including_original',
    window_size,
    batch_size
)

In [4]:
struct, model_fn = averaging.create(
    window_size,
    hidden_size,
    evaluation.__optimizer_fn,
    evaluation.__metrics_fn
)

In [5]:
ckpt_manager = tff.simulation.FileCheckpointManager('./checkpoints/federated/averaging')

In [6]:
weights = ckpt_manager.load_latest_checkpoint(struct)[0].model

In [7]:
model = model_fn()

In [8]:
weights.assign_weights_to(model)

In [9]:
cm = val.data[val.clients[0]].reduce(
    tf.zeros((2, 2), dtype=tf.int32),
    partial(evaluation.__evaluation_step, model=model)
)

In [10]:
cm

<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[399,   3],
       [ 34,  46]], dtype=int32)>

In [17]:
outputs = model.report_local_outputs()

In [18]:
outputs['auc']

[<tf.Tensor: shape=(200,), dtype=float32, numpy=
 array([80., 71., 69., 67., 67., 66., 65., 61., 61., 61., 61., 60., 60.,
        60., 60., 60., 57., 57., 57., 56., 56., 56., 56., 56., 56., 55.,
        54., 54., 54., 54., 54., 54., 54., 54., 54., 54., 54., 54., 54.,
        54., 54., 54., 54., 54., 54., 54., 54., 54., 54., 53., 53., 53.,
        53., 53., 53., 53., 52., 52., 52., 51., 51., 51., 51., 51., 51.,
        49., 49., 49., 49., 49., 49., 49., 49., 49., 49., 49., 48., 48.,
        48., 48., 48., 48., 48., 48., 48., 47., 47., 47., 47., 47., 47.,
        47., 47., 47., 47., 46., 46., 46., 46., 46., 46., 46., 46., 46.,
        46., 46., 46., 45., 45., 45., 45., 43., 43., 43., 43., 43., 43.,
        42., 42., 42., 42., 41., 41., 41., 41., 40., 40., 40., 40., 40.,
        40., 40., 40., 39., 39., 39., 39., 39., 39., 39., 39., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 37., 37., 37., 37.,
        37., 37., 37., 36., 36., 36., 35., 35., 35., 35., 34., 33., 33.,
  

In [12]:
metrics = evaluation.__metrics_fn()

In [19]:
metrics[0].result()

<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

In [22]:
for variable, value in zip(metrics[0].variables, outputs['auc']):
    variable.assign(value)

In [23]:
metrics[0].result()

<tf.Tensor: shape=(), dtype=float32, numpy=0.8509942>

In [24]:
for variable, value in zip(metrics[1].variables, outputs['precision']):
    variable.assign(value)

In [25]:
metrics[1].result()

<tf.Tensor: shape=(200,), dtype=float32, numpy=
array([0.16597511, 0.6283186 , 0.71875   , 0.752809  , 0.7882353 ,
       0.825     , 0.84415585, 0.8472222 , 0.8472222 , 0.87142855,
       0.87142855, 0.8695652 , 0.8695652 , 0.8695652 , 0.88235295,
       0.88235295, 0.8769231 , 0.8769231 , 0.8769231 , 0.875     ,
       0.875     , 0.875     , 0.875     , 0.875     , 0.8888889 ,
       0.88709676, 0.8852459 , 0.8852459 , 0.8852459 , 0.8852459 ,
       0.8852459 , 0.8852459 , 0.8852459 , 0.8852459 , 0.8852459 ,
       0.8852459 , 0.8852459 , 0.8852459 , 0.8852459 , 0.8852459 ,
       0.8852459 , 0.8852459 , 0.8852459 , 0.8852459 , 0.8852459 ,
       0.8852459 , 0.9       , 0.9       , 0.9       , 0.89830506,
       0.89830506, 0.89830506, 0.89830506, 0.89830506, 0.9137931 ,
       0.9137931 , 0.9122807 , 0.9122807 , 0.9122807 , 0.91071427,
       0.91071427, 0.91071427, 0.91071427, 0.91071427, 0.91071427,
       0.9074074 , 0.9074074 , 0.9074074 , 0.9074074 , 0.9074074 ,
       0.90740