In [1]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as kl
import tensorflow.keras.backend as K
import tensorflow.distributions as tfd
import numpy as np

In [2]:
import matplotlib.pyplot as pl

In [3]:
from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers

In [4]:
class Minimal2D_RNNCell(kl.Layer):
    data_format="channels_last"
    kernel_size=(1, 1)
    padding="valid"
    strides=(1, 1)
    dilation_rate=(1,1)

    def __init__(
        self, 
        filters, 
        activation=None, 
        kernel_initializer='glorot_uniform',
        bias_initializer='zeros',
        kernel_regularizer=None,
        use_bias=True,
        bias_regularizer=None,
        **kwargs
    ):
        self.filters = filters
        self.state_size = filters
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.kernel_shape = self.kernel_size + (input_shape[-1], self.filters)
        self.kernel = self.add_weight(
            'kernel',
            shape = (input_shape[-1], self.filters),
            initializer = self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            dtype = self.dtype,
        )
        if self.use_bias:
            self.bias = self.add_weight(
                'bias',
                shape = [self.filters,],
                initializer = self.bias_initializer,
                regularizer = self.bias_regularizer,
                dtype = self.dtype,
            )
            
        self.recurrent_kernel = self.add_weight(
            shape=(self.filters, self.filters),
            initializer='uniform',
            name='recurrent_kernel')
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = K.dot(inputs, self.kernel) + K.dot(prev_output, self.recurrent_kernel)
        if self.use_bias:
            h = h + self.bias
        output = self.activation(h)
        return output, [output]
    
    def input_conv(self, x, w, b=None, padding='valid'):
        conv_out = K.conv2d(x, w, strides=self.strides,
                            padding=padding,
                            data_format=self.data_format,
                            dilation_rate=self.dilation_rate)
        if b is not None:
            conv_out = K.bias_add(conv_out, b,
                                data_format=self.data_format)
        return conv_out


In [5]:
from tensorflow.python.keras.layers.convolutional_recurrent import ConvRNN2D

In [6]:
def get_model(batch_size=None, stateful=False):
    if stateful:
        assert batch_size is not None, "In stateful RNN, `batch_size` must be known."
    x = inp = kl.Input([ None, 320, 200, 3], batch_size=batch_size)

    rnn_cell = Minimal2D_RNNCell(5, activation="sigmoid")
    print(rnn_cell.state_size)
    rnn_layer = ConvRNN2D(rnn_cell, return_sequences=True, stateful=stateful)

    y = rnn_layer(x)

    model = keras.Model(x, y)
    return model

In [7]:
model = get_model()

5


In [8]:
res = model.predict(np.zeros([1, 10, 320, 200, 3])).squeeze()

In [9]:
res.shape

(10, 320, 200, 5)

In [10]:
res[:, 0,0, :]

array([[0.5       , 0.5       , 0.5       , 0.5       , 0.5       ],
       [0.8341756 , 0.7784132 , 0.80142903, 0.67302865, 0.7627833 ],
       [0.92426056, 0.87568986, 0.89744836, 0.7567    , 0.8596069 ],
       [0.9425277 , 0.89887655, 0.91924775, 0.7808389 , 0.8834182 ],
       [0.9462007 , 0.9037547 , 0.923757  , 0.78619486, 0.8884894 ],
       [0.9469449 , 0.90475357, 0.9246757 , 0.78730404, 0.8895314 ],
       [0.94709605, 0.9049569 , 0.92486244, 0.7875304 , 0.88974357],
       [0.94712675, 0.90499824, 0.9249004 , 0.7875765 , 0.88978666],
       [0.94713306, 0.90500665, 0.92490816, 0.7875858 , 0.8897954 ],
       [0.94713426, 0.90500844, 0.92490965, 0.78758776, 0.8897973 ]],
      dtype=float32)

In [11]:
model_serv = get_model(batch_size=1, stateful=True)
model_serv.set_weights(model.get_weights())

5


In [12]:
model_serv.reset_states()
for i in range(10):
    res = model_serv.predict(np.zeros([1, 1, 320, 200, 3])).squeeze()
    print(res[0,0, :])

[0.5 0.5 0.5 0.5 0.5]
[0.8341756  0.7784132  0.80142903 0.67302865 0.7627833 ]
[0.92426056 0.87568986 0.89744836 0.7567     0.8596069 ]
[0.9425277  0.89887655 0.91924775 0.7808389  0.8834182 ]
[0.9462007  0.9037547  0.923757   0.78619486 0.8884894 ]
[0.9469449  0.90475357 0.9246757  0.78730404 0.8895314 ]
[0.94709605 0.9049569  0.92486244 0.7875304  0.88974357]
[0.94712675 0.90499824 0.9249004  0.7875765  0.88978666]
[0.94713306 0.90500665 0.92490816 0.7875858  0.8897954 ]
[0.94713426 0.90500844 0.92490965 0.78758776 0.8897973 ]
