In [1]:
import numpy as np

def sigmoid(x):
    return 1/(1+np.exp(-x))

def d_sigmoid(x):
    return x*(1-x)

def tanh(x):
    return (np.exp(x)-np.exp(-x))/(np.exp(x)+np.exp(-x))

def d_tanh(x):
    return 1 - x ** 2

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

In [2]:
class MyGRUCell():
    def __init__(self, input_size, hidden_size, lr):
        super(MyGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.lr = lr
        self.W = np.random.randn(hidden_size, input_size)
        self.W_r = np.random.randn(hidden_size, input_size)
        self.W_z = np.random.randn(hidden_size, input_size)
        self.U = np.random.randn(hidden_size, hidden_size)
        self.U_r = np.random.randn(hidden_size, hidden_size)
        self.U_z = np.random.randn(hidden_size, hidden_size)
       
    def forward(self, x_t, h_prev):
        r = sigmoid(np.dot(self.W_r, x_t)+np.dot(self.U_r, h_prev))
        z = sigmoid(np.dot(self.W_z, x_t)+np.dot(self.U_z, h_prev))
        h_hat = tanh(np.dot(self.W, x_t)+np.dot(self.U,r@h_prev))
        h = z*h_prev+(1-z)*h_hat
        cache = (x_t, h_prev, r, z, h_hat)
        return h, cache

    def backward(self, cache, dh_next):
        x_t, h_prev, r, z, h_hat = cache

        dh_hat = dh_next * (1 - z)
        dz = dh_next * (h_prev - h_hat)

        # hidden state backward
        dtanh = dh_hat * d_tanh(h_hat)
        dW = np.dot(dtanh, x_t.T)
        dU = np.dot(dtanh, (r * h_prev).T)
        dx_1 = np.dot(self.W.T, dtanh)
        dh_prev_1 = np.dot(self.U.T, dtanh) * r

        # update gate backward
        dz = dz * d_sigmoid(z)
        dW_z = np.dot(dz, x_t.T)
        dU_z = np.dot(dz, h_prev.T)
        dx_2 = np.dot(self.W_z.T, dz)
        dh_prev_2 = np.dot(self.U_z.T, dz)

        # reset gate backward
        dr = np.dot(self.U.T, dtanh) * h_prev * d_sigmoid(r)
        dW_r = np.dot(dr, x_t.T)
        dU_r = np.dot(dr, h_prev.T)
        dx_3 = np.dot(self.W_r.T, dr)
        dh_prev_3 = np.dot(self.U_r.T, dr)

        dx = dx_1 + dx_2 + dx_3
        dh_prev = dh_prev_1 + dh_prev_2 + dh_prev_3 + dh_next * z

        self.update(dW, dU, dW_z, dU_z, dW_r, dU_r, self.lr)
        return dx, dh_prev
    
    def update(self, dW, dU, dW_z, dU_z, dW_r, dU_r, lr):
        self.W -= lr * dW
        self.U -= lr * dU
        self.W_z -= lr * dW_z
        self.U_z -= lr * dU_z
        self.W_r -= lr * dW_r
        self.U_r -= lr * dU_r