In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets

In [3]:
(x, y), _ = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32)
y = tf.convert_to_tensor(y, dtype=tf.int32)
print(x.shape, y.shape, x.dtype, y.dtype)

(60000, 28, 28) (60000,) <dtype: 'float32'> <dtype: 'int32'>


In [4]:
train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128)

In [5]:
train_iter = iter(train_db)
sample = next(train_iter)

In [6]:
print('batch:', sample[0].shape, sample[1].shape)

batch: (128, 28, 28) (128,)


In [7]:
# [b, 784] => [b, 256] => [b, 128] => [b, 10]
# [di_i, di_o], [di_o]

In [17]:
# Variable 自動隨梯度更新
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.01))
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.01))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.01))
b3 = tf.Variable(tf.zeros([10]))

In [16]:
lr = 1e-3
for step, (x, y) in enumerate(train_db):
    # x [128,128,128]
    # y [128]
    # x [b, 28*28]
    x = tf.reshape(x, [-1, 28*28])
    with tf.GradientTape() as tape:
        # h1 = x@w1 + b1
        h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256])
        h2 = h1@w2 + b2
        out = h2@w3 + b3
        # loss
        #out: [b, 10]
        y_onehot = tf.one_hot(y, depth=10)

        # mse = mean(sum(y-out)^2)
        loss = tf.square(y_onehot - out)
        #mean
        loss = tf.reduce_mean(loss)
    grads = tape.gradient(loss,[w1, b1, w2, b2, w3, b3])
    # w1 = w1 - lr * w1* w1_grad
    w1.assign_sub(lr*grads[0]) # 園地更新，同對象
#     w1 = w1 - lr*grads[0] # 賦給新對象
    b1.assign_sub(lr*grads[1])
    w2.assign_sub(lr*grads[2])
    b2.assign_sub(lr*grads[3])
    w3.assign_sub(lr*grads[4])
    b3.assign_sub(lr*grads[5])
    
    if step%100 == 0:
        print(step, 'loss:', float(loss))
        

0 loss: 0.1672275960445404
100 loss: 0.06931205838918686
200 loss: 0.0552520677447319
300 loss: 0.05026322603225708
400 loss: 0.05261535570025444


In [18]:
lr = 1e-3
for epoch in range(10):
    for step, (x, y) in enumerate(train_db):
        # x [128,128,128]
        # y [128]
        # x [b, 28*28]
        x = tf.reshape(x, [-1, 28*28])
        with tf.GradientTape() as tape:
            # h1 = x@w1 + b1
            h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256])
            h2 = h1@w2 + b2
            out = h2@w3 + b3
            # loss
            #out: [b, 10]
            y_onehot = tf.one_hot(y, depth=10)

            # mse = mean(sum(y-out)^2)
            loss = tf.square(y_onehot - out)
            #mean
            loss = tf.reduce_mean(loss)
        grads = tape.gradient(loss,[w1, b1, w2, b2, w3, b3])
        # w1 = w1 - lr * w1* w1_grad
        w1.assign_sub(lr*grads[0]) # 園地更新，同對象
    #     w1 = w1 - lr*grads[0] # 賦給新對象
        b1.assign_sub(lr*grads[1])
        w2.assign_sub(lr*grads[2])
        b2.assign_sub(lr*grads[3])
        w3.assign_sub(lr*grads[4])
        b3.assign_sub(lr*grads[5])

        if step%100 == 0:
            print(epoch, step, 'loss:', float(loss))


0 0 loss: 0.15991362929344177
0 100 loss: 0.06630662083625793
0 200 loss: 0.05309118703007698
0 300 loss: 0.04925220459699631
0 400 loss: 0.05069420859217644
1 0 loss: 0.0455445721745491
1 100 loss: 0.049549125134944916
1 200 loss: 0.044769685715436935
1 300 loss: 0.04442111775279045
1 400 loss: 0.047204989939928055
2 0 loss: 0.04212489724159241
2 100 loss: 0.046403925865888596
2 200 loss: 0.042999159544706345
2 300 loss: 0.043066926300525665
2 400 loss: 0.0458727590739727
3 0 loss: 0.040701307356357574
3 100 loss: 0.044885262846946716
3 200 loss: 0.042186539620161057
3 300 loss: 0.04235023260116577
3 400 loss: 0.045080192387104034
4 0 loss: 0.03985799476504326
4 100 loss: 0.04398693889379501
4 200 loss: 0.04170404002070427
4 300 loss: 0.04187215119600296
4 400 loss: 0.044541094452142715
5 0 loss: 0.0392804779112339
5 100 loss: 0.04338859021663666
5 200 loss: 0.041374437510967255
5 300 loss: 0.04151427000761032
5 400 loss: 0.04415082186460495
6 0 loss: 0.03885024040937424
6 100 loss: 0