In [3]:
import numpy as np
import tensorflow as tf

In [10]:
from tensorflow.python.keras.layers import Layer, Dense, LSTM

class AttentionLSTM(Layer):
    def __init__(self, units, ft_dim):
        super(AttentionLSTM, self).__init__()
        self.units = units
        self.ft_dim = ft_dim

    def build(self, input_shape):
        self.feature_trans_weight = Dense(self.ft_dim,
                                          input_shape=input_shape,
                                          kernel_initializer='glorot_uniform',
                                          bias_initializer='glorot_uniform',
                                          activation='tanh')

        self.lstm = LSTM(self.units, return_sequences=True, return_state=False)

    def call(self, inputs, *args, **kwargs):
        # inputs: (B, N, T, V)
        # B: batch_size, N: number of stocks, T: sequence length, V: number of features
        feature = tf.reshape(inputs, shape=(-1, inputs.shape[2], inputs.shape[3]))  # (B*N, T, V)
        feature = self.feature_trans_weight(feature) # (B*N, T, ft_dim)
        feature = self.lstm(feature) # (B*N, T, units)
        attn = tf.matmul(tf.expand_dims(feature[:,-1,:], axis=1), feature, transpose_b=True) # (B*N, 1, T)
        attn = tf.nn.softmax(attn, axis=-1) # (B*N, 1, T)
        attn = tf.reshape(attn, shape=(inputs.shape[0], inputs.shape[1], -1, 1)) # (B, N, T, 1)

        feature = tf.reshape(feature, shape=(inputs.shape[0], inputs.shape[1], feature.shape[1], feature.shape[2])) # (B, N, T, units)
        context = tf.math.reduce_sum(attn*feature, axis=2) # (B, N, units)
        return context

class ContextNormalize(Layer):
    def __init__(self):
        super(ContextNormalize, self).__init__()

    def build(self, input_shape):
        self.norm_weight = self.add_weight(name='norm_weight',
                                           shape=input_shape,
                                           initializer='uniform',
                                           trainable=True)
        self.norm_bias = self.add_weight(name='norm_bias',
                                         shape=input_shape,
                                         initializer='uniform',
                                         trainable=True)

    def call(self, context, *args, **kwargs):
        # context: (B, N, units)
        context = (context - tf.math.reduce_mean(context))/tf.math.reduce_std(context) # (B, N, units)
        context = self.norm_weight * context + self.norm_bias
        return context

In [11]:
x = np.random.rand(32,50,128,7)
attn = AttentionLSTM(72,64)(x)

inputs: (32, 50, 128, 7)
feature_trans: (1600, 128, 64)
lstm: (1600, 128, 72)
attn: (1600, 1, 128)
attn_soft: (1600, 1, 128)
attn_reshape: (32, 50, 128, 1)
feature_reshape: (32, 50, 128, 72)
context: (32, 50, 72)
