In [1]:
from functools import partial

import numpy as np
import tensorflow as tf

from ocddetection.data import preprocessing

In [61]:
ds = tf.data.Dataset.from_tensor_slices(tf.random.uniform((100, 6), 0, 1, tf.float32))

In [56]:
def preprocess(dataset: tf.data.Dataset, window_size: int, batch_size):
    def split(t: tf.Tensor):
        return t[:-1], t[-1]
    
    def flatten(X: tf.data.Dataset, y: tf.data.Dataset):
        X = X.batch(window_size, drop_remainder=True)
        y = y.batch(window_size, drop_remainder=True)

        return tf.data.Dataset.zip((X, y))
    
    def label(X: tf.Tensor, y: tf.Tensor):
        return X, tf.expand_dims(tf.round(tf.math.reduce_mean(y)), axis=-1)
    
#     return dataset.map(split).window(window_size, shift=window_size // 2).flat_map(flatten).map(label)
    return dataset.map(split)

In [57]:
windows = preprocess(ds, 5, 8)

In [58]:
it = list(windows.as_numpy_iterator())

In [59]:
it[0]

(array([[8.83333921e-01, 4.47999239e-01, 9.87823129e-01, 7.49977231e-01,
         2.77589202e-01, 7.20413923e-02],
        [7.19629884e-01, 8.67615819e-01, 9.26806808e-01, 7.01342225e-01,
         4.34023261e-01, 6.80663466e-01],
        [3.55891705e-01, 1.44241691e-01, 1.67160392e-01, 7.67767668e-01,
         3.43821406e-01, 5.62742591e-01],
        [3.60177159e-01, 6.32676125e-01, 1.84718251e-01, 8.00537109e-01,
         9.27394629e-01, 6.95180655e-01],
        [3.19748521e-01, 4.93983030e-01, 6.05038404e-02, 7.28910565e-01,
         2.68368721e-02, 6.33050323e-01],
        [7.04159141e-01, 9.34039950e-01, 9.19109106e-01, 1.81986332e-01,
         7.25145936e-01, 4.97978449e-01],
        [8.19444656e-04, 7.51851916e-01, 3.21280479e-01, 4.33346272e-01,
         4.44229960e-01, 3.17025542e-01],
        [3.35283279e-02, 2.81912208e-01, 1.48193598e-01, 7.36645222e-01,
         4.84742880e-01, 8.60685110e-01],
        [7.34435081e-01, 6.43556237e-01, 9.36274767e-01, 5.03676057e-01,
       

In [79]:
model = tf.keras.Sequential([
    tf.keras.layers.Input((3, 5)),
    tf.keras.layers.GRU(8),
    tf.keras.layers.Dense(1)
])

In [82]:
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.SGD(0.01)
)

In [83]:
model.evaluate(windows)



0.6971033811569214

In [84]:
model.predict(windows)

array([[0.04385411],
       [0.08331823],
       [0.12165804],
       [0.09626167],
       [0.09225644],
       [0.16774105],
       [0.08547177],
       [0.07701292],
       [0.08184889],
       [0.10375215],
       [0.19606964],
       [0.13753055],
       [0.1033686 ],
       [0.07252976],
       [0.0239902 ],
       [0.1220748 ],
       [0.05776018],
       [0.15258412],
       [0.19321215],
       [0.09743381]], dtype=float32)

In [86]:
list(iter(windows))

[(<tf.Tensor: shape=(5, 3, 5), dtype=float32, numpy=
  array([[[0.9711813 , 0.9702368 , 0.9819026 , 0.44555366, 0.53736794],
          [0.00175786, 0.42618668, 0.7764367 , 0.37295341, 0.30762243],
          [0.11227882, 0.972563  , 0.28996825, 0.7783464 , 0.15424109]],
  
         [[0.11227882, 0.972563  , 0.28996825, 0.7783464 , 0.15424109],
          [0.0236249 , 0.09400368, 0.6970799 , 0.8281772 , 0.6299257 ],
          [0.01147389, 0.08442509, 0.85933065, 0.12593842, 0.8361325 ]],
  
         [[0.01147389, 0.08442509, 0.85933065, 0.12593842, 0.8361325 ],
          [0.67315245, 0.5122876 , 0.2191062 , 0.8280934 , 0.6038486 ],
          [0.31377542, 0.3300203 , 0.3582393 , 0.16810751, 0.4082694 ]],
  
         [[0.31377542, 0.3300203 , 0.3582393 , 0.16810751, 0.4082694 ],
          [0.4944626 , 0.6701931 , 0.47612512, 0.4476341 , 0.10698342],
          [0.871716  , 0.6150317 , 0.26029062, 0.53301275, 0.79688287]],
  
         [[0.871716  , 0.6150317 , 0.26029062, 0.53301275, 0.796882