In [131]:
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf

In [132]:
INPUT_SIZE = 2
SEQ_LEN = 1
HIDDEN_SIZE = 3
BATCH_SIZE = 4

In [133]:
input_tensor = np.random.randn(BATCH_SIZE, SEQ_LEN, INPUT_SIZE).astype(np.float32)
kernel = np.random.randn(INPUT_SIZE, 3 * HIDDEN_SIZE).astype(np.float32)
recurrent_kernel = np.random.randn(HIDDEN_SIZE, 3 * HIDDEN_SIZE).astype(np.float32)

In [134]:
def tf_gru(input_, kernel_, recurrent_kernel_, bias_):
    inputs = tf.keras.layers.Input(shape=(SEQ_LEN, INPUT_SIZE), dtype=tf.float32)
    gru_layer = tf.keras.layers.GRU(HIDDEN_SIZE, return_sequences=True)
    outputs = gru_layer(inputs)
    model = tf.keras.Model(inputs, outputs)
    # tensorflow has a bug in GRU that the shape of the bias is always (2, n) while it should be (n).
    # therefore in this implement we set all bias to zero
    gru_layer.set_weights([kernel_, recurrent_kernel_, np.zeros(gru_layer.weights[2].shape)])
    return model(input_)

tf_output = tf_gru(input_tensor, kernel, recurrent_kernel, bias)
print(tf_output)

tf.Tensor(
[[[-0.1105953  -0.00810354  0.25300372]]

 [[-0.15829751 -0.13685134 -0.16579199]]

 [[ 0.160389    0.29439434  0.08103199]]

 [[-0.2037702   0.09957132 -0.26256567]]], shape=(4, 1, 3), dtype=float32)


In [135]:
def scratch_gru(input_, kernel_, recurrent_kernel_, bias_):
    w_z, w_r, w_h = np.split(kernel_, 3, axis=-1)
    r_z, r_r, r_h = np.split(recurrent_kernel_, 3, axis=-1)
    hidden = np.zeros(shape=(BATCH_SIZE, HIDDEN_SIZE), dtype=np.float32)
    output = np.zeros(shape=(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), dtype=np.float32)
    for seq_idx in range(SEQ_LEN):
        input = input_[:, seq_idx, :]
        gate_z = tf.math.sigmoid(tf.linalg.matmul(input, w_z) + tf.linalg.matmul(hidden, r_z))
        gate_r = tf.math.sigmoid(tf.linalg.matmul(input, w_r) + tf.linalg.matmul(hidden, r_r))
        hidden_hat = tf.math.tanh(tf.linalg.matmul(input, w_h) + gate_r * tf.linalg.matmul(hidden, r_h))
        hidden = (1 - gate_z) * hidden_hat + gate_z * hidden
        output[:, seq_idx, :] = hidden
    return output

scratch_output = scratch_gru(input_tensor, kernel, recurrent_kernel, bias)
print(scratch_output)

[[[-0.1105953  -0.00810354  0.25300375]]

 [[-0.15829751 -0.13685134 -0.16579199]]

 [[ 0.160389    0.29439434  0.08103199]]

 [[-0.2037702   0.09957132 -0.26256567]]]


In [136]:
print(np.allclose(tf_output, scratch_output))

True
