In [6]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 仅打印error

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets


# x [60k,28,28]
# y [60k]
(x,y),(x_test,y_test) =datasets.mnist.load_data()

# 【0 255】 [0 1]
x=tf.convert_to_tensor(x,dtype=tf.float32) / 255.
y=tf.convert_to_tensor(y,dtype=tf.int32)
x_test=tf.convert_to_tensor(x_test,dtype=tf.float32) / 255.
y_test=tf.convert_to_tensor(y_test,dtype=tf.int32)

print(x.shape,y.shape,x.dtype,y.dtype)
print(tf.reduce_min(x),tf.reduce_max(x))
print(tf.reduce_min(y),tf.reduce_max(y))

train_db= tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter=  iter(train_db)
sample =next(train_iter)
print('batch:',sample[0].shape,sample[1].shape)

test_db= tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(128)


# [b,784] -> [b,256] -> [b,128] -> [b,10]
# [dim_in, dim_out]
w1=tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1))
b1=tf.Variable(tf.zeros([256]))

w2=tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2=tf.Variable(tf.zeros([128]))

w3=tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3=tf.Variable(tf.zeros([10]))

lr=1e-3

for epoch in range (10): # 整个数据集迭代10次
    for step,(x,y) in enumerate(train_db) :
        # x: [128,28,28] -> [128,28*28]
        # y:[128]

        x=tf.reshape(x,[-1,28*28])

        with tf.GradientTape() as tape:
            # h1=x@w1+b1
            #[b,784]@[784,256]+[256]
            #print(x.shape,w1.shape,b1.shape)
            h1=x@w1+b1
            h1=tf.nn.relu(h1)

            h2=h1@w2+b2
            h2=tf.nn.relu(h2)

            out=h2@w3+b3

            # compute loss
            # out: [b,10]
            # y: [b] -> [b,10]
            y_onehot=tf.one_hot(y,depth=10)

            # mean square error
            # mse=mean(sum(y-out)^2)
            loss=tf.square(y_onehot-out)

            # mean
            loss=tf.reduce_mean(loss)

        grads= tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
        
        # w1=w1-lr*w1_grad
        w1.assign_sub(lr*grads[0])
        b1.assign_sub(lr*grads[1])
        w2.assign_sub(lr * grads[2])
        b2.assign_sub(lr * grads[3])
        w3.assign_sub(lr * grads[4])
        b3.assign_sub(lr * grads[5])

        if step % 100 ==0:
            print(epoch,step,'loss:', float(loss))
            
    # test/evaluation
    #[w1,b1,w2,b2,w3,b3]
    total_correct, total_num=0,0
    
    for step,(x,y) in enumerate(test_db):
        # [b,28,28]->[b,28*28]
        x=tf.reshape(x,[-1,28*28])
        
        # [b,784]-> [b,256]->[b,128]->[b,10]
        h1=tf.nn.relu(x@w1+b1)
        h2=tf.nn.relu(h1@w2+b2)
        out=h2@w3+b3
        
        # out: [b,10]
        prob=tf.nn.softmax(out,axis=1)
        pred=tf.argmax(prob,axis=1)
        pred=tf.cast(pred,dtype=tf.int32)
        
        correct=tf.equal(pred,y)
        correct=tf.cast(correct,dtype=tf.int32)
        correct=tf.reduce_sum(correct)
        total_correct=total_correct+int(correct)
        total_num=total_num+x.shape[0]
        
    acc=total_correct/total_num
    print('Test acc:',acc)
        

(60000, 28, 28) (60000,) <dtype: 'float32'> <dtype: 'int32'>
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32)
tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
batch: (128, 28, 28) (128,)
0 0 loss: 0.5343353152275085
0 100 loss: 0.21305398643016815
0 200 loss: 0.20807361602783203
0 300 loss: 0.1719106137752533
0 400 loss: 0.17316119372844696
Test acc: 0.1606
1 0 loss: 0.16123713552951813
1 100 loss: 0.15644954144954681
1 200 loss: 0.16277039051055908
1 300 loss: 0.13677451014518738
1 400 loss: 0.14312772452831268
Test acc: 0.233
2 0 loss: 0.13315367698669434
2 100 loss: 0.13440100848674774
2 200 loss: 0.13863633573055267
2 300 loss: 0.11750228703022003
2 400 loss: 0.12587854266166687
Test acc: 0.3108
3 0 loss: 0.11636205017566681
3 100 loss: 0.12064077705144882
3 200 loss: 0.12345468997955322
3 300 loss: 0.10520368814468384
3 400 loss: 0.11465738713741302
Test acc: 0.3723
4 0 loss: 0.10522766411304474
4 100 loss: 0.11104680597782135
