In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, RNN
from tensorflow.keras import activations, initializers, regularizers, constraints

class AlphaRNNCell(Layer):
    """Cell class for AlphaRNN."""
    def __init__(self, units, activation='tanh', use_bias=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 recurrent_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 dropout=0.,
                 recurrent_dropout=0.,
                 **kwargs):
        super(AlphaRNNCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.dropout = dropout
        self.recurrent_dropout = recurrent_dropout

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units),
                                                initializer=self.recurrent_initializer,
                                                name='recurrent_kernel',
                                                regularizer=self.recurrent_regularizer,
                                                constraint=self.recurrent_constraint)
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.units,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = tf.matmul(inputs, self.kernel)
        if self.use_bias:
            h = tf.nn.bias_add(h, self.bias)
        output = h + tf.matmul(prev_output, self.recurrent_kernel)
        if self.activation is not None:
            output = self.activation(output)
        return output, [output]

    def get_config(self):
        return {
            'units': self.units,
            'activation': activations.serialize(self.activation),
            'use_bias': self.use_bias,
            'kernel_initializer': initializers.serialize(self.kernel_initializer),
            'recurrent_initializer': initializers.serialize(self.recurrent_initializer),
            'bias_initializer': initializers.serialize(self.bias_initializer),
            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
            'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint),
            'recurrent_constraint': constraints.serialize(self.recurrent_constraint),
            'bias_constraint': constraints.serialize(self.bias_constraint),
            'dropout': self.dropout,
            'recurrent_dropout': self.recurrent_dropout
        }

class AlphaRNN(RNN):
    """Fully-connected RNN where the output is to be fed back to input."""
    def __init__(self, units, **kwargs):
        cell = AlphaRNNCell(units, **kwargs)
        super(AlphaRNN, self).__init__(cell, **kwargs)

    @property
    def units(self):
        return self.cell.units

    def get_config(self):
        config = super(AlphaRNN, self).get_config()
        config.update({'units': self.units})
        return config

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, RNN
from tensorflow.keras import activations, initializers, regularizers, constraints

class AlphatRNNCell(Layer):
    """Cell class for the AlphatRNN layer."""
    def __init__(self, units, activation='tanh', recurrent_activation='sigmoid',
                 use_bias=True, kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal', bias_initializer='zeros',
                 kernel_regularizer=None, recurrent_regularizer=None,
                 bias_regularizer=None, kernel_constraint=None,
                 recurrent_constraint=None, bias_constraint=None,
                 dropout=0., recurrent_dropout=0., implementation=2, **kwargs):
        super(AlphatRNNCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.dropout = dropout
        self.recurrent_dropout = recurrent_dropout

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer=self.recurrent_initializer,
            name='recurrent_kernel',
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)

        if self.use_bias:
            self.bias = self.add_weight(shape=(2 * self.units,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        dp_mask = None
        rec_dp_mask = None

        if 0 < self.dropout < 1:
            dp_mask = self.dropout_mask(inputs, self.dropout, training=True)
        if 0 < self.recurrent_dropout < 1:
            rec_dp_mask = self.dropout_mask(prev_output, self.recurrent_dropout, training=True)

        h = tf.matmul(inputs, self.kernel)
        if dp_mask is not None:
            h *= dp_mask

        if self.bias is not None:
            h = tf.nn.bias_add(h, self.bias[:self.units])

        rec_h = tf.matmul(prev_output, self.recurrent_kernel)
        if rec_dp_mask is not None:
            rec_h *= rec_dp_mask

        if self.bias is not None:
            rec_h = tf.nn.bias_add(rec_h, self.bias[self.units:])

        output = self.activation(h + rec_h)
        return output, [output]

    def get_config(self):
        config = super(AlphatRNNCell, self).get_config()
        config.update({
            'units': self.units,
            'activation': activations.serialize(self.activation),
            'recurrent_activation': activations.serialize(self.recurrent_activation),
            'use_bias': self.use_bias,
            'kernel_initializer': initializers.serialize(self.kernel_initializer),
            'recurrent_initializer': initializers.serialize(self.recurrent_initializer),
            'bias_initializer': initializers.serialize(self.bias_initializer),
            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
            'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint),
            'recurrent_constraint': constraints.serialize(self.recurrent_constraint),
            'bias_constraint': constraints.serialize(self.bias_constraint),
            'dropout': self.dropout,
            'recurrent_dropout': self.recurrent_dropout
        })
        return config

class AlphatRNN(RNN):
    """Fully-connected RNN where the output is to be fed back to input."""
    def __init__(self, units, **kwargs):
        cell = AlphatRNNCell(units, **kwargs)
        super(AlphatRNN, self).__init__(cell, **kwargs)

    def call(self, inputs, mask=None, training=None, initial_state=None):
        return super(AlphatRNN, self).call(inputs, mask=mask, training=training, initial_state=initial_state)

    def get_config(self):
        config = super(AlphatRNN, self).get_config()
        return config