In [17]:
from __future__ import absolute_import, division, print_function

import tensorflow as tf
from tensorflow.keras import Model, layers
import numpy as np

In [18]:
num_classes = 10 
num_features = 784 

learning_rate = 0.001
training_steps = 50
batch_size = 32
display_step = 5

num_input = 28 
timesteps = 28 
num_units = 32 

In [19]:
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
x_train, x_test = x_train.reshape([-1, 28, 28]), x_test.reshape([-1, num_features])
x_train, x_test = x_train / 255., x_test / 255.

In [20]:
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)

In [21]:
class LSTM(Model):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm_layer = layers.LSTM(units=num_units)
        self.out = layers.Dense(num_classes)

    def call(self, x, is_training=False):
        # LSTM layer.
        x = self.lstm_layer(x)
        x = self.out(x)
        if not is_training:
            x = tf.nn.softmax(x)
        print(x)
        return x
lstm_net = LSTM()


In [22]:

def cross_entropy_loss(x, y):
    y = tf.cast(y, tf.int64)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)
    print (y)
    return tf.reduce_mean(loss)

def accuracy(y_pred, y_true):
    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))
    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)

optimizer = tf.optimizers.Adam(learning_rate)


In [23]:
def run_optimization(x, y):
    with tf.GradientTape() as g:
        pred = lstm_net(x, is_training=True)
        loss = cross_entropy_loss(pred, y)
        
    trainable_variables = lstm_net.trainable_variables

    gradients = g.gradient(loss, trainable_variables)
    
    optimizer.apply_gradients(zip(gradients, trainable_variables))

In [None]:
for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):
    run_optimization(batch_x, batch_y)
    
    if step % display_step == 0:
        pred = lstm_net(batch_x, is_training=True)
        accru = accuracy(pred, batch_y)
        print("step= %i, accuracy= %i" % (step,  accru))

tf.Tensor(
[[ 3.23154815e-02 -9.63898525e-02  8.22949708e-02 -1.48870796e-01
   6.99703693e-02 -9.59629565e-03  2.50161588e-02 -6.33109510e-02
  -3.09803970e-02 -3.19755413e-02]
 [ 2.14305930e-02 -1.16881430e-01  9.28590745e-02 -4.66790460e-02
   2.60518566e-02  1.10981032e-01  4.42214310e-04 -3.69510725e-02
  -8.39389414e-02  7.64809363e-03]
 [-6.69383854e-02 -3.73534262e-02  1.00008562e-01 -1.57941416e-01
   5.61581999e-02  1.41875193e-01 -8.43251795e-02 -1.18706189e-02
  -3.16997617e-03  7.13857636e-03]
 [ 7.14587867e-02 -8.26711953e-02  1.07707128e-01 -7.93028623e-02
   2.82038376e-02  5.27234077e-02  2.97333524e-02 -4.42803465e-03
  -4.29705605e-02 -4.77901511e-02]
 [ 3.64124328e-02 -1.17567495e-01  1.19900048e-01 -9.72013101e-02
   4.38156351e-03  9.03549865e-02 -1.75369903e-03 -1.84289161e-02
  -1.01658382e-01 -4.02856059e-03]
 [ 7.91992992e-04 -5.75190187e-02  1.02022089e-01 -1.18452504e-01
   4.11511734e-02  1.27149194e-01 -6.07771091e-02  4.37738560e-03
  -6.71663582e-02  4.0

  -3.58465463e-02 -5.45150749e-02]], shape=(32, 10), dtype=float32)
tf.Tensor([2 7 5 1 9 1 4 3 1 9 2 6 1 0 9 8 6 0 7 2 2 4 5 5 9 9 2 4 2 1 1 0], shape=(32,), dtype=int64)
tf.Tensor(
[[ 7.00878724e-02 -3.40136793e-03  6.96439818e-02 -2.38425098e-02
  -1.28265560e-01  9.65595469e-02  6.86791018e-02  3.99080617e-03
  -4.06596884e-02 -3.31988670e-02]
 [ 2.49502398e-02 -3.22320201e-02  6.03393055e-02 -3.51443402e-02
  -3.16400602e-02  5.16531020e-02  1.19948359e-02  8.72426108e-03
  -4.57579307e-02  8.46912805e-03]
 [ 4.92304154e-02 -6.79145381e-02  8.19491670e-02 -1.40927330e-01
   2.91010551e-02  1.39312670e-02 -6.43963635e-04 -6.24755863e-04
  -4.80236411e-02  4.17928328e-04]
 [-7.58134723e-02 -6.15351312e-02  1.49077550e-01 -2.76013136e-01
  -3.68873887e-02  1.71505928e-01 -7.50464946e-02 -2.27979850e-02
  -3.82173806e-02  4.83710580e-02]
 [ 2.62604114e-02 -4.66857776e-02  8.07079747e-02 -3.02933548e-02
  -5.16028441e-02  4.67970297e-02  5.52635677e-02 -1.32495202e-02
  -5.02955653e-02 

tf.Tensor([7 4 9 4 3 3 0 2 8 8 9 3 1 4 7 1 4 1 2 2 2 5 7 7 6 5 8 8 3 9 0 5], shape=(32,), dtype=int64)
tf.Tensor(
[[ 0.07928365 -0.11189632  0.11288746 -0.17517594 -0.04135509  0.03332274
  -0.01749447  0.04404282 -0.1281235   0.03477948]
 [-0.07436669 -0.08641519  0.03950489 -0.11940604  0.05317122  0.03570596
  -0.04265655 -0.00411565 -0.04962479  0.0591789 ]
 [ 0.04433995  0.06972519 -0.00084391 -0.02183327 -0.00445783  0.07376324
  -0.03852652 -0.01497676  0.02885136 -0.0307861 ]
 [ 0.00229827 -0.04368788  0.08822674 -0.10077747 -0.01631504  0.10431314
   0.00415905 -0.00460789  0.01056475  0.00939436]
 [ 0.04657036 -0.08893911  0.08320731 -0.11258953 -0.11282481  0.08464549
   0.00681098 -0.01717337 -0.10638118  0.05438447]
 [-0.03252089 -0.11631543  0.09801345 -0.1617726   0.00124667  0.08083438
  -0.05197251  0.00675791 -0.10014532  0.06582924]
 [ 0.06381424  0.03092724  0.04821987  0.00483387 -0.13422443  0.06550026
   0.06228907  0.02896656 -0.00467168 -0.03878064]
 [ 0.050081

tf.Tensor(
[[ 8.48430321e-02 -1.04172073e-01  7.51320571e-02 -7.37233013e-02
  -1.96534991e-01 -2.30387971e-02  1.91312507e-02 -3.42064686e-02
  -7.73276389e-02  1.10042626e-02]
 [-1.54635236e-01 -9.94408801e-02  1.66330056e-03 -1.83325186e-01
  -2.73192208e-02  4.72441092e-02 -1.07521206e-01 -2.45410912e-02
  -1.62183776e-01  1.10543273e-01]
 [ 1.48180407e-04 -5.97955771e-02  4.34665717e-02 -1.24275066e-01
   2.95786113e-02 -6.49139658e-02  5.85874543e-03 -3.18276417e-03
  -5.03243394e-02  2.21145265e-02]
 [ 8.47363696e-02 -9.10951421e-02  8.57678056e-02  6.02094573e-04
  -1.47243112e-01 -3.36426236e-02  1.27370834e-01 -1.31480219e-02
  -5.44472784e-02  4.28136159e-03]
 [-1.28614351e-01  1.30879149e-01 -9.23888236e-02 -5.72928749e-02
  -5.51498011e-02  4.25088257e-02 -2.23859906e-01 -7.99290556e-03
  -1.28067866e-01  6.88998029e-02]
 [-5.46010658e-02 -1.29319668e-01  7.92378187e-02  2.94689182e-03
   2.70543247e-03  7.53703862e-02  5.93767948e-02  2.44335551e-03
  -7.44886473e-02 -4.5

tf.Tensor(
[[ 0.05783003 -0.0375043   0.04875313  0.00740194 -0.08980167  0.02023211
   0.10117411  0.0415015   0.00959789 -0.03603333]
 [-0.04932344 -0.23592953  0.12067553  0.097472   -0.00875801 -0.02089883
   0.11973009  0.02593686 -0.0471371  -0.08171421]
 [ 0.12884459 -0.23446596  0.12495893 -0.063976   -0.325484   -0.10119268
   0.22863053  0.08750559 -0.16158028  0.10913246]
 [ 0.08749771 -0.06163495  0.14253579 -0.07308168 -0.16196145  0.02379148
   0.05230162  0.04621581 -0.095563    0.03282205]
 [ 0.07301773  0.10852012  0.06377453  0.12008603 -0.14430879  0.0451463
   0.04899129  0.05541737 -0.05371632 -0.02792926]
 [-0.01042339 -0.07114863  0.04269727 -0.12616725  0.0658326   0.04669568
  -0.03825939 -0.00166288 -0.04230378  0.08959875]
 [-0.02851738 -0.17294583  0.05559129  0.05835897  0.02842142  0.02889474
   0.08867425 -0.02046965 -0.06199149 -0.02845418]
 [ 0.0523947  -0.0362821   0.07395212  0.0048256  -0.12037919  0.00074816
   0.05617675  0.03230512 -0.08202079  0.

tf.Tensor([3 7 7 9 9 0 1 5 3 4 7 8 5 1 5 0 9 8 3 3 8 7 8 6 4 7 8 3 9 5 4 9], shape=(32,), dtype=int64)
tf.Tensor(
[[ 0.055345    0.02061503  0.02929107  0.02934632 -0.07367563  0.00080642
  -0.00485681  0.04914692  0.00430659 -0.01685411]
 [ 0.02605643 -0.24364166  0.09782372 -0.07730284 -0.01714579 -0.05896133
   0.12301172 -0.00928932 -0.11817417  0.06884114]
 [ 0.09906298 -0.19855742  0.08109597 -0.08512554 -0.2732966  -0.11040176
   0.16335118  0.15746948 -0.15907118  0.12653059]
 [-0.14233369  0.04250805 -0.02204836 -0.17465949  0.01244164  0.11226508
  -0.18580294  0.07530006 -0.07556468  0.09769662]
 [ 0.14071037 -0.01957405  0.09909933  0.02542004 -0.18301868 -0.00065553
   0.10362977  0.07245044 -0.04277984  0.01285318]
 [-0.14351124  0.00139436 -0.00855913 -0.21584325  0.04052125  0.10957097
  -0.2220844   0.08435034 -0.15536025  0.15468659]
 [ 0.00825355 -0.01250593  0.02180858 -0.09273119  0.02335742  0.00101696
  -0.02646713  0.03508472 -0.00146579  0.0525879 ]
 [ 0.044331

tf.Tensor(
[[-4.98183370e-02 -1.18476965e-01  4.52252142e-02 -9.90226045e-02
  -7.52443075e-03 -2.22222302e-02 -4.72296961e-06  3.62116247e-02
  -9.43962336e-02  1.05363555e-01]
 [ 7.58808479e-02 -1.63981184e-01  8.26890171e-02 -6.52662441e-02
  -1.08686239e-01 -6.29977509e-02  6.39220104e-02  1.80086717e-02
  -1.05255052e-01  5.53445108e-02]
 [ 6.85112998e-02 -4.84839752e-02  4.90187481e-02 -3.54752205e-02
  -1.45156607e-01 -4.23082709e-02  1.07496092e-02  6.70003146e-02
  -7.71833658e-02  3.49143147e-02]
 [ 7.05911517e-02 -1.18230708e-01  7.23087937e-02 -4.07303050e-02
  -9.66335237e-02 -5.49517497e-02  5.54904193e-02  8.09726194e-02
  -4.75339144e-02  4.14404050e-02]
 [ 3.72880995e-02 -1.23410292e-01  4.03400585e-02 -4.65015508e-02
  -1.41998436e-02 -4.83579561e-02  6.51967525e-02  2.72717420e-02
  -3.79740372e-02  8.00951645e-02]
 [-8.16312060e-02 -9.26797166e-02  3.72355543e-02 -1.35437712e-01
   6.51809052e-02  3.47152501e-02 -6.99928254e-02  6.49538562e-02
  -4.88593578e-02  1.0