In [1]:
from layers import *
import tensorflow as tf

In [2]:
import tensorflow as tf
from tensorflow.keras import layers,activations,backend,constraints,initializers,regularizers

class NFM(layers.Layer):
    def __init__(self,
                 dropout_rate=0.2,
                 **kwargs):
        self.dropout_rate=dropout_rate
        super().__init__(**kwargs)
        self.bnlayer=layers.BatchNormalization(name='bi_interaction_bn')
        self.dropoutlayer=layers.Dropout(rate=dropout_rate)

    def call(self,inputs):
        """
        inputs: (bs,field_num,emb_size)
        output: (bs,emb_size)
        """
        sum_square_part = tf.square(tf.reduce_sum(inputs, axis=1)) # (batch, emb_size)
        square_sum_part = tf.reduce_sum(tf.square(inputs), axis=1) # (batch, emb_size)
        nfm = 0.5 * (sum_square_part - square_sum_part)
        nfm = self.bnlayer(nfm)
        nfm = self.dropoutlayer(nfm)
        return nfm

if __name__ == "__main__":
    obj=NFM()
    a=tf.ones((64,23,16))
    obj(a)

2023-10-08 10:09:01.043586: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2023-10-08 10:09:01.043752: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-10-08 10:09:01.044784: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [3]:
a=tf.ones((64,23,16))
b=tf.expand_dims(tf.ones((64,23)),axis=-1)

In [4]:
import tensorflow.compat.v2 as tf
from keras import activations
from keras import backend
from keras import constraints
from keras import initializers
from keras import regularizers
from keras.engine import base_layer
from keras.layers.rnn import rnn_utils
from keras.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin

from tensorflow.keras import layers

# isort: off
from tensorflow.python.platform import tf_logging as logging

RECURRENT_DROPOUT_WARNING_MSG = (
    "RNN `implementation=2` is not supported when `recurrent_dropout` is set. "
    "Using `implementation=1`."
)


class AUGRUCell(DropoutRNNCellMixin, base_layer.BaseRandomLayer):
    def __init__(
        self,
        units,
        activation="tanh",
        recurrent_activation="sigmoid",
        use_bias=True,
        kernel_initializer="glorot_uniform",
        recurrent_initializer="orthogonal",
        bias_initializer="zeros",
        kernel_regularizer=None,
        recurrent_regularizer=None,
        bias_regularizer=None,
        kernel_constraint=None,
        recurrent_constraint=None,
        bias_constraint=None,
        dropout=0.0,
        recurrent_dropout=0.0,
        reset_after=True,
        **kwargs,
    ):
        # By default use cached variable under v2 mode, see b/143699808.
        if tf.compat.v1.executing_eagerly_outside_functions():
            self._enable_caching_device = kwargs.pop(
                "enable_caching_device", True
            )
        else:
            self._enable_caching_device = kwargs.pop(
                "enable_caching_device", False
            )
        super().__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.dropout = min(1.0, max(0.0, dropout))
        self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))

        implementation = kwargs.pop("implementation", 2)
        if self.recurrent_dropout != 0 and implementation != 1:
            logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
            self.implementation = 1
        else:
            self.implementation = implementation
        self.reset_after = reset_after
        self.state_size = self.units
        self.output_size = self.units

    def build(self, input_shape):
        input_shape=input_shape[0]
        super().build(input_shape)
        input_dim = input_shape[-1]
        default_caching_device = rnn_utils.caching_device(self)
        self.kernel = self.add_weight(
            shape=(input_dim, self.units * 3),
            name="kernel",
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            caching_device=default_caching_device,
        )
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 3),
            name="recurrent_kernel",
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint,
            caching_device=default_caching_device,
        )

        if self.use_bias:
            if not self.reset_after:
                bias_shape = (3 * self.units,)
            else:
                # separate biases for input and recurrent kernels
                # Note: the shape is intentionally different from CuDNNGRU
                # biases `(2 * 3 * self.units,)`, so that we can distinguish the
                # classes when loading and converting saved weights.
                bias_shape = (2, 3 * self.units)
            self.bias = self.add_weight(
                shape=bias_shape,
                name="bias",
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                caching_device=default_caching_device,
            )
        else:
            self.bias = None
        self.built = True

    def call(self, inputs, states, training=None):
        inputs, att_score = inputs
        h_tm1 = (
            states[0] if tf.nest.is_nested(states) else states
        )  # previous memory

        dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
        rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
            h_tm1, training, count=3
        )

        if self.use_bias:
            if not self.reset_after:
                input_bias, recurrent_bias = self.bias, None
            else:
                input_bias, recurrent_bias = tf.unstack(self.bias)

        if self.implementation == 1:
            if 0.0 < self.dropout < 1.0:
                inputs_z = inputs * dp_mask[0]
                inputs_r = inputs * dp_mask[1]
                inputs_h = inputs * dp_mask[2]
            else:
                inputs_z = inputs
                inputs_r = inputs
                inputs_h = inputs

            x_z = backend.dot(inputs_z, self.kernel[:, : self.units])
            x_r = backend.dot(
                inputs_r, self.kernel[:, self.units : self.units * 2]
            )
            x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2 :])

            if self.use_bias:
                x_z = backend.bias_add(x_z, input_bias[: self.units])
                x_r = backend.bias_add(
                    x_r, input_bias[self.units : self.units * 2]
                )
                x_h = backend.bias_add(x_h, input_bias[self.units * 2 :])

            if 0.0 < self.recurrent_dropout < 1.0:
                h_tm1_z = h_tm1 * rec_dp_mask[0]
                h_tm1_r = h_tm1 * rec_dp_mask[1]
                h_tm1_h = h_tm1 * rec_dp_mask[2]
            else:
                h_tm1_z = h_tm1
                h_tm1_r = h_tm1
                h_tm1_h = h_tm1

            recurrent_z = backend.dot(
                h_tm1_z, self.recurrent_kernel[:, : self.units]
            )
            recurrent_r = backend.dot(
                h_tm1_r, self.recurrent_kernel[:, self.units : self.units * 2]
            )
            if self.reset_after and self.use_bias:
                recurrent_z = backend.bias_add(
                    recurrent_z, recurrent_bias[: self.units]
                )
                recurrent_r = backend.bias_add(
                    recurrent_r, recurrent_bias[self.units : self.units * 2]
                )

            z = self.recurrent_activation(x_z + recurrent_z)
            r = self.recurrent_activation(x_r + recurrent_r)

            # reset gate applied after/before matrix multiplication
            if self.reset_after:
                recurrent_h = backend.dot(
                    h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
                )
                if self.use_bias:
                    recurrent_h = backend.bias_add(
                        recurrent_h, recurrent_bias[self.units * 2 :]
                    )
                recurrent_h = r * recurrent_h
            else:
                recurrent_h = backend.dot(
                    r * h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]
                )

            hh = self.activation(x_h + recurrent_h)
        else:
            if 0.0 < self.dropout < 1.0:
                inputs = inputs * dp_mask[0]

            # inputs projected by all gate matrices at once
            matrix_x = backend.dot(inputs, self.kernel)
            if self.use_bias:
                # biases: bias_z_i, bias_r_i, bias_h_i
                matrix_x = backend.bias_add(matrix_x, input_bias)

            x_z, x_r, x_h = tf.split(matrix_x, 3, axis=-1)

            if self.reset_after:
                # hidden state projected by all gate matrices at once
                matrix_inner = backend.dot(h_tm1, self.recurrent_kernel)
                if self.use_bias:
                    matrix_inner = backend.bias_add(
                        matrix_inner, recurrent_bias
                    )
            else:
                # hidden state projected separately for update/reset and new
                matrix_inner = backend.dot(
                    h_tm1, self.recurrent_kernel[:, : 2 * self.units]
                )

            recurrent_z, recurrent_r, recurrent_h = tf.split(
                matrix_inner, [self.units, self.units, -1], axis=-1
            )

            z = self.recurrent_activation(x_z + recurrent_z)
            r = self.recurrent_activation(x_r + recurrent_r)

            if self.reset_after:
                recurrent_h = r * recurrent_h
            else:
                recurrent_h = backend.dot(
                    r * h_tm1, self.recurrent_kernel[:, 2 * self.units :]
                )

            hh = self.activation(x_h + recurrent_h)
        # previous and candidate state mixed by update gate
        z= z*att_score
        h = z * h_tm1 + (1 - z) * hh
        new_state = [h] if tf.nest.is_nested(states) else h
        return h, new_state

    
    
from layers.sequential.din import DiceMLP
class ActivationUnit(layers.Layer):

    def __init__(self, units=[32, 16], dropout_rate=0.2):
        super().__init__()
        self.dicemlp = DiceMLP(1,units,dropout_rate=dropout_rate)

    def build(self,input_shape):
        self.emb_size=input_shape[2]
        self.maxseqlen=input_shape[1]
        
    def call(self, sequence, query):
        """
            query : 单独的ad的embedding mat -> batch * 1 * embed
            user_behavior : 行为特征矩阵 -> batch * seq_len * embed
        """
        key=sequence
        query=tf.tile(query,[1,self.maxseqlen]) ## (bs,emb_dim) -> (bs,emb_dim*seqlen)
        query=tf.reshape(query,[-1,self.maxseqlen,self.emb_size]) # (bs,emb_dim*seqlen) -> (bs,seqlen,emb_dim)
        ## key 和 query的一些交互,如果有其它特征other_emb 可以一并拼接
        din_all=tf.concat([query,key,query-key,query*key],axis=-1) ## (bs,seqlen,emb_dim*4)
        ## 一个2层mlp (emb_dim*2,emb_dim,1)
        din_w=self.dicemlp(din_all)
        return din_w

'''   Attention Pooling Layer   '''
class AttentionPoolingLayer(layers.Layer):

    def __init__(self,units=[32, 16], dropout_rate=0.2, return_score=False):
        super().__init__()
        self.active_unit = ActivationUnit(units,dropout_rate)
        self.return_score = return_score

    def call(self,user_behavior,query, mask_bool):
        """
            query_ad : 单独的ad的embedding mat -> batch * 1 * embed
            user_behavior : 行为特征矩阵 -> batch * seq_len * embed
            mask : 被padding为0的行为置为false -> batch * seq_len * 1
        """

        # attn weights
        attn_weights = self.active_unit(user_behavior,query)
        # mul weights and sum pooling
        if self.return_score:
            output = user_behavior * attn_weights *tf.expand_dims(tf.cast(mask_bool,tf.float32),axis=-1)
            return output

        return attn_weights
    

sequence=tf.ones((64,20,16))
query=tf.ones((64,16))
mask_bool=tf.sequence_mask(tf.ones((64,))*10,20)
obj=ActivationUnit()
obj(sequence,query)
obj=AttentionPoolingLayer(return_score=True)
tmp=obj(sequence,query,mask_bool)


ModuleNotFoundError: No module named 'keras'

In [20]:
# https://blog.csdn.net/qq_42363032/article/details/122365548
class InterestEvolutionLayer(layers.Layer):
    def __init__(self,
                 attention_units=[32,16],
                 attention_dropout=0,
                 gru_type="augru",
                 **kwargs):
        self.gru_type=gru_type
        self.attention_units=attention_units
        self.attention_dropout=attention_dropout
        super().__init__(**kwargs)
        
    def build(self,input_shape):
        emb_size=input_shape[-1]
        self.maxseqlen=input_shape[1]
        if self.gru_type.upper() == "AUGRU":
            self.attention = AttentionPoolingLayer(units=self.attention_units,
                                                   dropout_rate=self.attention_dropout,
                                                   return_score=True)
            self.rnn=layers.RNN(AUGRUCell(units=emb_size))
        elif self.gru_type.upper() == "GRU":
            self.attention = AttentionPoolingLayer(dropout=self.attention_dropout, units=self.attention_units)
            self.rnn=layers.GRU(units=emb_size,return_sequences=True)
            
        elif self.gru_type.upper() in ("AIGRU") :
            self.attention = AttentionPoolingLayer(dropout=self.attention_dropout, units=self.attention_units)
            self.rnn=layers.GRU(units=emb_size)
        else:
            raise NotImplementedError

    def call(self,gru_interests,query_ad,seqlen):
        mask_bool=tf.sequence_mask(seqlen,maxlen=self.maxseqlen)
        if self.gru_type.upper() == 'GRU':
            # GRU后接attention
            out = self.rnn(gru_interests, mask=mask_bool)  # (2000, 40, 4)
            out = self.attention(out, query_ad, mask_bool)  # (2000, 40, 4)
            out = tf.reduce_sum(out, axis=1)  # (2000, 4)
        elif self.gru_type.upper() == 'AIGRU':
            # AIGRU
            att_score = self.attention(gru_interests, query_ad, mask_bool)  # (2000, 40, 1)
            out = att_score * gru_interests  # (2000, 40, 4)
            out = self.rnn(out, mask=mask_bool)  # (2000, 4)
        elif self.gru_type.upper() == 'AUGRU':
            # AGRU or AUGRU
            att_score = self.attention(gru_interests, query_ad,  mask_bool)  # (2000, 40, 1)
            out = self.rnn((gru_interests, att_score), mask=mask_bool)  # (2000, 4)
        else:
            raise NotImplementedError
        return out
    

rnn=layers.RNN(AUGRUCell(units=16),return_sequences=True)
inputs=tf.ones((64,20,16))
query=tf.ones((64,16))
seqlen=tf.ones((64))*10
att_score=tf.ones((64,20,1))
rnn((inputs,att_score))
obj=InterestEvolutionLayer()
obj(inputs,query,seqlen)

<tf.Tensor: shape=(64, 16), dtype=float32, numpy=
array([[ 0.01445294, -0.44694364, -0.9463091 , ..., -0.29685327,
        -0.70927423,  0.48810744],
       [ 0.01445294, -0.44694364, -0.9463091 , ..., -0.29685327,
        -0.70927423,  0.48810744],
       [ 0.01445294, -0.44694364, -0.9463091 , ..., -0.29685327,
        -0.70927423,  0.48810744],
       ...,
       [ 0.01445294, -0.44694364, -0.9463091 , ..., -0.29685327,
        -0.70927423,  0.48810744],
       [ 0.01445294, -0.44694364, -0.9463091 , ..., -0.29685327,
        -0.70927423,  0.48810744],
       [ 0.01445294, -0.44694364, -0.9463091 , ..., -0.29685327,
        -0.70927423,  0.48810744]], dtype=float32)>

In [21]:
from layers.sequential.din import DiceMLP
class ActivationUnit(layers.Layer):

    def __init__(self,att_dropout=0.2, att_fc_dims=[32, 16]):
        super().__init__()
        self.dicemlp = DiceMLP(1,att_fc_dims)

    def build(self,input_shape):
        self.emb_size=input_shape[2]
        self.maxseqlen=input_shape[1]
        
    def call(self, sequence, query):
        """
            query : 单独的ad的embedding mat -> batch * 1 * embed
            user_behavior : 行为特征矩阵 -> batch * seq_len * embed
        """
        key=sequence
        query=tf.tile(query,[1,self.maxseqlen]) ## (bs,emb_dim) -> (bs,emb_dim*seqlen)
        query=tf.reshape(query,[-1,self.maxseqlen,self.emb_size]) # (bs,emb_dim*seqlen) -> (bs,seqlen,emb_dim)
        ## key 和 query的一些交互,如果有其它特征other_emb 可以一并拼接
        din_all=tf.concat([query,key,query-key,query*key],axis=-1) ## (bs,seqlen,emb_dim*4)
        ## 一个2层mlp (emb_dim*2,emb_dim,1)
        din_w=self.dicemlp(din_all)
        return din_w

'''   Attention Pooling Layer   '''
class AttentionPoolingLayer(layers.Layer):

    def __init__(self, att_dropout=0.2, att_fc_dims=[32, 16], return_score=False):
        super().__init__()
        self.active_unit = ActivationUnit(att_dropout, att_fc_dims)
        self.return_score = return_score

    def call(self,user_behavior,query, mask_bool):
        """
            query_ad : 单独的ad的embedding mat -> batch * 1 * embed
            user_behavior : 行为特征矩阵 -> batch * seq_len * embed
            mask : 被padding为0的行为置为false -> batch * seq_len * 1
        """

        # attn weights
        attn_weights = self.active_unit(user_behavior,query)
        # mul weights and sum pooling
        if self.return_score:
            output = user_behavior * attn_weights *tf.expand_dims(tf.cast(mask_bool,tf.float32),axis=-1)
            return output

        return attn_weights
    

sequence=tf.ones((64,20,16))
query=tf.ones((64,16))
mask_bool=tf.sequence_mask(tf.ones((64,))*10,20)
obj=ActivationUnit()
obj(sequence,query)
obj=AttentionPoolingLayer(return_score=True)
tmp=obj(sequence,query,mask_bool)

In [5]:
from layers import MLP
from tensorflow.keras import layers
class InterestExtractLayer(layers.Layer):
    def __init__(self, extract_units, extract_dropout=0,**kwargs):
        super().__init__(**kwargs)
        # 用一个mlp来计算 auxiliary loss
    
    def build(self,input_shape):
        emb_size=input_shape[-1]
        # 传统的GRU来抽取时序行为的兴趣表示  return_sequences=True: 返回上次的输出
        self.auxiliary_mlp = MLP(1,extract_units,dropout_rate=extract_dropout)
        self.rnn = layers.GRU(units=emb_size, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)

    def call(self, user_behavior, mask_bool, neg_user_behavior=None, neg_mask_bool=None):
        """
            user_behavior : (2000, 40, 4)
            mask : (2000, 40, 1)
            neg_user_behavior : (2000, 39, 4)
            neg_mask : (2000, 39, 1)
        """
        # 将0-1遮罩变换bool
        # mask_bool = tf.cast(tf.squeeze(mask, axis=2), tf.bool)  # (2000, 40)

        gru_interests = self.rnn(user_behavior, mask=mask_bool)  # (2000, 40, 4)

        # 计算Auxiliary Loss，只在负采样的时候计算 aux loss
        if neg_user_behavior is not None:
            # 此处用户真实行为user_behavior为图中的e，GRU抽取的状态为图中的h
            gru_embed = gru_interests[:, 1:]  # (2000, 39, 4)
            #neg_mask_bool = tf.cast(tf.squeeze(neg_mask, axis=2), tf.bool)  # (2000, 39)

            # 正样本的构建  选取下一个行为作为正样本
            pos_seq = tf.concat([gru_embed, user_behavior[:, 1:]], -1)  # (2000, 39, 8)
            pos_res = self.auxiliary_mlp(pos_seq)  # (2000, 39, 1)
            pos_res = tf.sigmoid(pos_res[neg_mask_bool])  # 选择不为0的进行sigmoid  (N, 1) ex: (18290, 1)
            pos_target = tf.ones_like(pos_res, tf.float16)  # label

            # 负样本的构建  从未点击的样本中选取一个作为负样本
            neg_seq = tf.concat([gru_embed, neg_user_behavior], -1)  # (2000, 39, 8)
            neg_res = self.auxiliary_mlp(neg_seq)  # (2000, 39, 1)
            neg_res = tf.sigmoid(neg_res[neg_mask_bool])
            neg_target = tf.zeros_like(neg_res, tf.float16)

            # 计算辅助损失 二分类交叉熵
            aux_loss = tf.keras.losses.binary_crossentropy(tf.concat([pos_res, neg_res], axis=0), tf.concat([pos_target, neg_target], axis=0))
            aux_loss = tf.cast(aux_loss, tf.float32)
            aux_loss = tf.reduce_mean(aux_loss)

            return gru_interests, aux_loss

        return gru_interests, 0

In [6]:
obj=InterestExtractLayer([128,128],0.2)
user_behavior=tf.ones((20, 40, 4))
mask=tf.sequence_mask(tf.ones((20,))*30,40)
neg_user_behavior=tf.ones((20, 39, 4))
neg_mask = tf.sequence_mask(tf.ones((20,))*30,39)
obj.call(user_behavior,mask)

AttributeError: 'InterestExtractLayer' object has no attribute 'rnn'

In [6]:
import itertools

import tensorflow as tf
from tensorflow.keras import layers

class DynamicMultiRNN(layers.Layer):
    def __init__(self, num_units=None, rnn_type='LSTM', return_sequence=True, num_layers=2, num_residual_layers=1, dropout_rate=0.2,
                 forget_bias=1.0, **kwargs):

        self.num_units = num_units
        self.return_sequence = return_sequence
        self.rnn_type = rnn_type
        self.num_layers = num_layers
        self.num_residual_layers = num_residual_layers
        self.dropout = dropout_rate
        self.forget_bias = forget_bias
        super(DynamicMultiRNN, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        input_seq_shape = input_shape[0]
        if self.num_units is None:
            self.num_units = input_seq_shape.as_list()[-1]
        if self.rnn_type == "LSTM":
            single_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, forget_bias=self.forget_bias)
        elif self.rnn_type == "GRU":
            single_cell = tf.nn.rnn_cell.GRUCell(self.num_units)
        else:
            raise ValueError("Unknown unit type %s!" % self.rnn_type)
        dropout = self.dropout if tf.keras.backend.learning_phase() == 1 else 0
        single_cell = tf.nn.rnn_cell.DropoutWrapper(cell=single_cell, input_keep_prob=(1.0 - dropout))
        cell_list = []
        for i in range(self.num_layers):
            residual = (i >= self.num_layers - self.num_residual_layers)
            if residual:
                single_cell_residual = tf.nn.rnn_cell.ResidualWrapper(single_cell)
                cell_list.append(single_cell_residual)
            else:
                cell_list.append(single_cell)
        if len(cell_list) == 1:
            self.final_cell = cell_list[0]
        else:
            self.final_cell = tf.nn.rnn_cell.MultiRNNCell(cell_list)
        super(DynamicMultiRNN, self).build(input_shape)

    def call(self, input_list, mask=None, training=None):
        rnn_input, sequence_length = input_list
        with tf.name_scope("rnn"), tf.variable_scope("rnn", reuse=tf.AUTO_REUSE):
            rnn_output, hidden_state = tf.nn.dynamic_rnn(self.final_cell, inputs=rnn_input, sequence_length=tf.squeeze(sequence_length),
                                                         dtype=tf.float32, scope=self.name)
        if self.return_sequence:
            return rnn_output
        else:
            return tf.expand_dims(hidden_state, axis=1)

    def compute_output_shape(self, input_shape):
        rnn_input_shape = input_shape[0]
        if self.return_sequence:
            return rnn_input_shape
        else:
            return (None, 1, rnn_input_shape[2])

    def get_config(self, ):
        config = {'num_units': self.num_units, 'rnn_type': self.rnn_type, 'return_sequence':self.return_sequence, 'num_layers': self.num_layers,
                  'num_residual_layers': self.num_residual_layers, 'dropout_rate': self.dropout}
        base_config = super(DynamicMultiRNN, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))