In [1]:
import numpy as np

In [15]:
def rnn_step_forward(x, prev_h, Wx, Wh, b):
    """
    Run the forward pass for a single timestep of a vanilla RNN that uses a tanh
    activation function.

    The input data has dimension D, the hidden state has dimension H, and we use
    a minibatch size of N.

    Inputs:
    - x: Input data for this timestep, of shape (N, D).
    - prev_h: Hidden state from previous timestep, of shape (N, H)
    - Wx: Weight matrix for input-to-hidden connections, of shape (D, H)
    - Wh: Weight matrix for hidden-to-hidden connections, of shape (H, H)
    - b: Biases of shape (H,)

    Returns a tuple of:
    - next_h: Next hidden state, of shape (N, H)
    - cache: Tuple of values needed for the backward pass.
    """
    
    next_h = np.tanh(np.dot(x, Wx) + np.dot(prev_h, Wh) + b)
    cache = (Wx, Wh, b, x, prev_h, next_h)
    
    return next_h, cache

In [21]:
def rnn_step_backward(dnext_h, cache):
    """
    Backward pass for a single timestep of a vanilla RNN.

    Inputs:
    - dnext_h: Gradient of loss with respect to next hidden state, of shape (N, H)
    - cache: Cache object from the forward pass

    Returns a tuple of:
    - dx: Gradients of input data, of shape (N, D)
    - dprev_h: Gradients of previous hidden state, of shape (N, H)
    - dWx: Gradients of input-to-hidden weights, of shape (D, H)
    - dWh: Gradients of hidden-to-hidden weights, of shape (H, H)
    - db: Gradients of bias vector, of shape (H,)
    """
    Wx, Wh, b, x, prev_h, next_h = cache
    
    dz = dnext_h * (1 - next_h ** 2)
    dx = np.dot(dz, Wx.T)
    dprev_h = np.dot(dz, Wh.T)
    dWx = np.dot(x.T, dz)
    dWh = np.dot(prev_h.T, dz)
    db = np.sum(dz, axis=0)
    
    return dx, dprev_h, dWx, dWh, db

In [17]:
def rnn_forward(x, h0, Wx, Wh, b):
    """
    Run a vanilla RNN forward on an entire sequence of data. We assume an input
    sequence composed of T vectors, each of dimension D. The RNN uses a hidden
    size of H, and we work over a minibatch containing N sequences. After running
    the RNN forward, we return the hidden states for all timesteps.

    Inputs:
    - x: Input data for the entire timeseries, of shape (N, T, D).
    - h0: Initial hidden state, of shape (N, H)
    - Wx: Weight matrix for input-to-hidden connections, of shape (D, H)
    - Wh: Weight matrix for hidden-to-hidden connections, of shape (H, H)
    - b: Biases of shape (H,)

    Returns a tuple of:
    - h: Hidden states for the entire timeseries, of shape (N, T, H).
    - cache: Values needed in the backward pass
    """
    N, T, D = x.shape
    H, = b.shape
    
    cache = []
    prev_h = h0
    h = np.zeros((N, T, H))
    
    for i in range(T):
        prev_h, tmp_cache = rnn_step_forward(x[:,i,:], prev_h, Wx, Wh, b)
        h[:,i,:] = prev_h
        cache.append(tmp_cache)
    
    return h, cache

In [18]:
def rnn_backward(dh, cache):
    """
    Compute the backward pass for a vanilla RNN over an entire sequence of data.

    Inputs:
    - dh: Upstream gradients of all hidden states, of shape (N, T, H). 
    
    NOTE: 'dh' contains the upstream gradients produced by the 
    individual loss functions at each timestep, *not* the gradients
    being passed between timesteps (which you'll have to compute yourself
    by calling rnn_step_backward in a loop).

    Returns a tuple of:
    - dx: Gradient of inputs, of shape (N, T, D)
    - dh0: Gradient of initial hidden state, of shape (N, H)
    - dWx: Gradient of input-to-hidden weights, of shape (D, H)
    - dWh: Gradient of hidden-to-hidden weights, of shape (H, H)
    - db: Gradient of biases, of shape (H,)
    """
    
    N, T, H = dh.shape
    D = cache[0][0].shape[0]
    
    dprev_h = np.zeros((N,H))
    dx = np.zeros((N, T, D))
    dWx = np.zeros((D,H))
    dWh = np.zeros((H,H))
    db = np.zeros(H)
    
    for i in reversed(range(T)):
        dcurr_h = dprev_h + dh[:,i,:]
        dx[:,i,:], dprev_h, tmp_dWx, tmp_dWh, tmp_db = rnn_step_backward(dcurr_h, cache[i])
        dWx += tmp_dWx
        dWh += tmp_dWh
        db += tmp_db
    
    dh0 = dprev_h
    
    return dx, dh0, dWx, dWh, db