In [1]:
import jax.numpy as jnp
from jax import random, vmap, jit, grad

key = random.PRNGKey(0)

In [2]:
class OriginalVersion:
    @staticmethod
    def normal_cell(x, h0, 
                    w_hh, w_xh, b_h, 
                    w_hy, b_y):
        '''
        Input
        -----
        x: (S, B, I)
        h0: (B, H)
        q_hh: (H, H)
        w_xh: (I, H)
        b_h: (H)
        w_hy: (H, O)
        b_y: (H)

        Output
        ------
        res: (S, B, O)
        h: (S, B, H)
        '''
        steps, batch_size, input_dim = x.shape  # S, B, I
        _, hidden_dim = w_hh.shape  # H, H
        _, output_dim = w_hy.shape  # H, O

        res = jnp.zeros((steps, batch_size, output_dim))  # S, B, O
        h = jnp.zeros((steps, batch_size, hidden_dim))  # S, B, H
        h = h.at[-1].set(h0)
        for ix in range(steps):
            h = h.at[ix].set(
                jnp.tanh(h[ix - 1] @ w_hh + x[ix] @ w_xh + b_h)
            )
            res = res.at[ix].set(
                h[ix] @ w_hy + b_y
            )

        return res, h

    @staticmethod
    def lstm_cell(x, h0, c0,
                    Ws, Us, Bs):
        '''
        Input
        -----
        x: (S, B, I)
        h0: (B, H)
        c0: (B, H)
        Ws: (4, I, H)
        Us: (4, H, H)
        Bs: (4, H)

        Output
        ------
        res: (S, B, H)
        h: (S, B, H)
        c: (S, B, H)
        '''

        def sigmoid(x):
            return 1 / (1 + jnp.exp(-x))

        w_i, w_f, w_c, w_o = Ws  # (I, H)
        u_i, u_f, u_c, u_o = Us  # (H, H)
        b_i, b_f, b_c, b_o = Bs  # (H)

        steps, batch_size, input_dim = x.shape  # S, B, I
        _, hidden_dim = w_i.shape

        res = jnp.zeros((steps, batch_size, hidden_dim))  # S, B, H
        h = jnp.zeros((steps, batch_size, hidden_dim)) # S, B, H
        h = h.at[-1].set(h0)
        c = jnp.zeros((steps, batch_size, hidden_dim))  # S, B, H
        c = c.at[-1].set(c0)

        for ix in range(steps):
            I = sigmoid(x[ix] @ w_i + h[ix - 1] @ u_i + b_i)
            F = sigmoid(x[ix] @ w_f + h[ix - 1] @ u_f + b_f)
            C = jnp.tanh(x[ix] @ w_c + h[ix - 1] @ u_c + b_c)
            O = sigmoid(x[ix] @ w_o + h[ix - 1] @ u_o + b_o)

            c = c.at[ix].set(
                F*c[ix - 1] + I*C
            )
            h = h.at[ix].set(
                O*jnp.tanh(C)
            )
            res = res.at[ix].set(
                O
            )

        return res, h, c

    @staticmethod
    def gru_cell(x, h0, 
                    Ws, Us, Bs):
        '''
        Input
        -----
        x: (S, B, I)
        h0: (S, B, H)
        Ws: 3 * (I, H)
        Us: 3 * (H, H)
        Bs: 3 * (H)

        Output
        ------
        res: (S, B, H)
        h: (S, B, H)
        '''

        def sigmoid(x):
            return 1 / (1 + jnp.exp(-x))

        w_z, w_r, w_h = Ws  # (I, H)
        u_z, u_r, u_h = Us  # (H, H)
        b_z, b_r, b_h = Bs  # (H)

        steps, batch_size, input_dim = x.shape  # S, B, I
        _, hidden_dim = w_z.shape

        res = jnp.zeros((steps, batch_size, hidden_dim))  # S, B, H
        h = jnp.zeros((steps, batch_size, hidden_dim)) # S, B, H
        h = h.at[-1].set(h0)

        for ix in range(steps):
            R = sigmoid(x[ix] @ w_r + h[ix - 1] @ u_r + b_r)
            Z = sigmoid(x[ix] @ w_z + h[ix - 1] @ u_z + b_z)

            H = jnp.tanh(x[ix] @ w_h + (R * h[ix - 1]) @ u_h + b_h)

            h = h.at[ix].set(
                (1 - Z) * h[ix - 1] + Z * H
            )

        return h

In [3]:
batch_size = 1000
time_steps = 128
input_dim = 9
output_dim = 6
hidden_dim = 64

x = random.normal(key, (time_steps, batch_size, input_dim))


def softmax(logits):
    logits_stable = logits - jnp.max(logits, axis=1, keepdims=True)
    exp_logits = jnp.exp(logits_stable)
    return exp_logits / jnp.sum(exp_logits, axis=1, keepdims=True)

def cross_entropy_loss(y, y_pred):
    epsilon = 1e-9
    y_pred_clipped = jnp.clip(y_pred, epsilon, 1. - epsilon)  # clip here is very important, or you will get Nan when you training. 
    loss = -jnp.sum(y * jnp.log(y_pred_clipped), axis=1)
    return loss.mean()

y_true = random.randint(key, (1000,), 0, 5)

def one_hot(x, num_class):
    res = jnp.zeros((x.shape[0], num_class))
    res = res.at[jnp.arange(x.shape[0]), x].set(1)
    return res

y_true = one_hot(y_true, output_dim)

print(y_true.shape)
print(y_true[:10])

(1000, 6)
[[0. 0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0. 0.]]


# # test for normal cell

In [4]:
h0 = random.normal(key, (batch_size, hidden_dim))

w_hh = random.normal(key, (hidden_dim, hidden_dim))
w_xh = random.normal(key, (input_dim, hidden_dim))
b_h = random.normal(key, (hidden_dim))
w_hy = random.normal(key, (hidden_dim, output_dim))
b_y = random.normal(key, (output_dim))

@jit
def loss(w_hh, w_xh, b_h, w_hy, b_y):
    res, _ = OriginalVersion.normal_cell(x, h0, w_hh, w_xh, b_h, w_hy, b_y)
    logits = res[-1]
    proba = softmax(logits)
    lo = cross_entropy_loss(y_true, proba)

    return lo

In [5]:
res, h = OriginalVersion.normal_cell(x, h0, w_hh, w_xh, b_h, w_hy, b_y)

print(res.shape)
print(f'res: \n {res[0]}')
print(f'res: \n {res[-1]}')
print()
print(h.shape)
print(f'h: \n {h[0]}')
print(f'h: \n {h[-1]}')
print()

(128, 1000, 6)
res: 
 [[  8.66937     -1.9780154   -9.532201    11.172364     1.1787364
    2.5646725 ]
 [  6.774772    -4.0882893   -0.89915705  -3.9126399   -2.888041
   -2.898172  ]
 [  3.0921602  -14.344307     3.7016938  -10.14012      8.677878
   -7.758292  ]
 ...
 [ -4.9551196    4.8915215   10.054241   -13.020678     2.0951357
   -1.1433754 ]
 [  1.3927512   10.878891   -17.823631     8.056272    -2.5641198
   -4.1883535 ]
 [ -5.8027096   -1.3530209   -0.57130706   0.7142793    2.584575
   13.048897  ]]
res: 
 [[  1.0190579   3.9091406  -9.910453    2.5703552   9.322914    6.368651 ]
 [ 13.950739   -5.5121355  -8.111554    2.3537805   2.5091462  -3.8911343]
 [  4.0426297  -2.2920299  -8.372124   -0.3052504   8.704442    7.8723865]
 ...
 [ -1.2945749  -3.6627593 -10.36847   -10.822591    1.5282776   7.798837 ]
 [ -4.212191    9.458644   -5.5591736   3.7447376  -2.4790366   9.87258  ]
 [ -8.179989   -0.2797513   3.0643513   3.6697965   0.7604062   7.428778 ]]

(128, 1000, 64)
h: 

In [7]:
import time

s = time.time()
lo = loss(w_hh, w_xh, b_h, w_hy, b_y)
print(f'cost: {time.time() - s}')

cost: 0.00043582916259765625


In [10]:
s = time.time()
d_w_hh = grad(loss, argnums=0)(w_hh, w_xh, b_h, w_hy, b_y)
print(f'cost: {time.time() - s}')

cost: 0.012967348098754883


# # test for LSTM cell

In [11]:
x = random.normal(key, (time_steps, batch_size, input_dim))
h0 = random.normal(key, (batch_size, hidden_dim))
c0 = random.normal(key, (batch_size, hidden_dim))

Ws = [random.normal(key, (input_dim, hidden_dim)) for _ in range(4)]
Us = [random.normal(key, (hidden_dim, hidden_dim)) for _ in range(4)]
Bs = [random.normal(key, (hidden_dim)) for _ in range(4)]

# fc layer
w = random.normal(key, (hidden_dim, output_dim))
b = random.normal(key, (output_dim))

@jit
def loss(Ws, Us, Bs, w, b):
    res, _, _ = OriginalVersion.lstm_cell(x, h0, c0, Ws, Us, Bs)
    logits = res[-1] @ w + b
    proba = softmax(logits)
    lo = cross_entropy_loss(y_true, proba)

    return lo

In [12]:
res, h, c = OriginalVersion.lstm_cell(x, h0, c0, Ws, Us, Bs)

print(res.shape)
print(f'res: \n{res[0]}')
print(f'res: \n{res[-1]}')
print()
print(h.shape)
print(f'res: \n{h[0]}')
print(f'res: \n{h[-1]}')
print()
print(c.shape)
print(f'res: \n{c[0]}')
print(f'res: \n{c[-1]}')
print()

(128, 1000, 64)
res: 
[[9.69917059e-01 8.92088935e-03 8.38588595e-01 ... 2.38286518e-02
  1.13636515e-05 9.99993205e-01]
 [9.37866569e-01 9.84946728e-01 3.56435962e-03 ... 9.99959230e-01
  8.61522481e-02 9.93443251e-01]
 [9.99974728e-01 1.00000000e+00 1.84506334e-07 ... 8.99734277e-06
  1.88667193e-01 7.30012107e-05]
 ...
 [9.63957965e-01 9.90295529e-01 6.44707739e-01 ... 5.12012139e-06
  6.41178712e-03 6.39698002e-03]
 [1.04544124e-08 2.62359686e-06 3.22645038e-01 ... 4.98917669e-01
  1.95793733e-02 6.93655511e-06]
 [6.30637764e-09 2.62113311e-03 9.99986768e-01 ... 8.90103693e-04
  2.37014814e-04 9.99993920e-01]]
res: 
[[9.9962735e-01 8.1208688e-01 9.7131777e-01 ... 9.9817574e-01
  5.4641280e-02 2.9876885e-01]
 [9.6802384e-01 9.9897939e-01 6.7911730e-03 ... 9.9219877e-01
  4.0285251e-01 1.5354355e-01]
 [9.4644415e-01 9.9155831e-01 2.3182090e-01 ... 9.3571478e-01
  2.7067055e-05 9.9094427e-01]
 ...
 [9.9469912e-01 9.9858010e-01 9.4757545e-01 ... 9.7912121e-01
  6.7130703e-01 7.3915243e

In [14]:
s = time.time()
lo = loss(Ws, Us, Bs, w, b)

print(f'cost: {time.time() - s}')

cost: 0.0003101825714111328


In [19]:
s = time.time()
gd = grad(loss, argnums=0)(Ws, Us, Bs, w, b)

print(f'cost: {time.time() - s}')

cost: 0.027780771255493164


# # test for GRU cell

In [50]:
x = random.normal(key, (time_steps, batch_size, input_dim))
h0 = random.normal(key, (batch_size, hidden_dim))

Ws = [random.normal(key, (input_dim, hidden_dim)) for _ in range(3)]
Us = [random.normal(key, (hidden_dim, hidden_dim)) for _ in range(3)]
Bs = [random.normal(key, (hidden_dim)) for _ in range(3)]

# fc layer
w = random.normal(key, (hidden_dim, output_dim))
b = random.normal(key, (output_dim))

@jit
def loss(Ws, Us, Bs, w, b):
    res = OriginalVersion.gru_cell(x, h0, Ws, Us, Bs)
    logits = res[-1] @ w + b
    proba = softmax(logits)
    lo = cross_entropy_loss(y_true, proba)

    return lo

In [51]:
res = OriginalVersion.gru_cell(x, h0, Ws, Us, Bs)

print(res.shape)
print(f'res: \n{res[0]}')
print(f'res: \n{res[-1]}')
print()

(128, 1000, 64)
res: 
[[ 1.0187309   1.9993063   0.71674156 ... -0.52912116 -1.6692089
   0.9993596 ]
 [-0.39627454  0.95510256 -0.9753823  ...  0.9998114   0.20660144
   0.9908003 ]
 [ 1.0000037   1.          0.7234152  ...  0.01733283  0.13955556
  -1.0713109 ]
 ...
 [ 0.97045386  0.97538894  0.8471778  ...  0.12468851 -0.09145955
  -2.6021852 ]
 [-1.8666751   0.38999814 -0.10754696 ...  0.01773185 -1.330422
   0.6966255 ]
 [-0.24632064 -1.3629702   0.9991948  ...  0.69544685 -0.4143857
   0.99951446]]
res: 
[[ 0.9999984   0.9998255   0.9998114  ...  0.9507475   0.6827151
   0.7985372 ]
 [ 0.9999942   0.99759394  0.99878114 ...  0.7878383   0.99831545
   0.9857281 ]
 [ 0.9977804   0.99998593  0.59166765 ...  0.99699116  0.05741335
   0.9959601 ]
 ...
 [ 0.994446    0.9997472   0.9999921  ...  0.5379023   0.9994596
   0.9772161 ]
 [ 1.          0.99999624  0.6755853  ... -0.1219774   0.4803075
  -0.5124174 ]
 [ 0.98375964  0.99986595  0.28833    ... -0.4977213   0.581353
   0.98541087

In [53]:
s = time.time()
lo = loss(Ws, Us, Bs, w, b)

print(f'cost: {time.time() - s}')

cost: 0.0004401206970214844


In [55]:
s = time.time()
gd = grad(loss, argnums=0)(Ws, Us, Bs, w, b)

print(f'cost: {time.time() - s}')

cost: 0.06224656105041504
