In [9]:
import numpy as np

In [10]:
def tanh(x, derv=False):
    if derv: return 1 - np.tanh(x)**2
    return np.tanh(x)

In [11]:
def derv_tanh(x):
    return 1 - x**2

In [12]:
def sigmoid(x, derv=False):
    s = 1 / (1 + np.exp(-x))
    if derv: return s * (1 - s)
    return s

In [13]:
def derv_sigmoid(x):
    return x * (1 - x)

In [14]:
def MSE(y, y_pred, derv=False):
    if derv: return 2*(y_pred-y)
    return np.mean((y_pred-y)**2)

In [15]:
INPUT_SIZE = 1
HIDDEN_SIZE = 50

In [16]:
def init_weights():
    U = np.random.randn(INPUT_SIZE, HIDDEN_SIZE) * np.sqrt(2/(INPUT_SIZE+HIDDEN_SIZE))
    W = np.random.randn(HIDDEN_SIZE, HIDDEN_SIZE) * np.sqrt(1/HIDDEN_SIZE)
    b = np.zeros((1, HIDDEN_SIZE))
    return U, W, b

In [17]:
Uf, Wf, bf = init_weights()
Ui, Wi, bi = init_weights()
Uo, Wo, bo = init_weights()
Ug, Wg, bg = init_weights()

params = [Uf, Wf, bf, Ui, Wi, bi, Uo, Wo, bo, Ug, Wg, bg]
len_params = len(params)

In [18]:
def update_params(grads, lr):
    for i in range(len_params):
        params[i] -= lr * grads[i]

In [19]:
def forward_cell(xt, h_prev, c_prev):
    ft = sigmoid(xt @ Uf + h_prev @ Wf + bf)
    it = sigmoid(xt @ Ui + h_prev @ Wi + bi)
    ot = sigmoid(xt @ Uo + h_prev @ Wo + bo)
    candt = tanh(xt @ Ug + h_prev @ Wg + bg)

    ct = ft * c_prev + it * candt
    ht = tanh(ct) * ot
    return ht, ct, candt, ot, it, ft

In [20]:
def forward(x):
    global H, C, O, I, F, Cand
    T = x.shape[0]

    H = np.zeros((T+1, 1, HIDDEN_SIZE))
    C = np.zeros((T+1, 1, HIDDEN_SIZE))


    O = np.zeros((T, 1, HIDDEN_SIZE))
    I = np.zeros((T, 1, HIDDEN_SIZE))
    F = np.zeros((T, 1, HIDDEN_SIZE))
    Cand = np.zeros((T, 1, HIDDEN_SIZE))

    for t in range(T):
        H[t+1], C[t+1], candt, ot, it, ft = forward_cell(x[t], H[t], C[t])
        Cand[t] = candt
        O[t] = ot
        I[t] = it
        F[t] = ft
        
    return H[1:] # first doesn't count

In [21]:
forward(np.random.randn(32, INPUT_SIZE))

array([[[-0.07669864,  0.058061  ,  0.01657847, ...,  0.02144948,
          0.03874634,  0.00459386]],

       [[ 0.00477208, -0.00606728,  0.00359236, ...,  0.0053797 ,
         -0.01663696, -0.00794547]],

       [[-0.05361056,  0.04972602,  0.01414172, ...,  0.02263642,
          0.02649588,  0.00071927]],

       ...,

       [[-0.0082118 ,  0.0204489 ,  0.0158995 , ...,  0.02149906,
         -0.02206349, -0.01573029]],

       [[-0.11914841,  0.10369873,  0.0344218 , ...,  0.05791529,
          0.05581425,  0.00160228]],

       [[-0.02464391,  0.04319373,  0.02070662, ...,  0.03168772,
         -0.00253206, -0.01497841]]])

In [22]:
def backward_cell(xt, h_prev, c_prev, c_current, candt, ot, it_gate, ft_gate, dh, dc_next):
    d_ot = dh * tanh(c_current)
    d_tanh_c = dh * ot
    d_c_from_h = d_tanh_c * tanh(c_current, derv=True)

    dct = d_c_from_h + dc_next
    
    d_ft = dct * c_prev
    d_it = dct * candt
    d_candt = dct * it_gate

    d_ft_pre = d_ft * derv_sigmoid(ft_gate)
    d_it_pre = d_it * derv_sigmoid(it_gate)
    d_ot_pre = d_ot * derv_sigmoid(ot)
    d_candt_pre = d_candt * derv_tanh(candt)

    dUft = xt.T @ d_ft_pre
    dWft = h_prev.T @ d_ft_pre
    dbft = np.sum(d_ft_pre, axis=0, keepdims=True)

    dUit = xt.T @ d_it_pre
    dWit = h_prev.T @ d_it_pre
    dbit = np.sum(d_it_pre, axis=0, keepdims=True)

    dUot = xt.T @ d_ot_pre
    dWot = h_prev.T @ d_ot_pre
    dbot = np.sum(d_ot_pre, axis=0, keepdims=True)

    dUgt = xt.T @ d_candt_pre
    dWgt = h_prev.T @ d_candt_pre
    dbgt = np.sum(d_candt_pre, axis=0, keepdims=True)

    dx_t = (d_ft_pre @ Uf.T +
                d_it_pre @ Ui.T +
                d_ot_pre @ Uo.T +
                d_candt_pre @ Ug.T)

    dh_prev = (d_ft_pre @ Wf.T +
                   d_it_pre @ Wi.T +
                   d_ot_pre @ Wo.T +
                   d_candt_pre @ Wg.T)
    
    dc_prev = dct * ft_gate

    return (dUft, dWft, dbft,
            dUit, dWit, dbit,
            dUot, dWot, dbot,
            dUgt, dWgt, dbgt,
            dh_prev, dc_prev, dx_t)

In [23]:
def init_grads(U, W, b):
    return np.zeros_like(U), np.zeros_like(W), np.zeros_like(b)

In [24]:
def backward(x, y_true, y_pred, learn=True, lr=0.001):
    T = x.shape[0]

    dUf, dWf, dbf = init_grads(Uf, Wf, bf)
    dUi, dWi, dbi = init_grads(Ui, Wi, bi)
    dUo, dWo, dbo = init_grads(Uo, Wo, bo)
    dUg, dWg, dbg = init_grads(Ug, Wg, bg)

    dh_next = np.zeros((1, HIDDEN_SIZE))
    dc_next = np.zeros((1, HIDDEN_SIZE))
    dx = np.zeros_like(x)

    for t in reversed(range(T)):
        xt = x[t:t+1]
        yt = y_true[t:t+1]
        outp = y_pred[t:t+1]
        h_prev = H[t]
        c_prev = C[t]
        c_t = C[t+1]
        candt = Cand[t]
        ot = O[t]
        it_gate= I[t]
        ft_gate= F[t]

        dh = MSE(yt, outp, derv=True) + dh_next

        (dUf_t, dWf_t, dbf_t,
         dUi_t, dWi_t, dbi_t,
         dUo_t, dWo_t, dbo_t,
         dUg_t, dWg_t, dbg_t,
         dh_next, dc_next, dx[t]) = backward_cell(xt, h_prev, c_prev, c_t, candt, ot, it_gate, ft_gate, dh, dc_next)

        dUf += dUf_t; dWf += dWf_t; dbf += dbf_t
        dUi += dUi_t; dWi += dWi_t; dbi += dbi_t
        dUo += dUo_t; dWo += dWo_t; dbo += dbo_t
        dUg += dUg_t; dWg += dWg_t; dbg += dbg_t

    dUf /= T; dWf /= T; dbf /= T
    dUi /= T; dWi /= T; dbi /= T
    dUo /= T; dWo /= T; dbo /= T
    dUg /= T; dWg /= T; dbg /= T

    grads = dUf, dWf, dbf, dUi, dWi, dbi, dUo, dWo, dbo, dUg, dWg, dbg
    
    if learn:
        update_params(grads, lr)

    return dx, grads

In [25]:
backward(
    x=np.random.randn(32, INPUT_SIZE),
    y_true=np.random.randn(32, HIDDEN_SIZE),
    y_pred=np.random.randn(32, HIDDEN_SIZE),
    learn=True,
    lr=0.001
)

(array([[-0.22436765],
        [ 0.97938876],
        [ 0.05361214],
        [-1.09917685],
        [-0.43596157],
        [-1.16757609],
        [-1.0722661 ],
        [ 0.08104055],
        [ 0.87143992],
        [ 0.40342399],
        [ 1.22254067],
        [ 2.20421585],
        [ 1.07417873],
        [ 0.84403903],
        [ 0.22552111],
        [-0.02014018],
        [-0.04947555],
        [ 1.81986484],
        [ 2.21482954],
        [ 0.2856446 ],
        [ 1.13207781],
        [ 0.70017221],
        [ 1.69386634],
        [-1.62385167],
        [-2.02233188],
        [-0.49515473],
        [ 0.03296934],
        [ 2.08073884],
        [-0.65899205],
        [-1.90157991],
        [-0.9585065 ],
        [-1.72247067]]),
 (array([[-0.01053794, -0.0096479 ,  0.00229446,  0.00416207, -0.00331031,
           0.00101777,  0.00403722,  0.00482895, -0.00708919,  0.00293243,
           0.00027225, -0.00059585,  0.00207418,  0.00238352,  0.00310146,
           0.000639  ,  0.0071121 ,  