## GRUの実装

In [1]:
import numpy as np

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

### resetゲート

In [2]:
N = 32
T = 16
D = 100
H = 100

xs = np.random.randn(N, T, D)

x = xs[:, 0, :]

h_prev = np.zeros((N, H))

In [3]:
Wh_r = np.random.randn(H, H)
Wx_r = np.random.randn(D, H)
b_r = np.random.randn(N, H)

In [5]:
r = sigmoid(np.dot(x, Wx_r) + np.dot(h_prev, Wh_r) + b_r)

### updateゲート

In [6]:
Wh_z = np.random.randn(H, H)
Wx_z = np.random.randn(D, H)
b_z = np.random.randn(N, H)

Wh_h = np.random.randn(H, H)
Wx_h = np.random.randn(D, H)
b_h = np.random.randn(N, H)

z = sigmoid(np.dot(x, Wx_z) + np.dot(h_prev, Wh_z) + b_z)
h_hat = np.tanh(np.dot(x, Wx_h) + np.dot(r * h_prev, Wh_h) + b_h)

In [7]:
h_next = (1 - z) * h_prev + z * h_hat

### 処理をまとめる

In [8]:
# Wh_r = np.random.randn(H, H)
# Wh_z = np.random.randn(H, H)
# Wh_h = np.random.randn(H, H)

# Wx_r = np.random.randn(D, H)
# Wx_z = np.random.randn(D, H)
# Wx_h = np.random.randn(D, H)

# b_r = np.random.randn(N, H)
# b_z = np.random.randn(N, H)
# b_h = np.random.randn(N, H)

In [9]:
Wh = np.random.randn(H, 3 * H)
Wx = np.random.randn(D, 3 * H)
b = np.random.randn(N, 3 * H)

In [10]:
class GRU:

    def __init__(self, Wx, Wh, b):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.cache = None

    def forward(self, x, h_prev):
        Wx, Wh, b = self.params
        H = Wh.shape[0]

        Wx_z, Wx_r, Wx_h = Wx[:, :H], Wx[:, H:2*H], Wx[:, 2*H:]
        Wh_z, Wh_r, Wh_h = Wh[:, :H], Wh[:, H:2*H], Wh[:, 2*H:]
        b_z, b_r, b_h = b[:, :H], b[:, H:2*H], b[:, 2*H:]

        z = sigmoid(np.dot(x, Wx_z) + np.dot(h_prev, Wh_z) + b_z)
        r = sigmoid(np.dot(x, Wx_r) + np.dot(h_prev, Wh_r) + b_r)
        h_hat = np.tanh(np.dot(x, Wx_h) + np.dot(r * h_prev, Wh_h) + b_h)

        h_next = (1 - z) * h_prev + z * h_hat

        self.cache = (x, h_prev, z, r, h_hat)

        return h_next

In [11]:
x = xs[:, 0, :]
gru_unit = GRU(Wx, Wh, b)

h_next = gru_unit.forward(x, h_prev)

In [12]:
h_next.shape

(32, 100)