In [1]:
from keras.models import Model
from keras.layers import Input, LSTM, GRU
import numpy as np
import matplotlib.pyplot as plt

In [2]:
try:
  import keras.backend as K
  if len(K.tensorflow_backend._get_available_gpus()) > 0:
    from keras.layers import CuDNNLSTM as LSTM
    from keras.layers import CuDNNGRU as GRU
except:
  pass

In [3]:
T = 8
D = 2
M = 3
X = np.random.randn(1, T, D)

In [5]:
def lstm1():
  input_ = Input(shape=(T, D))
  rnn = LSTM(M, return_state=True)
  x = rnn(input_)

  model = Model(inputs=input_, outputs=x)
  o, h, c = model.predict(X)
  print("o:", o)
  print("h:", h)
  print("c:", c)

In [4]:
def lstm2():
  input_ = Input(shape=(T, D))
  rnn = LSTM(M, return_state=True, return_sequences=True)
  # rnn = GRU(M, return_state=True)
  x = rnn(input_)

  model = Model(inputs=input_, outputs=x)
  o, h, c = model.predict(X)
  print("o:", o)
  print("h:", h)
  print("c:", c)

In [6]:
def gru1():
  input_ = Input(shape=(T, D))
  rnn = GRU(M, return_state=True)
  x = rnn(input_)

  model = Model(inputs=input_, outputs=x)
  o, h = model.predict(X)
  print("o:", o)
  print("h:", h)

In [7]:
def gru2():
  input_ = Input(shape=(T, D))
  rnn = GRU(M, return_state=True, return_sequences=True)
  x = rnn(input_)

  model = Model(inputs=input_, outputs=x)
  o, h = model.predict(X)
  print("o:", o)
  print("h:", h)

In [8]:
print("lstm1:")
lstm1()
print("lstm2:")
lstm2()
print("gru1:")
gru1()
print("gru2:")
gru2()

lstm1:
o: [[-0.04634568 -0.24632916  0.16547537]]
h: [[-0.04634568 -0.24632916  0.16547537]]
c: [[-0.11470257 -0.5127182   0.43763816]]
lstm2:
o: [[[ 0.03518438 -0.09564848 -0.18087968]
  [ 0.07648514 -0.10668113 -0.23878014]
  [ 0.06359713 -0.22380613 -0.17309335]
  [ 0.05865982 -0.22219503 -0.17065279]
  [ 0.02365636 -0.13207385 -0.08748368]
  [-0.01816387  0.01292667  0.06137877]
  [-0.03318233  0.01575479  0.13983685]
  [-0.06528048  0.15393232  0.21760711]]]
h: [[-0.06528048  0.15393232  0.21760711]]
c: [[-0.12588961  0.24542606  0.49227926]]
gru1:
o: [[ 0.09764599 -0.01560204 -0.02160428]]
h: [[ 0.09764599 -0.01560204 -0.02160428]]
gru2:
o: [[[-7.6937623e-02  1.6142216e-01 -4.7158368e-02]
  [-3.6489528e-01  2.9570019e-01  1.9905674e-01]
  [-1.1167477e-01  1.7655867e-01  9.8255411e-02]
  [ 2.6206428e-01 -9.0722837e-02 -1.3905734e-01]
  [-3.5534924e-04  9.7891085e-02  2.5586945e-01]
  [ 2.1725178e-01 -3.2570031e-01 -1.1598748e-02]
  [ 3.3343086e-01 -4.7246063e-01 -2.3159818e-01]
  