In [36]:
import numpy as np

In [37]:
INPUT_SIZE = 1
HIDDEN_SIZE = 50

In [38]:
def sigmoid(x, derv=False):
    s = 1 / (1 + np.exp(-x))
    if derv: return s * (1 - s)
    return s

In [39]:
def tanh(x, derv=False):
    if derv: return 1 - np.tanh(x)**2
    return np.tanh(x)

In [40]:
def derv_sigmoid(x):
    return x * (1 - x)

In [41]:
def derv_tanh(x):
    return 1 - x**2

In [42]:
def MSE(y, y_pred, derv=False):
    if derv: return 2*(y_pred-y)
    return np.mean((y_pred-y)**2)

In [43]:
def init_params():
    U = np.random.randn(INPUT_SIZE, HIDDEN_SIZE) * np.sqrt(2/(INPUT_SIZE+HIDDEN_SIZE))
    W = np.random.randn(HIDDEN_SIZE, HIDDEN_SIZE) * np.sqrt(1/HIDDEN_SIZE)
    b = np.zeros((1, HIDDEN_SIZE))
    return U, W, b

In [44]:
Uz, Wz, bz = init_params()
Ur, Wr, br = init_params()
Uh, Wh, bh = init_params()

params = [Uz, Wz, bz, Ur, Wr, br, Uh, Wh, bh]
len_params = len(params)

In [45]:
def update_params(grads, lr):
    for i in range(len_params):
        params[i] -= lr * grads[i]

In [46]:
def forward_cell(xt, h_prev):
    zt = sigmoid(xt @ Uz + h_prev @ Wz + bz)
    rt = sigmoid(xt @ Ur + h_prev @ Wr + br)
    h_hat = tanh(xt @ Uh + (h_prev * rt) @ Wh + bh)
    ht = h_prev * (1 - zt) + h_hat * zt
    return ht, h_hat, rt, zt

In [47]:
def forward(x):
    global H, H_hat, R, Z
    T = x.shape[0]

    H = np.zeros((T+1, 1, HIDDEN_SIZE))

    H_hat = np.zeros((T, 1, HIDDEN_SIZE))
    R = np.zeros((T, 1, HIDDEN_SIZE))
    Z = np.zeros((T, 1, HIDDEN_SIZE))

    for t in range(T):
        H[t+1], H_hat[t], R[t], Z[t] = forward_cell(x[t:t+1], H[t])

    return H[1:] # the first doesn't count

In [48]:
forward(np.random.randn(32, INPUT_SIZE))

array([[[ 0.00119843, -0.00214816, -0.00136578, ..., -0.0016563 ,
         -0.00283519, -0.00025173]],

       [[-0.05552698,  0.0845451 ,  0.04875948, ...,  0.08012116,
          0.1440951 ,  0.01020329]],

       [[ 0.01629476, -0.02538805, -0.02305644, ..., -0.02216652,
         -0.0041121 , -0.00228248]],

       ...,

       [[ 0.01947727, -0.01873649,  0.00983924, ...,  0.02757232,
          0.0072704 , -0.02644392]],

       [[ 0.02286702, -0.03366561, -0.00218939, ...,  0.01011993,
         -0.02249714, -0.02100254]],

       [[-0.00645143,  0.0103638 ,  0.0190851 , ...,  0.04395663,
          0.03613181, -0.01028254]]])

In [49]:
def backward_cell(xt, h_prev, h_hat, r, z, dh):
    dz = dh * (h_hat - h_prev)
    dh_hat = dh * z
    dh_prev_direct = dh * (1 - z)
        
    da_h = dh_hat * derv_tanh(h_hat)
    dU_h = xt.T @ da_h
    dW_h = (h_prev * r).T @ da_h
    dbh = np.sum(da_h, axis=0, keepdims=True)
    dh_prev_candidate = (da_h @ Wh.T) * r
    dr_candidate = (da_h @ Wh.T) * h_prev
        
    da_r = dr_candidate * derv_sigmoid(r)
    dU_r = xt.T @ da_r
    dW_r = h_prev.T @ da_r
    dbr = np.sum(da_r, axis=0, keepdims=True)
        
    da_z = dz * derv_sigmoid(z)
    dU_z = xt.T @ da_z
    dW_z = h_prev.T @ da_z
    dbz = np.sum(da_z, axis=0, keepdims=True)
        
    dx_t = da_h @ Uh.T + da_r @ Ur.T + da_z @ Uz.T
    dh_prev_update = da_z @ Wz.T
    dh_prev_reset = da_r @ Wr.T
    dh_prev = dh_prev_direct + dh_prev_candidate + dh_prev_update + dh_prev_reset
    return dU_z, dW_z, dbz, dU_r, dW_r, dbr, dU_h, dW_h, dbh, dx_t, dh_prev

In [50]:
def init_grads(U, W, b):
    return np.zeros_like(U), np.zeros_like(W), np.zeros_like(b)

In [51]:
def backward(x, y_true, y_pred, learn=True, lr=0.001):
    T = x.shape[0]

    dUz, dWz, dbz = init_grads(Uz, Wz, bz)
    dUr, dWr, dbr = init_grads(Ur, Wr, br)
    dUh, dWh, dbh = init_grads(Uh, Wh, bh)

    dx = np.zeros_like(x)
    dh_next = np.zeros((1, HIDDEN_SIZE))

    for t in reversed(range(T)):
        xt = x[t:t+1]
        yt = y_true[t:t+1]
        outp = y_pred[t:t+1]
        h_prev = H[t]
        h_hatt = H_hat[t]
        rt = R[t]
        zt = Z[t]

        dh = MSE(yt, outp, derv=True) + dh_next

        (dUzt, dWzt, dbzt, dUrt, dWrt, dbrt,
         dUht, dWht, dbht, dx[t], dh_next) = backward_cell(xt, h_prev, h_hatt, rt, zt, dh)
        
        dUz += dUzt; dWz += dWzt; dbz += dbzt
        dUr += dUrt; dWr += dWrt; dbr += dbrt
        dUh += dUht; dWh += dWht; dbh += dbht

    dUz /= T; dWz /= T; dbz /= T
    dUr /= T; dWr /= T; dbr /= T
    dUh /= T; dWh /= T; dbh /= T

    grads = dUz, dWz, dbz, dUr, dWr, dbr, dUh, dWh, dbh

    if learn:
        update_params(grads, lr)

    return dx, grads

In [52]:
backward(
    x=np.random.randn(32, INPUT_SIZE),
    y_true=np.random.randn(32, HIDDEN_SIZE),
    y_pred=np.random.randn(32, HIDDEN_SIZE),
    learn=True,
    lr=0.001
)

(array([[ 0.43529756],
        [-2.14187344],
        [-3.10292159],
        [-2.82162605],
        [ 1.97750774],
        [-0.78620112],
        [-0.11968462],
        [ 2.15036071],
        [ 0.3981498 ],
        [ 1.97688798],
        [-1.88910205],
        [-2.78398205],
        [-2.91966055],
        [-1.35079056],
        [-1.75438521],
        [-1.14224462],
        [-2.76103518],
        [-1.20428206],
        [-0.15537948],
        [ 1.6116229 ],
        [-3.43159002],
        [-0.60478707],
        [-3.20638712],
        [-2.52446036],
        [ 1.21824981],
        [ 1.9664575 ],
        [ 3.04914764],
        [ 3.82120942],
        [ 5.10468362],
        [ 1.81575258],
        [ 2.15654629],
        [ 0.36915476]]),
 (array([[-3.00957470e-02, -3.39309957e-02,  6.23238145e-03,
           2.25398257e-02,  1.76284701e-02,  1.80758507e-02,
           2.85578463e-02, -4.78326335e-03, -5.76576600e-03,
           3.15411082e-02, -4.21465950e-02,  2.03402822e-02,
          -4.83238