In [27]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np

In [62]:
INPUT_SIZE = 2
SEQ_LEN = 1
HIDDEN_SIZE = 2
BATCH_SIZE = 4

In [72]:
input_tensor = np.random.randn(BATCH_SIZE, SEQ_LEN, INPUT_SIZE).astype(np.float32)
kernel = np.random.randn(INPUT_SIZE, HIDDEN_SIZE * 4).astype(np.float32)
recurrent_kernel = np.random.randn(INPUT_SIZE, HIDDEN_SIZE * 4).astype(np.float32)
bias = np.random.randn(HIDDEN_SIZE * 4).astype(np.float32)

In [73]:
def tf_lstm(input_, kernel_, recurrent_kernel_, bias_):
    inputs = tf.keras.layers.Input(shape=(SEQ_LEN, INPUT_SIZE), dtype=tf.float32)
    lstm_layer = tf.keras.layers.LSTM(HIDDEN_SIZE, return_sequences=True)
    outputs = lstm_layer(inputs)
    model = tf.keras.Model(inputs, outputs)
    lstm_layer.set_weights([kernel_, recurrent_kernel_, bias_])
    return model(input_)

tf_output = tf_lstm(input_tensor, kernel, recurrent_kernel, bias)
print(tf_output)
    

tf.Tensor(
[[[-0.00261656  0.7161094 ]]

 [[-0.00569702  0.62866473]]

 [[-0.00151875  0.5827352 ]]

 [[-0.01549257  0.5073756 ]]], shape=(4, 1, 2), dtype=float32)


In [74]:
def scratch_lstm(input_, kernel_, recurrent_kernel_, bias_):
    hidden_state = np.zeros((BATCH_SIZE, HIDDEN_SIZE), dtype=np.float32)
    cell_state = np.zeros((BATCH_SIZE, HIDDEN_SIZE), dtype=np.float32)
    output = np.zeros((BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), dtype=np.float32)
    w_i = kernel_[:, 0 * HIDDEN_SIZE: 1 * HIDDEN_SIZE]
    w_f = kernel_[:, 1 * HIDDEN_SIZE: 2 * HIDDEN_SIZE]
    w_c = kernel_[:, 2 * HIDDEN_SIZE: 3 * HIDDEN_SIZE]
    w_o = kernel_[:, 3 * HIDDEN_SIZE: 4 * HIDDEN_SIZE]
    r_i = recurrent_kernel_[:, 0 * HIDDEN_SIZE: 1 * HIDDEN_SIZE]
    r_f = recurrent_kernel_[:, 1 * HIDDEN_SIZE: 2 * HIDDEN_SIZE]
    r_c = recurrent_kernel_[:, 2 * HIDDEN_SIZE: 3 * HIDDEN_SIZE]
    r_o = recurrent_kernel_[:, 3 * HIDDEN_SIZE: 4 * HIDDEN_SIZE]
    b_i = bias_[0 * HIDDEN_SIZE: 1 * HIDDEN_SIZE]
    b_f = bias_[1 * HIDDEN_SIZE: 2 * HIDDEN_SIZE]    
    b_c = bias_[2 * HIDDEN_SIZE: 3 * HIDDEN_SIZE]
    b_o = bias_[3 * HIDDEN_SIZE: 4 * HIDDEN_SIZE]
    for seq_idx in range(SEQ_LEN):
        forget_gate = tf.linalg.matmul(input_[:, seq_idx, :], w_f) + tf.linalg.matmul(hidden_state, r_f) + b_f
        forget_gate = tf.math.sigmoid(forget_gate)
        
        input_gate = tf.linalg.matmul(input_[:, seq_idx, :], w_i) + tf.linalg.matmul(hidden_state, r_i) + b_i
        input_gate = tf.math.sigmoid(input_gate)
        
        output_gate = tf.linalg.matmul(input_[:, seq_idx, :], w_o) + tf.linalg.matmul(hidden_state, r_o) + b_o
        output_gate = tf.math.sigmoid(output_gate)
        
        cell_tmp = tf.linalg.matmul(input_[:, seq_idx, :], w_c) + tf.linalg.matmul(hidden_state, r_c) + b_c
        cell_state = forget_gate * cell_state + input_gate * tf.math.tanh(cell_tmp)
        
        hidden_state = output_gate * tf.math.tanh(cell_state)
        output[:, seq_idx, :] = hidden_state
    return output

scratch_output = scratch_lstm(input_tensor, kernel, recurrent_kernel, bias)
print(scratch_output)
        

[[[-0.00261656  0.7161094 ]]

 [[-0.00569702  0.62866473]]

 [[-0.00151875  0.5827352 ]]

 [[-0.01549257  0.5073756 ]]]


In [75]:
print(np.allclose(tf_output.numpy(), scratch_output))

True
