In [1]:
import numpy as np

In [2]:
# RNN Cell
class RNN():
  def __init__(self, Wx, Wh, b):
    '''
    Wx : Weights for x (D x H)
    Wh : Weights for h (N x H)
    b  : bias (N x H)
    '''
    self.params = [Wx, Wh, b]
    self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(Wb)]
    self.cache = None   # for backward

  def forward(self, x, h_prev):
    '''
    x      : input vector (N x D)
    h_prev : input pervious state (N x H)
    '''
    Wx, Wh, Wb = self.params
    t = np.matmul(h_prev, Wh) + np.matmul(x, Wx)
    h_now = np.tanh(t)

    self.cache = (x, h_prev, h_now)
    return h_now

  def backward(self, dh_next):
    Wx, Wh, b = self.params
    x, h_prev, h_now = self.cache
    
    dt = dh_next * (1 - h_now**2)
    db = np.sum(dt, axis = 0)
    dWh = np.matmul(h_prev.T, dt)
    dh_prev = np.matmul(dt, Wh.T)
    dWx = np.matmul(x.T, dt)
    dx = np.matmul(dt, Wx.T)
    
    self.grads[0][...] = dWx
    self.grads[1][...] = dWh
    self.grads[2][...] = db

    return dx, dh_prev


In [3]:
# Time RNN
class TimeRNN():
  def __init__(self, Wx, Wh, b, stateful=False):
    self.params = [Wx, Wh, b]
    self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(Wb)]
    self.layers = None

    self.h, self.dh = None, None
    self.stateful = stateful

  def set_state(self, h):
    self.h = h

  def reset_state(self):
    self.h = None

  def forward(self, xs):
    Wx, Wh, Wb = self.params
    N, T, D = xs.shape
    D, H = Wx.Shape

    self.layers = []
    hs = np.empty((N, T, H), dtype='f')   # (batch, sequences, hidden)
    
    if not self.stateful or self.h is None:
      self.h = np.zeros((N, H), dtype='f')

    for t in range(T):
      layer = RNN(*self.params)
      self.h = layer.forward(xs[:, t, :], self.h)
      hs[:, t, :] = self.h
      self.layers.append(layer)
    
    return hs

  def backward(self, dhs):
    Wx, Wh, b = self.params
    N, T, H = dhs.shape
    D, H = Wx.shape

    dxs = np.empty((N, T, D), dtype='f')
    dh = 0
    grads = [0, 0, 0]

    for t in reversed(range(T)): # BPTT
      layer = self.layers[t]
      dx, dh = layer.backward(dhs[:, t, :] + dh)
      dxs[:, t, :] = dx

      for i, grad in enumerate(layer.grads):
        grads[i] += grad
    
    for i, grad in enumerate(grads):
      self.grads[i] = grad
    self.dh = dh

    return dxs