In [1]:
import numpy as np

In [5]:
def softmax(x):
    ex = np.exp(x-np.max(x))
    return ex / ex.sum(axis = 0)

In [9]:
def rnn_cell_forward(xt, a_prev, parameters):
    Wax = parameters["Wax"]
    Waa = parameters["Waa"]
    Wya = parameters["Wya"]
    ba = parameters["ba"]
    by = parameters["by"]
    
    a_next = np.tanh(np.dot(Wax, xt) + np.dot(Waa, a_prev) + ba)
    y_pred = softmax(np.dot(Wya, a_next) + by)
    
#     cache = (a_next, a_prev, xt, parameters)
    
    return a_next, y_pred

In [10]:
def rnn_forward(X, Y, a0, parameters, vocab_size = 50):
    x, a, y_hat = {}, {}, {}
    a[-1] = np.copy(a0)
    loss = 0
    
    for t in range(len(X)):
        x[t] = np.zeros((vocab_size, 1))
        x[t][X[t]] = 1
        
        a[t], y_hat[t] = rnn_cell_forward(x[t], a[t-1], parameters)
        
        loss = -np.log(y_hat[t][Y[t],0])
    
    cache = (y_hat, a, x)
    return loss, cache

In [12]:
def rnn_cell_backward(dy, gradients, parameters, x, a, a_prev):
    gradients["dWya"] += np.dot(dy, a.T)
    gradients["dby"] += dy
    da = np.dot(parameters["Wya"].T, dy) + gradients["da_next"]
    dz = (1 - a ** 2) * da
    gradients["db"] += dz
    gradients["dWax"] += np.dot(dz, x.T)
    gradients["dWaa"] += np.dot(dz, a_prev.T)
    gradients["da_next"] += np.dot(parameters["Waa"].T, dz)
    return gradients

In [13]:
def rnn_backward(X, Y, parameters, cache):
    gradients = {}
    (y_hat, a, x) = cache
    Waa = parameters["Waa"]
    Wax = parameters["Wax"]
    Wya = parameters["Wya"]
    by = parameters["by"]
    ba = parameters["ba"]
    
    gradients["dWax"], gradients["dWaa"], gradients["dWya"] = np.zeros_like(Wax), np.zeros_like(Waa), np.zeros_like(Wya)
    gradients["dby"], gradients["dba"] = np.zeros_like["by"], np.zeros_like["ba"]
    gradients["da_next"] = np.zeros_like(a[0])
    
    for t in reversed(range(len(X))):
        dy = np.copy(y_hat[t])
        dy[Y[t]] -= 1
        gradients = rnn_cell_backward(dy, gradients, parameters, x[t], a[t], a[t-1])
    return gradients, a