## Time Interval Based Transformer Model

The original Transformer based recommender does not take into account of the time interval between two successive interactions. The paper Time Interval Aware Self-Attention for Sequential Recommendation, Jiacheng Li, Yujie Wang, Julian McAuley, WSDM, 2020 introduced the logic of including the time information. 

The original Git repo with TF 1.x is https://github.com/JiachengLi1995/TiSASRec

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import json
import re
import random
import copy
from tqdm import tqdm
import pandas as pd
import numpy as np
import pickle

from collections import defaultdict, Counter

sys.path.insert(0, "/recsys_data/RecSys/SASRec-tf2/")

import download_and_process_amazon as dpa

ModuleNotFoundError: No module named 'retrying'

In [9]:
data_dir = "/recsys_data/RecSys/SASRec-tf2/data"
meta_filename = 'meta_Electronics.json'
encoded_file = "ae_v3.txt"

# 5-core
category = "Electronics"
download_url = f"http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_{category}_5.json.gz"
reviews_name = f"reviews_{category}_5.json"
reviews_file = os.path.join(data_dir, reviews_name)

In [4]:
print(f"Generating data for ***{category}***")
dpa.download_and_extract(reviews_name, reviews_file)

Generating data for ***Electronics***


'/recsys_data/RecSys/SASRec-tf2/data/reviews_Electronics_5.json'

In [5]:
if not os.path.exists(reviews_file + '_output'):
    reviews_output_file = dpa._reviews_preprocessing(reviews_file)

start reviews preprocessing...
Processed data in /recsys_data/RecSys/SASRec-tf2/data/reviews_Electronics_5.json_output


In [20]:
def data_process_with_time(fname, pname, K=3, sep="\t", file_write=False, add_time=False):
    User = defaultdict(list)
    Items = set()
    user_dict, item_dict = {}, {}

    with open(fname, 'r') as fr:
        for line in fr:
            u, i, t = line.rstrip().split(sep)
            User[u].append((i, t))
            Items.add(i)
    
    print(len(User), len(Items))
    item_count = 1
    for item in Items:
        item_dict[item] = item_count
        item_count += 1

    count_del = 0
    user_count = 1
    if file_write:
        print(f"Writing data in {pname}")
        with open(pname, 'w') as fw:
            for user in User.keys():
                if len(User[user]) < K:
#                     del User[user]
                    count_del += 1
                else:
                    # user_dict[user] = user_count
                    items = sorted(User[user], key=lambda x: x[1])
                    timestamps = [x[1] for x in items]
                    items = [item_dict[x[0]] for x in items]
                    for i, t in zip(items, timestamps):
                        out_txt = [str(user_count), str(i)]
                        if add_time:
                            out_txt.append(str(t))
                        fw.write(sep.join(out_txt) + "\n")
                    user_dict[user] = user_count
                    user_count += 1
    else:
        for user in User.keys():
            if len(User[user]) < K:
                # del User[user]
                count_del += 1
            else:
                User[user] = sorted(User[user], key=lambda x: x[1])
                user_dict[user] = user_count
                user_count += 1
        
    print(user_count-1, count_del)
    return user_dict, item_dict, User

In [12]:
udict, idict = dpa.data_process_with_time(reviews_output_file,
                                      os.path.join(data_dir, encoded_file),
                                      K=5,
                                      sep="\t",
                                      item_set=None,
                                      add_time=True)
len(udict), len(idict)

Read 192403 users and 63001 items
Total 192403 users and 63001 items
Total 192403 users, 0 removed
Processed model input data in /recsys_data/RecSys/SASRec-tf2/data/ae_v3.txt


In [15]:
udict, idict = dpa.data_process_with_time(os.path.join(data_dir, "ae_original.txt"),
                                          os.path.join(data_dir, encoded_file),
                                          K=3,
                                          sep="\t",
                                          item_set=None,
                                          add_time=True)
len(udict), len(idict)

Read 63161 users and 85930 items
27773 items have less than 3 interactions
47 users have less than 3 interactions
Total 63114 users and 58157 items
Total 63073 users, 41 removed
Processed model input data in /recsys_data/RecSys/SASRec-tf2/data/ae_v3.txt


(63073, 58157)

In [21]:
udict, idict, user_history = data_process_with_time(os.path.join(data_dir, "ae_original.txt"),
                                                    os.path.join(data_dir, encoded_file),
                                                    K=3,
                                                    sep="\t",
                                                    file_write=True,
                                                    add_time=True)
print(f"Retained {len(udict)} users with {len(idict)} items from {len(user_history)} users")

63161 85930
Writing data in /recsys_data/RecSys/SASRec-tf2/data/ae_v3.txt
63114 47
Retained 63114 users with 85930 items from 63161 users


In [24]:
User = dpa.data_partition(os.path.join(data_dir, "ae_v3.txt"))

Preparing data...


In [26]:
User[1]

[[71865, 1276992000.0],
 [73699, 1295568000.0],
 [76752, 1305504000.0],
 [70038, 1339632000.0],
 [52031, 1369440000.0],
 [5655, 1370736000.0],
 [67712, 1370736000.0],
 [36497, 1382659200.0],
 [54084, 1390176000.0],
 [76563, 1390176000.0],
 [58972, 1390176000.0],
 [26213, 1390176000.0],
 [62645, 1390176000.0],
 [39023, 1392076800.0],
 [49569, 1393113600.0],
 [11443, 1395187200.0],
 [83584, 1395878400.0],
 [38275, 1403740800.0]]

In [29]:
[user_train, user_valid, user_test, usernum, itemnum, timenum] = dpa.data_partition(os.path.join(data_dir, "ae_v3.txt"))

Preparing data...
Preparing done...


In [30]:
user_train[1]

[[33306, 1],
 [34165, 28],
 [35596, 42],
 [32497, 92],
 [24147, 135],
 [2640, 137],
 [31413, 137],
 [16854, 154],
 [25087, 165],
 [35505, 165],
 [27370, 165],
 [12196, 165],
 [29087, 165],
 [18026, 167],
 [23003, 169],
 [5337, 172]]

In [31]:
1295568000 - 1276992000

18576000

In [3]:
import tensorflow as tf


In [4]:
n_timesteps, n_features = 96, 6
in1 = tf.keras.Input(shape=(n_timesteps, n_features))
conv1 = tf.keras.layers.Conv1D(2, 2, strides=1)(in1)
model = tf.keras.Model(inputs=in1, outputs=conv1)

In [5]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 96, 6)]           0         
_________________________________________________________________
conv1d (Conv1D)              (None, 95, 2)             26        
Total params: 26
Trainable params: 26
Non-trainable params: 0
_________________________________________________________________


In [6]:
model.weights

[<tf.Variable 'conv1d/kernel:0' shape=(2, 6, 2) dtype=float32, numpy=
 array([[[ 0.30539536,  0.32176018],
         [-0.24371243,  0.42594963],
         [-0.28002533, -0.2270734 ],
         [-0.21885604, -0.07609427],
         [ 0.14368159, -0.5911813 ],
         [ 0.51014   ,  0.13580382]],
 
        [[ 0.27753323, -0.5951712 ],
         [-0.29096928,  0.05894744],
         [-0.31501025,  0.29697037],
         [ 0.08256078,  0.20293069],
         [-0.45090458, -0.5907528 ],
         [ 0.04409683,  0.6015304 ]]], dtype=float32)>,
 <tf.Variable 'conv1d/bias:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]

In [None]:
w = w - grad * lr
grad = average(grad_i), i = 1.,, batch_size

In [60]:
class myLSTMCell(tf.keras.layers.Layer):
    def __init__(self, out_dim):
        super(myLSTMCell, self).__init__()
        self.big_matrix = tf.keras.layers.Dense(
                            units=4*out_dim, activation=None, use_bias=True,
                            kernel_initializer='glorot_uniform',
                            bias_initializer='zeros', kernel_regularizer=None,
                            bias_regularizer=None, activity_regularizer=None, kernel_constraint=None,
                            bias_constraint=None
                        )
        self.sigmoid = tf.math.sigmoid
        self.phi = tf.math.tanh
        
    def call(self, x, states):
        h, c = states
        all_y = self.big_matrix(tf.concat([x, h], axis=-1))
        ys = tf.split(all_y, 4, axis=-1)
        fk = self.sigmoid(ys[0])
        ik = self.sigmoid(ys[1])
        ck = tf.math.multiply(fk, c) + tf.math.multiply(ik, self.phi(ys[2]))
        ok = self.sigmoid(ys[3])
        hk = tf.math.multiply(ok, self.phi(ck))
        return ok, (hk, ck)
    
class myLSTM(tf.keras.layers.Layer):
    def __init__(self, out_dim, return_sequences=False, return_state=False):
        super(myLSTM, self).__init__()
        self.cell = myLSTMCell(out_dim)
        self.out_dim = out_dim
        self.return_sequences = return_sequences
        self.return_state = return_state
        
    def call(self, x):
        b, s, h = x.shape
        prev_state = (tf.random.normal([b, h], 0, 1, tf.float32), 
                      tf.random.normal([b, h], 0, 1, tf.float32))

        if self.return_sequences:
            all_ys =[]
            if self.return_state:
                all_hs = []
                all_cs = []
                
        for ii in range(s):
            x_i = x[:, ii, :]
            y_i, new_state = self.cell(x_i, prev_state)
            if self.return_sequences:
                all_ys.append(y_i)
                if self.return_state:
                    all_hs.append(new_state[0])
                    all_cs.append(new_state[1])
            prev_state = new_state
            
        if self.return_sequences:
            if self.return_state:
                return all_ys, (all_hs, all_cs)
            else:
                return all_ys
        else:
            if self.return_state:
                return y_i, new_state
            else:
                return y_i
    
    

In [61]:
in_dim, out_dim, seq_len = 32, 32, 10
batch_dim = 8

In [64]:
tf.random.set_seed(5)
x = tf.random.normal([batch_dim, seq_len, in_dim], 0, 1, tf.float32)
lstm = myLSTM(out_dim)
y = lstm(x)

In [65]:
y

<tf.Tensor: shape=(8, 32), dtype=float32, numpy=
array([[0.5599368 , 0.23224984, 0.49254808, 0.5926059 , 0.5832714 ,
        0.5794184 , 0.5536603 , 0.44985098, 0.3392653 , 0.4666337 ,
        0.5440046 , 0.51614356, 0.52931595, 0.48723206, 0.45453623,
        0.5113864 , 0.60725   , 0.5892867 , 0.46545208, 0.5244474 ,
        0.5753214 , 0.4265751 , 0.6043889 , 0.6281752 , 0.47321272,
        0.34909117, 0.34366786, 0.65943074, 0.3755182 , 0.27139494,
        0.6502329 , 0.47100565],
       [0.30076072, 0.63117933, 0.38542297, 0.43152928, 0.64638555,
        0.49973753, 0.66102844, 0.52569866, 0.42737994, 0.51383024,
        0.3163207 , 0.5394806 , 0.35988444, 0.76918477, 0.33002552,
        0.44780493, 0.41091138, 0.5326197 , 0.33237004, 0.5004793 ,
        0.40655237, 0.68841195, 0.3764247 , 0.30550352, 0.7358771 ,
        0.6590759 , 0.37678885, 0.45530462, 0.5368906 , 0.3561387 ,
        0.49229378, 0.3945621 ],
       [0.38124165, 0.6049407 , 0.6947774 , 0.62442875, 0.5124052 ,
 

In [66]:
lstm2 = tf.keras.layers.LSTM(out_dim)
output = lstm2(x)
output

<tf.Tensor: shape=(8, 32), dtype=float32, numpy=
array([[ 0.05627489,  0.0765749 , -0.17148174,  0.04548818,  0.03238086,
         0.09054439,  0.29433888,  0.02722451, -0.10929022, -0.1042285 ,
         0.10032048, -0.08055569,  0.182501  ,  0.03270156, -0.05111875,
         0.15232173, -0.1829652 , -0.19298   ,  0.00699575, -0.2746634 ,
         0.3781533 ,  0.10242426,  0.2531554 , -0.37442407,  0.16805875,
         0.00163681,  0.15103143, -0.19702393, -0.10702232,  0.02750298,
         0.01500319, -0.28751376],
       [-0.08968791,  0.05153225, -0.14298752, -0.21957736, -0.10381615,
        -0.12201183,  0.24632297,  0.20631096,  0.18002822, -0.22806312,
        -0.22915427, -0.18283159,  0.20275609, -0.06353931,  0.22910093,
        -0.12281308, -0.24024928, -0.09029342, -0.2840215 , -0.06294989,
         0.04115437,  0.29820126, -0.37069237,  0.04251494,  0.26545495,
         0.22440068, -0.01250129, -0.25229958, -0.11035067,  0.33077276,
         0.03193088, -0.36009035],
     

In [None]:
inputs = tf.random.normal([32, 10, 8])
lstm = tf.keras.layers.LSTM(4)
output = lstm(inputs)
print(output.shape)

lstm = tf.keras.layers.LSTM(4, return_sequences=True, return_state=True)
whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
print(whole_seq_output.shape)

In [22]:
import numpy as np
import tensorflow as tf

from keras import activations
from keras import backend
from keras import constraints
from keras import initializers
from keras import regularizers
from keras.engine.base_layer import Layer
# from keras.engine.input_spec import InputSpec
# from keras.saving.saved_model import layer_serialization
# from keras.utils import control_flow_util
# from keras.utils import generic_utils
# from keras.utils import tf_utils
# from tensorflow.python.platform import tf_logging as logging
# from tensorflow.python.util.tf_export import keras_export
# from tensorflow.tools.docs import doc_controls

# Check the documentation for customized layer
# https://keras.io/api/layers/base_layer/
class LSTMCell(tf.keras.layers.Layer):
    def __init__(self,
               units,
               activation='tanh',
               recurrent_activation='hard_sigmoid',
               use_bias=True,
               kernel_initializer='glorot_uniform',
               recurrent_initializer='orthogonal',
               bias_initializer='zeros',
               unit_forget_bias=True,
               kernel_regularizer=None,
               recurrent_regularizer=None,
               bias_regularizer=None,
               kernel_constraint=None,
               recurrent_constraint=None,
               bias_constraint=None,
               dropout=0.,
               recurrent_dropout=0.,
               **kwargs):
        if units < 0:
            raise ValueError(f'Received an invalid value for units, expected '
                           f'a positive integer, got {units}.')
        # By default use cached variable under v2 mode, see b/143699808.
        if tf.compat.v1.executing_eagerly_outside_functions():
            self._enable_caching_device = kwargs.pop('enable_caching_device', True)
        else:
            self._enable_caching_device = kwargs.pop('enable_caching_device', False)
        super(LSTMCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.unit_forget_bias = unit_forget_bias

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.dropout = min(1., max(0., dropout))
        self.recurrent_dropout = min(1., max(0., recurrent_dropout))
        implementation = kwargs.pop('implementation', 1)
        if self.recurrent_dropout != 0 and implementation != 1:
            logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
            self.implementation = 1
        else:
            self.implementation = implementation
        self.state_size = [self.units, self.units]
        self.output_size = self.units

    def build(self, input_shape):
        """
        Required to add weights that are input shape dependent
        """
        default_caching_device = _caching_device(self)
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(
            shape=(input_dim, self.units * 4),
            name='kernel',
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            caching_device=default_caching_device)
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 4),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint,
            caching_device=default_caching_device)
        if self.use_bias:
            if self.unit_forget_bias:

                def bias_initializer(_, *args, **kwargs):
                    return backend.concatenate([
                      self.bias_initializer((self.units,), *args, **kwargs),
                      initializers.get('ones')((self.units,), *args, **kwargs),
                      self.bias_initializer((self.units * 2,), *args, **kwargs),
                  ])
            else:
                bias_initializer = self.bias_initializer
            self.bias = self.add_weight(
                          shape=(self.units * 4,),
                          name='bias',
                          initializer=bias_initializer,
                          regularizer=self.bias_regularizer,
                          constraint=self.bias_constraint,
                          caching_device=default_caching_device)
        else:
            self.bias = None
        self.built = True

        
        
    def call(self, inputs, states, training=None):
        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

#         dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
#         rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
#             h_tm1, training, count=4)

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
            k_i, k_f, k_c, k_o = tf.split(
              self.kernel, num_or_size_splits=4, axis=1)
            x_i = backend.dot(inputs_i, k_i)
            x_f = backend.dot(inputs_f, k_f)
            x_c = backend.dot(inputs_c, k_c)
            x_o = backend.dot(inputs_o, k_o)
            if self.use_bias:
                b_i, b_f, b_c, b_o = tf.split(
                    self.bias, num_or_size_splits=4, axis=0)
                x_i = backend.bias_add(x_i, b_i)
                x_f = backend.bias_add(x_f, b_f)
                x_c = backend.bias_add(x_c, b_c)
                x_o = backend.bias_add(x_o, b_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            x = (x_i, x_f, x_c, x_o)
            h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
            c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
        else:
            if 0. < self.dropout < 1.:
                inputs = inputs * dp_mask[0]
            z = backend.dot(inputs, self.kernel)
            z += backend.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = backend.bias_add(z, self.bias)

            z = tf.split(z, num_or_size_splits=4, axis=1)
            c, o = self._compute_carry_and_output_fused(z, c_tm1)

        h = o * self.activation(c)
        return h, [h, c]

    def _compute_carry_and_output(self, x, h_tm1, c_tm1):
        """Computes carry and output using split kernels."""
        x_i, x_f, x_c, x_o = x
        h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
        i = self.recurrent_activation(
            x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
        f = self.recurrent_activation(x_f + backend.dot(
            h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))
        c = f * c_tm1 + i * self.activation(x_c + backend.dot(
            h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
        o = self.recurrent_activation(
            x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
        return c, o

    
def _caching_device(rnn_cell):
    """Returns the caching device for the RNN variable.
    This is useful for distributed training, when variable is not located as same
    device as the training worker. By enabling the device cache, this allows
    worker to read the variable once and cache locally, rather than read it every
    time step from remote when it is needed.
    Note that this is assuming the variable that cell needs for each time step is
    having the same value in the forward path, and only gets updated in the
    backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the
    cell body relies on any variable that gets updated every time step, then
    caching device will cause it to read the stall value.
    Args:
    rnn_cell: the rnn cell instance.
    """
    if tf.executing_eagerly():
        # caching_device is not supported in eager mode.
        return None
    if not getattr(rnn_cell, '_enable_caching_device', False):
        return None
    # Don't set a caching device when running in a loop, since it is possible that
    # train steps could be wrapped in a tf.while_loop. In that scenario caching
    # prevents forward computations in loop iterations from re-reading the
    # updated weights.
    if control_flow_util.IsInWhileLoop(tf.compat.v1.get_default_graph()):
        logging.warning(
            'Variable read device caching has been disabled because the '
            'RNN is in tf.while_loop loop context, which will cause '
            'reading stalled value in forward path. This could slow down '
            'the training due to duplicated variable reads. Please '
            'consider updating your code to remove tf.while_loop if possible.')
        return None
    if (rnn_cell._dtype_policy.compute_dtype !=
        rnn_cell._dtype_policy.variable_dtype):
        logging.warning(
            'Variable read device caching has been disabled since it '
            'doesn\'t work with the mixed precision API. This is '
            'likely to cause a slowdown for RNN training due to '
            'duplicated read of variable for each timestep, which '
            'will be significant in a multi remote worker setting. '
            'Please consider disabling mixed precision API if '
            'the performance has been affected.')
        return None
    # Cache the value on the device that access the variable.
    return lambda op: op.device



### Check the Cell

In [28]:
batch_size = 3
sentence_max_length = 5
n_features = 4
out_features = 8
new_shape = (batch_size, sentence_max_length, n_features)
# x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32)

x = tf.constant(np.reshape(np.arange(batch_size * n_features), [batch_size, n_features]), dtype = tf.float32)
states = [tf.zeros([batch_size, out_features]), tf.zeros([batch_size, out_features])]

model = LSTMCell(out_features)
model(x, states)

(<tf.Tensor: shape=(3, 8), dtype=float32, numpy=
 array([[-0.10476895,  0.01909216,  0.04259723,  0.22393304, -0.06510702,
          0.0607237 , -0.10680267, -0.3796199 ],
        [ 0.        ,  0.03360127,  0.2594817 ,  0.06467079,  0.        ,
          0.58613276, -0.02267295, -0.760425  ],
        [ 0.        ,  0.        ,  0.21239248,  0.        ,  0.        ,
          0.7520532 , -0.        , -0.7615706 ]], dtype=float32)>,
 [<tf.Tensor: shape=(3, 8), dtype=float32, numpy=
  array([[-0.10476895,  0.01909216,  0.04259723,  0.22393304, -0.06510702,
           0.0607237 , -0.10680267, -0.3796199 ],
         [ 0.        ,  0.03360127,  0.2594817 ,  0.06467079,  0.        ,
           0.58613276, -0.02267295, -0.760425  ],
         [ 0.        ,  0.        ,  0.21239248,  0.        ,  0.        ,
           0.7520532 , -0.        , -0.7615706 ]], dtype=float32)>,
  <tf.Tensor: shape=(3, 8), dtype=float32, numpy=
  array([[-0.22255842,  0.0431564 ,  0.08070901,  0.68405485, -0.120381

### A Sequential Model Using the Cell

In [52]:
from tensorflow.keras.layers import RNN
batch_size = 3
seq_len = 5
inp_dim = 4
out_dim = 8

lstm_layer = tf.keras.layers.RNN(LSTMCell(out_features),
                                 input_shape=(None, n_features),
                                 return_sequences=True,
                                 return_state=False
                                )
model = tf.keras.models.Sequential(
        [
            lstm_layer,
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Dense(out_features),
        ]
    )

new_shape = (batch_size, seq_len, inp_dim)
x = tf.constant(np.reshape(np.arange(60), new_shape), dtype = tf.float32)
model(x)

<tf.Tensor: shape=(3, 5, 8), dtype=float32, numpy=
array([[[-0.3685529 ,  0.03548042, -0.24419327, -0.2727273 ,
         -0.0813837 ,  0.0989939 , -0.11930474, -0.20282263],
        [-0.7565586 , -0.05964788, -0.45405662, -0.51776266,
         -0.58944976,  0.14515461, -0.151285  , -0.52873456],
        [-0.7174803 ,  0.02329051, -0.41103694, -0.5624658 ,
         -0.72346914,  0.06634909, -0.11814035, -0.56606746],
        [-0.6727449 ,  0.15986285, -0.3403043 , -0.49626243,
         -0.65444446,  0.11354627, -0.08819287, -0.56675744],
        [-0.6417773 ,  0.24228156, -0.27740368, -0.4014975 ,
         -0.56822044,  0.1695401 , -0.09530976, -0.53987384]],

       [[-0.5910057 ,  0.26778895, -0.21793066, -0.30680513,
         -0.5112453 ,  0.23893647, -0.0790079 , -0.5313454 ],
        [-0.56336236,  0.31183356, -0.18396243, -0.2624165 ,
         -0.5085331 ,  0.27645475, -0.05093649, -0.5521839 ],
        [-0.54038346,  0.332192  , -0.15500873, -0.20631021,
         -0.48263872,  0.

### A General Model Using the Cell

In [58]:
class spLSTM(tf.keras.Model):

    def __init__(self, **kwargs):
        super(spLSTM, self).__init__()

        self.input_dim = kwargs.get("input_dim", None)
        self.out_features = kwargs.get("out_dim", None)
        
        self.cell = LSTMCell(out_features)
        self.lstm_layer = tf.keras.layers.RNN(self.cell,
                                              input_shape=(None, self.input_dim),
                                              return_sequences=True,
                                              return_state=False
                                             )
        self.bn = tf.keras.layers.BatchNormalization()
        self.linear = tf.keras.layers.Dense(self.out_features)
        
    def call(self, x, training=None):
        y = self.lstm_layer(x)
        y = self.bn(y)
        y = self.linear(y)
        return y


new_shape = (batch_size, seq_len, inp_dim)
x = tf.constant(np.reshape(np.arange(60), new_shape), dtype = tf.float32)
model = spLSTM(input_dim=inp_dim, out_dim=out_dim)
model(x)

<tf.Tensor: shape=(3, 5, 8), dtype=float32, numpy=
array([[[ 0.05392185,  0.11244936,  0.0817205 , -0.2527224 ,
          0.09619954,  0.00545769, -0.20760998, -0.1781682 ],
        [-0.01148473,  0.26539844,  0.08757333, -0.3929581 ,
          0.12792118,  0.12631851, -0.19791119, -0.25857073],
        [-0.03495827,  0.25147653,  0.08493275, -0.41013473,
          0.09135444,  0.10869552, -0.17762847, -0.24655099],
        [-0.04157254,  0.2296985 ,  0.07936808, -0.39513916,
          0.0662499 ,  0.09229817, -0.15479581, -0.22584179],
        [-0.04791046,  0.20781992,  0.07372073, -0.37942111,
          0.04157673,  0.07604742, -0.13209486, -0.20501654]],

       [[-0.05181512,  0.1892882 ,  0.06866545, -0.362992  ,
          0.02328804,  0.06334292, -0.11399756, -0.18727951],
        [-0.0605114 ,  0.16408485,  0.06241901, -0.3478551 ,
         -0.00762257,  0.04361196, -0.08676881, -0.16338247],
        [-0.0668653 ,  0.1421248 ,  0.05674921, -0.33206227,
         -0.03237402,  0.

## Time-Aware LSTM Cell

Based on Yu et al., 2019, Adaptive User Modeling with Long and Short Term Preferences for Personalized Recommendation. 

Original LSTM equations are modified as follows: 

$$c_k = f_k \odot c_{k-1} + i_k \odot \phi(x_kW_c + h_{k-1}U_c + b_c)$$
$$c_k = f_k \odot T_\delta \odot c_{k-1} + i_k \odot T_s \odot \phi(x_kW_c + h_{k-1}U_c + b_c)$$

$$o_k = \sigma \left(x_kW_o + h_{k-1}U_o + b_o \right)$$
$$o_k = \sigma \left(x_kW_o + \underbrace{\delta_{tk}W_{\delta o} + s_{tk}W_{so}}  + h_{k-1}U_o + b_o \right)$$

where 

$$\delta_{tk} = \phi\left( W_{\delta}\log(t_k - t_{k-1}) + b_{\delta} \right) $$
$$s_{tk} = \phi\left( W_s\log(t_p - t_k) + b_s \right) $$
$$T_\delta = \sigma \left( x_kW_{x\delta} + \delta_{tk}W_{t\delta} + b_{t\delta}\right)$$
$$T_s = \sigma \left( x_kW_{xs} + s_{tk}W_{ts} + b_{ts}\right)$$


In [99]:
class TALSTMCell(tf.keras.layers.Layer):
    """
    Time Aware LSTM Cell
    """
    def __init__(self,
               units,
               activation='tanh',
               recurrent_activation='hard_sigmoid',
               use_bias=True,
               kernel_initializer='glorot_uniform',
               recurrent_initializer='orthogonal',
               bias_initializer='zeros',
               unit_forget_bias=True,
               kernel_regularizer=None,
               recurrent_regularizer=None,
               bias_regularizer=None,
               kernel_constraint=None,
               recurrent_constraint=None,
               bias_constraint=None,
               dropout=0.,
               recurrent_dropout=0.,
               **kwargs):
        if units < 0:
            raise ValueError(f'Received an invalid value for units, expected '
                           f'a positive integer, got {units}.')
        # By default use cached variable under v2 mode, see b/143699808.
        if tf.compat.v1.executing_eagerly_outside_functions():
            self._enable_caching_device = kwargs.pop('enable_caching_device', True)
        else:
            self._enable_caching_device = kwargs.pop('enable_caching_device', False)
        super(TALSTMCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.unit_forget_bias = unit_forget_bias

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.dropout = min(1., max(0., dropout))
        self.recurrent_dropout = min(1., max(0., recurrent_dropout))
        implementation = kwargs.pop('implementation', 1)
        if self.recurrent_dropout != 0 and implementation != 1:
            logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
            self.implementation = 1
        else:
            self.implementation = implementation
        self.state_size = [self.units, self.units]
        self.output_size = self.units

    def build(self, input_shape):
        """
        Required to add weights that are input shape dependent
        """
        default_caching_device = _caching_device(self)
        input_dim = input_shape[-1]-2
        self.kernel = self.add_weight(
            shape=(input_dim, self.units * 4),
            name='kernel',
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            caching_device=default_caching_device)
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 4),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint,
            caching_device=default_caching_device)
        # time related
        # W_xdelta, W_tdelta, W_xs and W_ts, Eqs. (9) & (10)
        self.time_kernel = self.add_weight(
            shape=(input_dim, self.units * 4),
            name='time_kernel',
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            caching_device=default_caching_device    
        )
        
        # W_delta and W_s for Eqs.(7) & (8)
        self.time_kernel2 = self.add_weight(
            shape=(2, input_dim),
            name='time_kernel2',
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            caching_device=default_caching_device    
        )
        
        # W_delta,o and W_so for Eq.(12)
        self.time_kernel3 = self.add_weight(
            shape=(input_dim, self.units * 2),
            name='time_kernel3',
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            caching_device=default_caching_device    
        )
        if self.use_bias:
            if self.unit_forget_bias:

                def bias_initializer(_, *args, **kwargs):
                    return backend.concatenate([
                      self.bias_initializer((self.units,), *args, **kwargs),
                      initializers.get('ones')((self.units,), *args, **kwargs),
                      self.bias_initializer((self.units * 2,), *args, **kwargs),
                  ])
            else:
                bias_initializer = self.bias_initializer
                
            self.bias = self.add_weight(
                          shape=(self.units * 4,),
                          name='bias',
                          initializer=bias_initializer,
                          regularizer=self.bias_regularizer,
                          constraint=self.bias_constraint,
                          caching_device=default_caching_device)
            # time related
            self.time_bias1 = self.add_weight(
                          shape=(input_dim * 2,),
                          name='time_bias1',
                          initializer=self.bias_initializer,
                          regularizer=self.bias_regularizer,
                          constraint=self.bias_constraint,
                          caching_device=default_caching_device)
            self.time_bias2 = self.add_weight(
                          shape=(self.units * 2,),
                          name='time_bias2',
                          initializer=self.bias_initializer,
                          regularizer=self.bias_regularizer,
                          constraint=self.bias_constraint,
                          caching_device=default_caching_device)
        else:
            self.bias = None
        self.built = True


    def call(self, inputs, states, training=None):
        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

#         dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
#         rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
#             h_tm1, training, count=4)

        # take out the time features
        tf1 = tf.expand_dims(inputs[:, -1], -1)
        tf2 = tf.expand_dims(inputs[:, -2], -1)
        inputs = inputs[:, :-2]

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
                inputs_d = inputs # delta
                inputs_s = inputs # s
            k_i, k_f, k_c, k_o = tf.split(
              self.kernel, num_or_size_splits=4, axis=1)
            W_xd, W_xs, W_td, W_ts = tf.split(
              self.time_kernel, num_or_size_splits=4, axis=1)
            W_d, W_s = tf.split(
              self.time_kernel2, num_or_size_splits=2, axis=0)
            x_i = backend.dot(inputs_i, k_i)
            x_f = backend.dot(inputs_f, k_f)
            x_c = backend.dot(inputs_c, k_c)
            x_o = backend.dot(inputs_o, k_o)
            # time related
            delta_tk = backend.dot(tf1, W_d)
            s_tk = backend.dot(tf2, W_s)
            T_d = backend.dot(inputs_d, W_xd)  # T_delta
            T_s = backend.dot(inputs_s, W_xs)  # T_s
            if self.use_bias:
                b_i, b_f, b_c, b_o = tf.split(
                    self.bias, num_or_size_splits=4, axis=0)
                
                x_i = backend.bias_add(x_i, b_i)
                x_f = backend.bias_add(x_f, b_f)
                x_c = backend.bias_add(x_c, b_c)
                x_o = backend.bias_add(x_o, b_o)

                b_d, b_s = tf.split(
                    self.time_bias1, num_or_size_splits=2, axis=0)
                delta_tk = backend.bias_add(delta_tk, b_d)
                s_tk = backend.bias_add(s_tk, b_s)                
                
                b_td, b_ts = tf.split(
                    self.time_bias2, num_or_size_splits=2, axis=0)    
                T_d = backend.bias_add(T_d, b_td)
                T_s = backend.bias_add(T_s, b_ts)

            delta_tk = self.activation(delta_tk)  # complete
            s_tk = self.activation(s_tk)  # complete

            T_d = T_d + backend.dot(delta_tk, W_td)
            T_s = T_s + backend.dot(s_tk, W_ts)
            T_delta = self.recurrent_activation(T_d)
            T_s = self.recurrent_activation(T_s)
            
            # Eq.(12)
            W_do, W_so = tf.split(
              self.time_kernel3, num_or_size_splits=2, axis=1)
            x_o = x_o + backend.dot(delta_tk, W_do)
            x_o = x_o + backend.dot(s_tk, W_so)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            x = (x_i, x_f, x_c, x_o)
            h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
            c, o = self._compute_carry_and_output(x, h_tm1, c_tm1, T_delta, T_s)
        else:
            if 0. < self.dropout < 1.:
                inputs = inputs * dp_mask[0]
            z = backend.dot(inputs, self.kernel)
            z += backend.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = backend.bias_add(z, self.bias)

            z = tf.split(z, num_or_size_splits=4, axis=1)
            c, o = self._compute_carry_and_output_fused(z, c_tm1)

        h = o * self.activation(c)
        return h, [h, c]

    def _compute_carry_and_output(self, x, h_tm1, c_tm1, T_delta, T_s):
        """Computes carry and output using split kernels."""
        x_i, x_f, x_c, x_o = x
        h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
        i = self.recurrent_activation(
            x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
        f = self.recurrent_activation(x_f + backend.dot(
            h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))
        c = f * T_delta * c_tm1 + i * T_s * self.activation(x_c + backend.dot(
            h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
        o = self.recurrent_activation(
            x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
        return c, o


## Check the Cell

In [100]:
batch_size = 3
sentence_max_length = 5
n_features = 4 + 2
out_features = 8
new_shape = (batch_size, sentence_max_length, n_features)
# x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32)

x = tf.constant(np.reshape(np.arange(batch_size * n_features), [batch_size, n_features]), dtype = tf.float32)
states = [tf.zeros([batch_size, out_features]), tf.zeros([batch_size, out_features])]

model = TALSTMCell(out_features)
model(x, states)

(<tf.Tensor: shape=(3, 8), dtype=float32, numpy=
 array([[ 0.01555797,  0.10276201,  0.00688872, -0.03902443, -0.00844985,
         -0.05784162,  0.04423954, -0.01630299],
        [ 0.05028533,  0.2974947 ,  0.        , -0.00615981,  0.        ,
          0.02751383, -0.02384676,  0.        ],
        [ 0.00584659,  0.27857697,  0.        , -0.        ,  0.        ,
          0.        ,  0.        ,  0.        ]], dtype=float32)>,
 [<tf.Tensor: shape=(3, 8), dtype=float32, numpy=
  array([[ 0.01555797,  0.10276201,  0.00688872, -0.03902443, -0.00844985,
          -0.05784162,  0.04423954, -0.01630299],
         [ 0.05028533,  0.2974947 ,  0.        , -0.00615981,  0.        ,
           0.02751383, -0.02384676,  0.        ],
         [ 0.00584659,  0.27857697,  0.        , -0.        ,  0.        ,
           0.        ,  0.        ,  0.        ]], dtype=float32)>,
  <tf.Tensor: shape=(3, 8), dtype=float32, numpy=
  array([[ 0.03604736,  0.1325443 ,  0.01069185, -0.08066487, -0.053534

In [101]:
x

<tf.Tensor: shape=(3, 6), dtype=float32, numpy=
array([[ 0.,  1.,  2.,  3.,  4.,  5.],
       [ 6.,  7.,  8.,  9., 10., 11.],
       [12., 13., 14., 15., 16., 17.]], dtype=float32)>

In [103]:
a, b, c = (1, 2, 3)

In [105]:
a, b, c

(1, 2, 3)