### [2] 第二部分：TensorFlow 模型建立与训练 <br/>
   &emsp; 2.1 模型（Model）与层（Layer） <br/>
   &emsp; 2.2 基础示例：多层感知机（MLP） <br/>
   &emsp; 2.3 卷积神经网络（CNN） <br/>
   &emsp; 2.4 循环神经网络（RNN） <br/>
   &emsp; 2.5 深度强化学习（DRL） <br/>
   &emsp; 2.6 Keras Pipeline * <br/>
   &emsp; 2.7 自定义层、损失函数和评估指标 *  <br/>

1. loss = tf.keras.losses.categorical_crossentropy(y_true=y, y_pred=y_pred)
2. loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
3. y_pred = tf.argmax(y_pred, axis=-1, output_type=tf.int32)  # 找到这一列中最大的值
4. y_pred = tf.one_hot(y_pred, depth=10)  # 转变成one_hot
5. y_pred = tf.cast(y_pred, dtype=tf.int32)  # 转化数据类型为int

## 2.1 模型（Model）与层（Layer）

Keras 模型以类的形式呈现，我们可以通过继承 tf.keras.Model 这个 Python 类来定义自己的模型。在继承类中，我们需要重写 __init__() （构造函数，初始化）和 call(input) （模型调用）两个方法，同时也可以根据需要增加自定义的方法。

In [44]:
import tensorflow as tf

In [45]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()     # Python 2 下使用 super(MyModel, self).__init__()
        # 此处添加初始化代码（包含 call 方法中会用到的层），例如
        # layer1 = tf.keras.layers.BuiltInLayer(...)
        # layer2 = MyCustomLayer(...)

    def call(self, input):
        # 此处添加模型调用的代码（处理输入并返回输出），例如
        # x = layer1(input)
        # output = layer2(x)
        return output

    # 还可以添加自定义的方法

![image.png](attachment:image.png)

继承 tf.keras.Model 后，我们同时可以使用父类的若干方法和属性，例如在实例化类 model = Model() 后，可以通过 model.variables 这一属性直接获得模型中的所有变量，**免去我们一个个显式指定变量的麻烦。**

In [46]:
### 上一章中简单的线性模型 y_pred = a * X + b ，我们可以通过模型类的方式编写如下：

X = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])  # batch == 2 , 维度 == 3  [batch = 2, index = 3]
y = tf.constant([[10.0], [20.0]])


class Linear(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = tf.keras.layers.Dense(
            units=1,   # 一个神经元，隐藏层的维度 == 1
            activation=None,  # 定义激活函数
            kernel_initializer=tf.zeros_initializer(),  # 初始化变量 a  {依据 X 来显示变化维度}
            bias_initializer=tf.zeros_initializer()     # 初始化变量 b
        )

    def call(self, input):
        output = self.dense(input)
        return output


# 以下代码结构与前节类似
model = Linear()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
for i in range(100):
    with tf.GradientTape() as tape:
        y_pred = model(X)      # 调用模型 y_pred = model(X) 而不是显式写出 y_pred = a * X + b
        loss = tf.reduce_mean(tf.square(y_pred - y))
    grads = tape.gradient(loss, model.variables)    # 使用 model.variables 这一属性直接获得模型中的所有变量
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
    
print(model.variables)

[<tf.Variable 'linear/dense_14/kernel:0' shape=(3, 1) dtype=float32, numpy=
array([[0.40784496],
       [1.191065  ],
       [1.9742855 ]], dtype=float32)>, <tf.Variable 'linear/dense_14/bias:0' shape=(1,) dtype=float32, numpy=array([0.78322077], dtype=float32)>]


全连接层（Full-connected Layer，tf.keras.layers.Dense ）是 Keras 中最基础和常用的层之一，对输入矩阵 A 进行 f(AW + b) 的线性变换 + 激活函数操作。如果不指定激活函数，即是纯粹的线性变换 AW + b。具体而言，给定输入张量 input = [batch_size, input_dim] ，该层对输入张量首先进行 tf.matmul(input, kernel) + bias 的线性变换（ kernel 和 bias 是层中可训练的变量），然后对线性变换后张量的每个元素通过激活函数 activation ，从而输出形状为 [batch_size, units] 的二维张量。
![image.png](attachment:image.png)

## 2.2 基础示例：多层感知机（MLP）

(1) 数据获取及预处理

In [47]:
class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
        # MNIST中的图像默认为uint8（0-255的数字）。以下代码将其归一化到0-1之间的浮点数，并在最后增加一维作为颜色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, self.num_train_data, batch_size)
        return self.train_data[index, :], self.train_label[index]

In [48]:
from tensorflow_core.examples.tutorials.mnist import input_data
import numpy as np
class MNISTLoader_my_download():
    def __init__(self):
        # 读取数据，预先已经下载了相应的数据直
        mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
        self.train_data = mnist.train.images
        self.train_label = mnist.train.labels
        self.test_data = mnist.test.images
        self.test_label = mnist.test.labels
        
        # MNIST中的图像默认为uint8（0-255的数字）。以下代码将其归一化到0-1之间的浮点数，并在最后增加一维作为颜色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 784, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 784, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, self.num_train_data, batch_size)
        return self.train_data[index, :], self.train_label[index]

(2) 模型的构建： tf.keras.Model 和 tf.keras.layers

In [49]:
class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()    # Flatten层将除第一维（batch_size）以外的维度展平
        self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)  # 第一层神经元的个数为100
        self.dense2 = tf.keras.layers.Dense(units=10)   # 第二层神经元的个数为10,输出一个样本的维度为10

    def call(self, inputs):         # [batch_size, 28, 28, 1]
        x = self.flatten(inputs)    # [batch_size, 784]
        x = self.dense1(x)          # [batch_size, 100]
        x = self.dense2(x)          # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output

![image.png](attachment:image.png)

（3）模型的训练： tf.keras.losses 和 tf.keras.optimizer

In [71]:
# 定义一些模型超参数：
num_epochs = 5
batch_size = 50
learning_rate = 0.001

# 实例化模型和数据读取类，并实例化一个 tf.keras.optimizer 的优化器（这里使用常用的 Adam 优化器）：
model = MLP()
# data_loader = MNISTLoader() # 导入数据 
data_loader = MNISTLoader_my_download()  # 导入数据
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 更新梯度

num_batches = int(mnist.train.num_examples // batch_size * num_epochs)
# num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
batch 0: loss 2.302737
batch 1: loss 2.302408
batch 2: loss 2.301512
batch 3: loss 2.302264
batch 4: loss 2.301797
batch 5: loss 2.300742
batch 6: loss 2.299592
batch 7: loss 2.298411
batch 8: loss 2.300669
batch 9: loss 2.300016
batch 10: loss 2.301450
batch 11: loss 2.297090
batch 12: loss 2.296710
batch 13: loss 2.296400
batch 14: loss 2.298450
batch 15: loss 2.294831
batch 16: loss 2.297379
batch 17: loss 2.296067
batch 18: loss 2.294367
batch 19: loss 2.291650
batch 20: loss 2.292817
batch 21: loss 2.295512
batch 22: loss 2.290362
batch 23: loss 2.294416
batch 24: loss 2.291180
batch 25: loss 2.288800
batch 26: loss 2.292525
batch 27: loss 2.299174
batch 28: loss 2.289680
batch 29: loss 2.288069
batch 30: loss 2.287610
batch 31: loss 2.286079
batch 32: loss 2.289759
batch 33: loss 2.294421


batch 329: loss 1.779322
batch 330: loss 1.812070
batch 331: loss 1.736311
batch 332: loss 1.799247
batch 333: loss 1.817841
batch 334: loss 1.814888
batch 335: loss 1.776089
batch 336: loss 1.740243
batch 337: loss 1.880329
batch 338: loss 1.710312
batch 339: loss 1.806216
batch 340: loss 1.722430
batch 341: loss 1.723161
batch 342: loss 1.846458
batch 343: loss 1.841937
batch 344: loss 1.729872
batch 345: loss 1.747376
batch 346: loss 1.740714
batch 347: loss 1.779876
batch 348: loss 1.825102
batch 349: loss 1.800129
batch 350: loss 1.825317
batch 351: loss 1.677890
batch 352: loss 1.820233
batch 353: loss 1.716917
batch 354: loss 1.850020
batch 355: loss 1.752982
batch 356: loss 1.789627
batch 357: loss 1.742160
batch 358: loss 1.770703
batch 359: loss 1.696355
batch 360: loss 1.729725
batch 361: loss 1.787565
batch 362: loss 1.793628
batch 363: loss 1.695866
batch 364: loss 1.737898
batch 365: loss 1.772923
batch 366: loss 1.634984
batch 367: loss 1.771467
batch 368: loss 1.686776


batch 658: loss 1.313324
batch 659: loss 1.137984
batch 660: loss 1.247879
batch 661: loss 1.267423
batch 662: loss 1.244373
batch 663: loss 1.272451
batch 664: loss 1.113421
batch 665: loss 1.205354
batch 666: loss 1.032809
batch 667: loss 1.085950
batch 668: loss 1.191219
batch 669: loss 1.075277
batch 670: loss 1.238969
batch 671: loss 1.200842
batch 672: loss 1.244428
batch 673: loss 1.170098
batch 674: loss 1.128130
batch 675: loss 1.123874
batch 676: loss 1.066667
batch 677: loss 1.171685
batch 678: loss 1.294729
batch 679: loss 1.201226
batch 680: loss 1.220352
batch 681: loss 1.328410
batch 682: loss 1.142048
batch 683: loss 1.189636
batch 684: loss 1.215866
batch 685: loss 1.140361
batch 686: loss 1.200853
batch 687: loss 1.169989
batch 688: loss 1.100171
batch 689: loss 1.211566
batch 690: loss 1.169530
batch 691: loss 1.066373
batch 692: loss 1.263639
batch 693: loss 1.145710
batch 694: loss 1.148293
batch 695: loss 1.144552
batch 696: loss 1.214194
batch 697: loss 1.034085


batch 988: loss 0.779473
batch 989: loss 0.798905
batch 990: loss 0.725760
batch 991: loss 0.826668
batch 992: loss 0.736949
batch 993: loss 0.902399
batch 994: loss 0.762141
batch 995: loss 1.010628
batch 996: loss 0.793066
batch 997: loss 0.763138
batch 998: loss 0.765709
batch 999: loss 0.797738
batch 1000: loss 0.813489
batch 1001: loss 0.942767
batch 1002: loss 0.792508
batch 1003: loss 1.028310
batch 1004: loss 0.979139
batch 1005: loss 0.902636
batch 1006: loss 0.850368
batch 1007: loss 0.870594
batch 1008: loss 0.868104
batch 1009: loss 1.046704
batch 1010: loss 0.995228
batch 1011: loss 0.832392
batch 1012: loss 0.840631
batch 1013: loss 0.821585
batch 1014: loss 0.831314
batch 1015: loss 0.844952
batch 1016: loss 1.034488
batch 1017: loss 0.975213
batch 1018: loss 0.915192
batch 1019: loss 0.945287
batch 1020: loss 0.843559
batch 1021: loss 0.937408
batch 1022: loss 0.760589
batch 1023: loss 0.966824
batch 1024: loss 0.793837
batch 1025: loss 0.788835
batch 1026: loss 0.80303

batch 1315: loss 0.828472
batch 1316: loss 0.719067
batch 1317: loss 0.749363
batch 1318: loss 0.697901
batch 1319: loss 0.730561
batch 1320: loss 0.785040
batch 1321: loss 0.657601
batch 1322: loss 0.895703
batch 1323: loss 0.784148
batch 1324: loss 0.778423
batch 1325: loss 0.822530
batch 1326: loss 0.945884
batch 1327: loss 0.651696
batch 1328: loss 0.658986
batch 1329: loss 0.812359
batch 1330: loss 0.718260
batch 1331: loss 0.854096
batch 1332: loss 0.855842
batch 1333: loss 0.659609
batch 1334: loss 0.604726
batch 1335: loss 0.697950
batch 1336: loss 0.709639
batch 1337: loss 0.530284
batch 1338: loss 0.762257
batch 1339: loss 0.782848
batch 1340: loss 0.752343
batch 1341: loss 0.694008
batch 1342: loss 0.716730
batch 1343: loss 0.703523
batch 1344: loss 0.856348
batch 1345: loss 0.735165
batch 1346: loss 0.797007
batch 1347: loss 0.733024
batch 1348: loss 0.672935
batch 1349: loss 0.836841
batch 1350: loss 0.741925
batch 1351: loss 0.608967
batch 1352: loss 0.864982
batch 1353: 

batch 1660: loss 0.742941
batch 1661: loss 0.621400
batch 1662: loss 0.571081
batch 1663: loss 0.491856
batch 1664: loss 0.527898
batch 1665: loss 0.529535
batch 1666: loss 0.536351
batch 1667: loss 0.749479
batch 1668: loss 0.598538
batch 1669: loss 0.652078
batch 1670: loss 0.460014
batch 1671: loss 0.634294
batch 1672: loss 0.600774
batch 1673: loss 0.681753
batch 1674: loss 0.603559
batch 1675: loss 0.559044
batch 1676: loss 0.724089
batch 1677: loss 0.623746
batch 1678: loss 0.608121
batch 1679: loss 0.636019
batch 1680: loss 0.543847
batch 1681: loss 0.603537
batch 1682: loss 0.528824
batch 1683: loss 0.649044
batch 1684: loss 0.670034
batch 1685: loss 0.540050
batch 1686: loss 0.681319
batch 1687: loss 0.642758
batch 1688: loss 0.547140
batch 1689: loss 0.633345
batch 1690: loss 0.709013
batch 1691: loss 0.635720
batch 1692: loss 0.631721
batch 1693: loss 0.736520
batch 1694: loss 0.662441
batch 1695: loss 0.544344
batch 1696: loss 0.463643
batch 1697: loss 0.579913
batch 1698: 

batch 1985: loss 0.502820
batch 1986: loss 0.462729
batch 1987: loss 0.592146
batch 1988: loss 0.551110
batch 1989: loss 0.603542
batch 1990: loss 0.454188
batch 1991: loss 0.535552
batch 1992: loss 0.498698
batch 1993: loss 0.614712
batch 1994: loss 0.509389
batch 1995: loss 0.327282
batch 1996: loss 0.527360
batch 1997: loss 0.583174
batch 1998: loss 0.583879
batch 1999: loss 0.483683
batch 2000: loss 0.600510
batch 2001: loss 0.581991
batch 2002: loss 0.525417
batch 2003: loss 0.576558
batch 2004: loss 0.399565
batch 2005: loss 0.704991
batch 2006: loss 0.559799
batch 2007: loss 0.522064
batch 2008: loss 0.454345
batch 2009: loss 0.635854
batch 2010: loss 0.650735
batch 2011: loss 0.358388
batch 2012: loss 0.623755
batch 2013: loss 0.618999
batch 2014: loss 0.481576
batch 2015: loss 0.506604
batch 2016: loss 0.500358
batch 2017: loss 0.572885
batch 2018: loss 0.543709
batch 2019: loss 0.456975
batch 2020: loss 0.683607
batch 2021: loss 0.533992
batch 2022: loss 0.692491
batch 2023: 

batch 2304: loss 0.538079
batch 2305: loss 0.352232
batch 2306: loss 0.473609
batch 2307: loss 0.452842
batch 2308: loss 0.521600
batch 2309: loss 0.451148
batch 2310: loss 0.462105
batch 2311: loss 0.424138
batch 2312: loss 0.539762
batch 2313: loss 0.429151
batch 2314: loss 0.362911
batch 2315: loss 0.704165
batch 2316: loss 0.516613
batch 2317: loss 0.490598
batch 2318: loss 0.408075
batch 2319: loss 0.640072
batch 2320: loss 0.525076
batch 2321: loss 0.516074
batch 2322: loss 0.656921
batch 2323: loss 0.567616
batch 2324: loss 0.462165
batch 2325: loss 0.516546
batch 2326: loss 0.534195
batch 2327: loss 0.767779
batch 2328: loss 0.554142
batch 2329: loss 0.583144
batch 2330: loss 0.388535
batch 2331: loss 0.477650
batch 2332: loss 0.357136
batch 2333: loss 0.533975
batch 2334: loss 0.337037
batch 2335: loss 0.350052
batch 2336: loss 0.440154
batch 2337: loss 0.507241
batch 2338: loss 0.569989
batch 2339: loss 0.626968
batch 2340: loss 0.536178
batch 2341: loss 0.422516
batch 2342: 

batch 2647: loss 0.522115
batch 2648: loss 0.481855
batch 2649: loss 0.432011
batch 2650: loss 0.409274
batch 2651: loss 0.390825
batch 2652: loss 0.514065
batch 2653: loss 0.490315
batch 2654: loss 0.532483
batch 2655: loss 0.499024
batch 2656: loss 0.484537
batch 2657: loss 0.455234
batch 2658: loss 0.386093
batch 2659: loss 0.425724
batch 2660: loss 0.384272
batch 2661: loss 0.508326
batch 2662: loss 0.583746
batch 2663: loss 0.357182
batch 2664: loss 0.455180
batch 2665: loss 0.402197
batch 2666: loss 0.493363
batch 2667: loss 0.540217
batch 2668: loss 0.332989
batch 2669: loss 0.457689
batch 2670: loss 0.502801
batch 2671: loss 0.353716
batch 2672: loss 0.444835
batch 2673: loss 0.327018
batch 2674: loss 0.490640
batch 2675: loss 0.483890
batch 2676: loss 0.318625
batch 2677: loss 0.500132
batch 2678: loss 0.511150
batch 2679: loss 0.509612
batch 2680: loss 0.448240
batch 2681: loss 0.501070
batch 2682: loss 0.461658
batch 2683: loss 0.405325
batch 2684: loss 0.441950
batch 2685: 

batch 2971: loss 0.348773
batch 2972: loss 0.343248
batch 2973: loss 0.532583
batch 2974: loss 0.869768
batch 2975: loss 0.378594
batch 2976: loss 0.551596
batch 2977: loss 0.388113
batch 2978: loss 0.523232
batch 2979: loss 0.380376
batch 2980: loss 0.285075
batch 2981: loss 0.489945
batch 2982: loss 0.332423
batch 2983: loss 0.514996
batch 2984: loss 0.595802
batch 2985: loss 0.388030
batch 2986: loss 0.364315
batch 2987: loss 0.663150
batch 2988: loss 0.347403
batch 2989: loss 0.355471
batch 2990: loss 0.481290
batch 2991: loss 0.648907
batch 2992: loss 0.275919
batch 2993: loss 0.523022
batch 2994: loss 0.413094
batch 2995: loss 0.458851
batch 2996: loss 0.473558
batch 2997: loss 0.630487
batch 2998: loss 0.457677
batch 2999: loss 0.559177
batch 3000: loss 0.342417
batch 3001: loss 0.405674
batch 3002: loss 0.586295
batch 3003: loss 0.431183
batch 3004: loss 0.397656
batch 3005: loss 0.405275
batch 3006: loss 0.451555
batch 3007: loss 0.491145
batch 3008: loss 0.357715
batch 3009: 

batch 3287: loss 0.440041
batch 3288: loss 0.446696
batch 3289: loss 0.640741
batch 3290: loss 0.420244
batch 3291: loss 0.348456
batch 3292: loss 0.304447
batch 3293: loss 0.587246
batch 3294: loss 0.286270
batch 3295: loss 0.380262
batch 3296: loss 0.389711
batch 3297: loss 0.299414
batch 3298: loss 0.550838
batch 3299: loss 0.364145
batch 3300: loss 0.389987
batch 3301: loss 0.767209
batch 3302: loss 0.488545
batch 3303: loss 0.222782
batch 3304: loss 0.425206
batch 3305: loss 0.371988
batch 3306: loss 0.493340
batch 3307: loss 0.420860
batch 3308: loss 0.365401
batch 3309: loss 0.349635
batch 3310: loss 0.438169
batch 3311: loss 0.394916
batch 3312: loss 0.267870
batch 3313: loss 0.640057
batch 3314: loss 0.528706
batch 3315: loss 0.451221
batch 3316: loss 0.555040
batch 3317: loss 0.473990
batch 3318: loss 0.407052
batch 3319: loss 0.384660
batch 3320: loss 0.352370
batch 3321: loss 0.471391
batch 3322: loss 0.443465
batch 3323: loss 0.357420
batch 3324: loss 0.340039
batch 3325: 

batch 3614: loss 0.412384
batch 3615: loss 0.412806
batch 3616: loss 0.304316
batch 3617: loss 0.468352
batch 3618: loss 0.197582
batch 3619: loss 0.461841
batch 3620: loss 0.300067
batch 3621: loss 0.540738
batch 3622: loss 0.268405
batch 3623: loss 0.403177
batch 3624: loss 0.503772
batch 3625: loss 0.497152
batch 3626: loss 0.327041
batch 3627: loss 0.449730
batch 3628: loss 0.322081
batch 3629: loss 0.445895
batch 3630: loss 0.372191
batch 3631: loss 0.516836
batch 3632: loss 0.396565
batch 3633: loss 0.305687
batch 3634: loss 0.407060
batch 3635: loss 0.390925
batch 3636: loss 0.389066
batch 3637: loss 0.293810
batch 3638: loss 0.418603
batch 3639: loss 0.257236
batch 3640: loss 0.489242
batch 3641: loss 0.371842
batch 3642: loss 0.393030
batch 3643: loss 0.273848
batch 3644: loss 0.249073
batch 3645: loss 0.472214
batch 3646: loss 0.315155
batch 3647: loss 0.396976
batch 3648: loss 0.706475
batch 3649: loss 0.336908
batch 3650: loss 0.349580
batch 3651: loss 0.323303
batch 3652: 

batch 3938: loss 0.473722
batch 3939: loss 0.281419
batch 3940: loss 0.443705
batch 3941: loss 0.348358
batch 3942: loss 0.296333
batch 3943: loss 0.328953
batch 3944: loss 0.432105
batch 3945: loss 0.544027
batch 3946: loss 0.251942
batch 3947: loss 0.254046
batch 3948: loss 0.492717
batch 3949: loss 0.323732
batch 3950: loss 0.246380
batch 3951: loss 0.358905
batch 3952: loss 0.392924
batch 3953: loss 0.384704
batch 3954: loss 0.294514
batch 3955: loss 0.337434
batch 3956: loss 0.452066
batch 3957: loss 0.455102
batch 3958: loss 0.539370
batch 3959: loss 0.234660
batch 3960: loss 0.331449
batch 3961: loss 0.496295
batch 3962: loss 0.406241
batch 3963: loss 0.565299
batch 3964: loss 0.251347
batch 3965: loss 0.257737
batch 3966: loss 0.428628
batch 3967: loss 0.377279
batch 3968: loss 0.289436
batch 3969: loss 0.249351
batch 3970: loss 0.370388
batch 3971: loss 0.328654
batch 3972: loss 0.256290
batch 3973: loss 0.338254
batch 3974: loss 0.599745
batch 3975: loss 0.345177
batch 3976: 

batch 4256: loss 0.374249
batch 4257: loss 0.301681
batch 4258: loss 0.240831
batch 4259: loss 0.416171
batch 4260: loss 0.411949
batch 4261: loss 0.278159
batch 4262: loss 0.395672
batch 4263: loss 0.347915
batch 4264: loss 0.609179
batch 4265: loss 0.432467
batch 4266: loss 0.283585
batch 4267: loss 0.675615
batch 4268: loss 0.348974
batch 4269: loss 0.343135
batch 4270: loss 0.402950
batch 4271: loss 0.257130
batch 4272: loss 0.409937
batch 4273: loss 0.330415
batch 4274: loss 0.489084
batch 4275: loss 0.445269
batch 4276: loss 0.306769
batch 4277: loss 0.442843
batch 4278: loss 0.545857
batch 4279: loss 0.311105
batch 4280: loss 0.400817
batch 4281: loss 0.404527
batch 4282: loss 0.438973
batch 4283: loss 0.372852
batch 4284: loss 0.448691
batch 4285: loss 0.445554
batch 4286: loss 0.646954
batch 4287: loss 0.361434
batch 4288: loss 0.415164
batch 4289: loss 0.252340
batch 4290: loss 0.461567
batch 4291: loss 0.307957
batch 4292: loss 0.400100
batch 4293: loss 0.319241
batch 4294: 

batch 4582: loss 0.352540
batch 4583: loss 0.382488
batch 4584: loss 0.271653
batch 4585: loss 0.242395
batch 4586: loss 0.638528
batch 4587: loss 0.416301
batch 4588: loss 0.367684
batch 4589: loss 0.510305
batch 4590: loss 0.286477
batch 4591: loss 0.331861
batch 4592: loss 0.469722
batch 4593: loss 0.331497
batch 4594: loss 0.302863
batch 4595: loss 0.265965
batch 4596: loss 0.399263
batch 4597: loss 0.267507
batch 4598: loss 0.346888
batch 4599: loss 0.516242
batch 4600: loss 0.395157
batch 4601: loss 0.490340
batch 4602: loss 0.268777
batch 4603: loss 0.489552
batch 4604: loss 0.588774
batch 4605: loss 0.495127
batch 4606: loss 0.282503
batch 4607: loss 0.194674
batch 4608: loss 0.411220
batch 4609: loss 0.534766
batch 4610: loss 0.571908
batch 4611: loss 0.386816
batch 4612: loss 0.215613
batch 4613: loss 0.268045
batch 4614: loss 0.188373
batch 4615: loss 0.243952
batch 4616: loss 0.241641
batch 4617: loss 0.302829
batch 4618: loss 0.257562
batch 4619: loss 0.231224
batch 4620: 

batch 4915: loss 0.332655
batch 4916: loss 0.500868
batch 4917: loss 0.385703
batch 4918: loss 0.434147
batch 4919: loss 0.269830
batch 4920: loss 0.326056
batch 4921: loss 0.182202
batch 4922: loss 0.312225
batch 4923: loss 0.428261
batch 4924: loss 0.297859
batch 4925: loss 0.582195
batch 4926: loss 0.425064
batch 4927: loss 0.289955
batch 4928: loss 0.682517
batch 4929: loss 0.529096
batch 4930: loss 0.392533
batch 4931: loss 0.324667
batch 4932: loss 0.350349
batch 4933: loss 0.857078
batch 4934: loss 0.735353
batch 4935: loss 0.404041
batch 4936: loss 0.411001
batch 4937: loss 0.157346
batch 4938: loss 0.299790
batch 4939: loss 0.393014
batch 4940: loss 0.377213
batch 4941: loss 0.360467
batch 4942: loss 0.306967
batch 4943: loss 0.360180
batch 4944: loss 0.347500
batch 4945: loss 0.461712
batch 4946: loss 0.380496
batch 4947: loss 0.388923
batch 4948: loss 0.250653
batch 4949: loss 0.474330
batch 4950: loss 0.314540
batch 4951: loss 0.300526
batch 4952: loss 0.258396
batch 4953: 

batch 5251: loss 0.260360
batch 5252: loss 0.311224
batch 5253: loss 0.478776
batch 5254: loss 0.613307
batch 5255: loss 0.331076
batch 5256: loss 0.398283
batch 5257: loss 0.383907
batch 5258: loss 0.432664
batch 5259: loss 0.393962
batch 5260: loss 0.261640
batch 5261: loss 0.313375
batch 5262: loss 0.358879
batch 5263: loss 0.229007
batch 5264: loss 0.292671
batch 5265: loss 0.265925
batch 5266: loss 0.420327
batch 5267: loss 0.556172
batch 5268: loss 0.430483
batch 5269: loss 0.311196
batch 5270: loss 0.261229
batch 5271: loss 0.349795
batch 5272: loss 0.628482
batch 5273: loss 0.567313
batch 5274: loss 0.303265
batch 5275: loss 0.433799
batch 5276: loss 0.386368
batch 5277: loss 0.359747
batch 5278: loss 0.284458
batch 5279: loss 0.278552
batch 5280: loss 0.368602
batch 5281: loss 0.234357
batch 5282: loss 0.446541
batch 5283: loss 0.249616
batch 5284: loss 0.476021
batch 5285: loss 0.276735
batch 5286: loss 0.408117
batch 5287: loss 0.246099
batch 5288: loss 0.367941
batch 5289: 

（4）模型的评估： tf.keras.metrics

我们使用测试集评估模型的性能。这里，我们使用 tf.keras.metrics 中的 SparseCategoricalAccuracy 评估器来评估模型在测试集上的性能，该评估器能够对模型预测的结果与真实结果进行比较，并输出预测正确的样本数占总样本数的比例。

In [73]:
# sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
categorical_accuracy = tf.keras.metrics.CategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
for batch_index in range(num_batches):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = model.predict(data_loader.test_data[start_index: end_index])
    categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % categorical_accuracy.result())

test accuracy: 0.909900


## 2.3 卷积神经网络（CNN）

卷积神经网络 （Convolutional Neural Network, CNN）是一种结构类似于人类或动物的 视觉系统 的人工神经网络，包含一个或多个卷积层（Convolutional Layer）、池化层（Pooling Layer）和全连接层（Fully-connected Layer）。

In [86]:
class CNN(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(
            filters=32,             # 卷积层神经元（卷积核）数目
            kernel_size=[5, 5],     # 感受野大小
            padding='same',         # padding策略（vaild 或 same） 在卷积结束后会补上0
            activation=tf.nn.relu   # 激活函数
        )
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        self.conv2 = tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=[5, 5],
            padding='same',
            activation=tf.nn.relu
        )
        self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        # 后面部分和MLP类似
        self.flatten = tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,))  # 打平成一维向量
        self.dense1 = tf.keras.layers.Dense(units=1024, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)

    def call(self, inputs):
        x = self.conv1(inputs)                  # [batch_size, 28, 28, 32]
        x = self.pool1(x)                       # [batch_size, 14, 14, 32]
        x = self.conv2(x)                       # [batch_size, 14, 14, 64]
        x = self.pool2(x)                       # [batch_size, 7, 7, 64]
        x = self.flatten(x)                     # [batch_size, 7 * 7 * 64]
        x = self.dense1(x)                      # [batch_size, 1024]
        x = self.dense2(x)                      # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output

![image.png](attachment:image.png)

In [87]:
from tensorflow_core.examples.tutorials.mnist import input_data
import numpy as np
class MNISTLoader_my_download():
    def __init__(self):
        # 读取数据，预先已经下载了相应的数据直
        mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
        self.train_data = np.reshape(mnist.train.images,(-1, 28, 28))  # 将二维[60000，784] 转化成三维 [60000, 28 , 28]
        self.train_label = mnist.train.labels
        self.test_data = np.reshape(mnist.test.images,(-1, 28, 28))
        self.test_label = mnist.test.labels
        
        # MNIST中的图像默认为uint8（0-255的数字）。以下代码将其归一化到0-1之间的浮点数，并在最后增加一维作为颜色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, self.num_train_data, batch_size)
        return self.train_data[index, :], self.train_label[index]

In [88]:
# 定义一些模型超参数：
num_epochs = 5
batch_size = 50
learning_rate = 0.001

# 实例化模型和数据读取类，并实例化一个 tf.keras.optimizer 的优化器（这里使用常用的 Adam 优化器）：
model = CNN()
# data_loader = MNISTLoader() # 导入数据 
data_loader = MNISTLoader_my_download()  # 导入数据
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 更新梯度

num_batches = int(mnist.train.num_examples // batch_size * num_epochs)
# num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

    
# 模型报错：原因是我下载的数据中train的维度为：found ndim=3，而需要卷积的维度为 expected ndim=4
#           下载的数据已经将原来的数据打平了。

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
batch 0: loss 2.302610
batch 1: loss 2.303320
batch 2: loss 2.302329
batch 3: loss 2.311814
batch 4: loss 2.302142
batch 5: loss 2.305665
batch 6: loss 2.299025
batch 7: loss 2.305187
batch 8: loss 2.303758
batch 9: loss 2.301320
batch 10: loss 2.304811
batch 11: loss 2.301137
batch 12: loss 2.299722
batch 13: loss 2.300665
batch 14: loss 2.300334
batch 15: loss 2.300955
batch 16: loss 2.301842
batch 17: loss 2.301183
batch 18: loss 2.297432
batch 19: loss 2.302504
batch 20: loss 2.307119
batch 21: loss 2.296802
batch 22: loss 2.286331
batch 23: loss 2.310440
batch 24: loss 2.281254
batch 25: loss 2.310944
batch 26: loss 2.310174
batch 27: loss 2.308407
batch 28: loss 2.270927
batch 29: loss 2.315083
batch 30: loss 2.317574
batch 31: loss 2.296562
batch 32: loss 2.299086
batch 33: loss 2.287718


batch 326: loss 0.339878
batch 327: loss 0.531572
batch 328: loss 0.495930
batch 329: loss 0.640722
batch 330: loss 0.544197
batch 331: loss 0.413505
batch 332: loss 0.521364
batch 333: loss 0.281159
batch 334: loss 0.266581
batch 335: loss 0.458856
batch 336: loss 0.393953
batch 337: loss 0.310898
batch 338: loss 0.711686
batch 339: loss 0.364978
batch 340: loss 0.517823
batch 341: loss 0.435396
batch 342: loss 0.387338
batch 343: loss 0.313671
batch 344: loss 0.393093
batch 345: loss 0.486100
batch 346: loss 0.428407
batch 347: loss 0.330168
batch 348: loss 0.266527
batch 349: loss 0.325472
batch 350: loss 0.270474
batch 351: loss 0.304703
batch 352: loss 0.393963
batch 353: loss 0.498106
batch 354: loss 0.261925
batch 355: loss 0.510262
batch 356: loss 0.239531
batch 357: loss 0.540869
batch 358: loss 0.399332
batch 359: loss 0.569441
batch 360: loss 0.426267
batch 361: loss 0.430040
batch 362: loss 0.312437
batch 363: loss 0.423610
batch 364: loss 0.269415
batch 365: loss 0.373814


batch 656: loss 0.279336
batch 657: loss 0.170540
batch 658: loss 0.255694
batch 659: loss 0.161557
batch 660: loss 0.418509
batch 661: loss 0.229591
batch 662: loss 0.225408
batch 663: loss 0.123892
batch 664: loss 0.269162
batch 665: loss 0.325617
batch 666: loss 0.280641
batch 667: loss 0.246556
batch 668: loss 0.297851
batch 669: loss 0.191942
batch 670: loss 0.199046
batch 671: loss 0.238438
batch 672: loss 0.281630
batch 673: loss 0.184505
batch 674: loss 0.309846
batch 675: loss 0.394855
batch 676: loss 0.193279
batch 677: loss 0.322092
batch 678: loss 0.111408
batch 679: loss 0.250559
batch 680: loss 0.274375
batch 681: loss 0.310714
batch 682: loss 0.091508
batch 683: loss 0.550545
batch 684: loss 0.350373
batch 685: loss 0.218725
batch 686: loss 0.360961
batch 687: loss 0.421443
batch 688: loss 0.165655
batch 689: loss 0.263303
batch 690: loss 0.093505
batch 691: loss 0.239689
batch 692: loss 0.203450
batch 693: loss 0.318673
batch 694: loss 0.234849
batch 695: loss 0.423512


batch 986: loss 0.282865
batch 987: loss 0.102233
batch 988: loss 0.236565
batch 989: loss 0.261108
batch 990: loss 0.175761
batch 991: loss 0.224866
batch 992: loss 0.129213
batch 993: loss 0.113882
batch 994: loss 0.121080
batch 995: loss 0.177442
batch 996: loss 0.072555
batch 997: loss 0.087875
batch 998: loss 0.211567
batch 999: loss 0.360696
batch 1000: loss 0.210690
batch 1001: loss 0.214100
batch 1002: loss 0.105322
batch 1003: loss 0.123677
batch 1004: loss 0.097079
batch 1005: loss 0.454705
batch 1006: loss 0.192189
batch 1007: loss 0.278550
batch 1008: loss 0.251727
batch 1009: loss 0.269410
batch 1010: loss 0.224247
batch 1011: loss 0.362896
batch 1012: loss 0.261096
batch 1013: loss 0.198965
batch 1014: loss 0.244445
batch 1015: loss 0.426824
batch 1016: loss 0.241730
batch 1017: loss 0.204897
batch 1018: loss 0.062970
batch 1019: loss 0.183338
batch 1020: loss 0.154945
batch 1021: loss 0.407700
batch 1022: loss 0.111065
batch 1023: loss 0.139322
batch 1024: loss 0.222580


batch 1304: loss 0.108534
batch 1305: loss 0.110516
batch 1306: loss 0.248810
batch 1307: loss 0.075219
batch 1308: loss 0.121395
batch 1309: loss 0.266211
batch 1310: loss 0.105100
batch 1311: loss 0.197124
batch 1312: loss 0.231371
batch 1313: loss 0.070434
batch 1314: loss 0.118525
batch 1315: loss 0.178785
batch 1316: loss 0.094097
batch 1317: loss 0.073105
batch 1318: loss 0.113155
batch 1319: loss 0.170287
batch 1320: loss 0.331402
batch 1321: loss 0.186673
batch 1322: loss 0.142356
batch 1323: loss 0.113340
batch 1324: loss 0.133873
batch 1325: loss 0.193989
batch 1326: loss 0.148091
batch 1327: loss 0.180050
batch 1328: loss 0.132076
batch 1329: loss 0.203228
batch 1330: loss 0.119972
batch 1331: loss 0.141144
batch 1332: loss 0.094908
batch 1333: loss 0.131663
batch 1334: loss 0.318414
batch 1335: loss 0.077589
batch 1336: loss 0.149586
batch 1337: loss 0.111740
batch 1338: loss 0.394931
batch 1339: loss 0.215845
batch 1340: loss 0.037532
batch 1341: loss 0.086363
batch 1342: 

batch 1621: loss 0.182237
batch 1622: loss 0.113747
batch 1623: loss 0.104703
batch 1624: loss 0.167026
batch 1625: loss 0.035751
batch 1626: loss 0.241225
batch 1627: loss 0.141696
batch 1628: loss 0.180174
batch 1629: loss 0.304860
batch 1630: loss 0.087998
batch 1631: loss 0.308510
batch 1632: loss 0.170212
batch 1633: loss 0.092066
batch 1634: loss 0.285016
batch 1635: loss 0.096526
batch 1636: loss 0.169885
batch 1637: loss 0.391916
batch 1638: loss 0.076566
batch 1639: loss 0.272224
batch 1640: loss 0.155875
batch 1641: loss 0.187051
batch 1642: loss 0.197506
batch 1643: loss 0.044404
batch 1644: loss 0.100116
batch 1645: loss 0.071534
batch 1646: loss 0.053658
batch 1647: loss 0.300262
batch 1648: loss 0.090683
batch 1649: loss 0.154934
batch 1650: loss 0.060559
batch 1651: loss 0.154186
batch 1652: loss 0.057431
batch 1653: loss 0.067771
batch 1654: loss 0.100502
batch 1655: loss 0.199184
batch 1656: loss 0.104434
batch 1657: loss 0.088820
batch 1658: loss 0.109902
batch 1659: 

batch 1939: loss 0.024629
batch 1940: loss 0.078229
batch 1941: loss 0.130231
batch 1942: loss 0.178590
batch 1943: loss 0.043133
batch 1944: loss 0.177355
batch 1945: loss 0.069053
batch 1946: loss 0.162950
batch 1947: loss 0.132648
batch 1948: loss 0.057127
batch 1949: loss 0.183340
batch 1950: loss 0.034289
batch 1951: loss 0.073845
batch 1952: loss 0.197067
batch 1953: loss 0.243043
batch 1954: loss 0.029411
batch 1955: loss 0.118861
batch 1956: loss 0.029694
batch 1957: loss 0.079953
batch 1958: loss 0.139888
batch 1959: loss 0.150314
batch 1960: loss 0.082202
batch 1961: loss 0.196282
batch 1962: loss 0.054403
batch 1963: loss 0.112280
batch 1964: loss 0.108556
batch 1965: loss 0.125302
batch 1966: loss 0.074918
batch 1967: loss 0.115164
batch 1968: loss 0.077864
batch 1969: loss 0.221572
batch 1970: loss 0.190760
batch 1971: loss 0.132784
batch 1972: loss 0.139354
batch 1973: loss 0.040899
batch 1974: loss 0.038691
batch 1975: loss 0.148903
batch 1976: loss 0.163668
batch 1977: 

batch 2257: loss 0.068814
batch 2258: loss 0.100255
batch 2259: loss 0.015577
batch 2260: loss 0.075108
batch 2261: loss 0.312996
batch 2262: loss 0.170141
batch 2263: loss 0.078321
batch 2264: loss 0.370417
batch 2265: loss 0.170107
batch 2266: loss 0.095927
batch 2267: loss 0.048008
batch 2268: loss 0.128240
batch 2269: loss 0.031900
batch 2270: loss 0.119135
batch 2271: loss 0.187901
batch 2272: loss 0.066178
batch 2273: loss 0.038296
batch 2274: loss 0.076854
batch 2275: loss 0.056821
batch 2276: loss 0.149914
batch 2277: loss 0.144804
batch 2278: loss 0.038547
batch 2279: loss 0.275835
batch 2280: loss 0.032584
batch 2281: loss 0.075702
batch 2282: loss 0.242900
batch 2283: loss 0.062296
batch 2284: loss 0.058781
batch 2285: loss 0.142555
batch 2286: loss 0.048448
batch 2287: loss 0.216586
batch 2288: loss 0.156460
batch 2289: loss 0.020940
batch 2290: loss 0.150316
batch 2291: loss 0.051192
batch 2292: loss 0.071909
batch 2293: loss 0.066290
batch 2294: loss 0.088793
batch 2295: 

batch 2575: loss 0.136494
batch 2576: loss 0.017955
batch 2577: loss 0.038274
batch 2578: loss 0.136383
batch 2579: loss 0.112982
batch 2580: loss 0.042058
batch 2581: loss 0.261800
batch 2582: loss 0.066336
batch 2583: loss 0.087528
batch 2584: loss 0.027654
batch 2585: loss 0.059720
batch 2586: loss 0.037154
batch 2587: loss 0.064829
batch 2588: loss 0.149172
batch 2589: loss 0.114250
batch 2590: loss 0.027031
batch 2591: loss 0.096946
batch 2592: loss 0.204933
batch 2593: loss 0.018950
batch 2594: loss 0.041892
batch 2595: loss 0.228303
batch 2596: loss 0.011323
batch 2597: loss 0.025954
batch 2598: loss 0.069627
batch 2599: loss 0.153186
batch 2600: loss 0.019346
batch 2601: loss 0.143558
batch 2602: loss 0.091937
batch 2603: loss 0.100051
batch 2604: loss 0.031220
batch 2605: loss 0.046376
batch 2606: loss 0.087089
batch 2607: loss 0.098316
batch 2608: loss 0.237831
batch 2609: loss 0.104221
batch 2610: loss 0.052263
batch 2611: loss 0.062642
batch 2612: loss 0.113358
batch 2613: 

batch 2893: loss 0.198247
batch 2894: loss 0.120127
batch 2895: loss 0.026070
batch 2896: loss 0.126628
batch 2897: loss 0.179124
batch 2898: loss 0.068504
batch 2899: loss 0.027021
batch 2900: loss 0.239748
batch 2901: loss 0.031681
batch 2902: loss 0.087251
batch 2903: loss 0.047760
batch 2904: loss 0.139899
batch 2905: loss 0.054389
batch 2906: loss 0.037106
batch 2907: loss 0.049544
batch 2908: loss 0.149640
batch 2909: loss 0.333926
batch 2910: loss 0.054562
batch 2911: loss 0.035885
batch 2912: loss 0.208905
batch 2913: loss 0.088701
batch 2914: loss 0.066914
batch 2915: loss 0.148328
batch 2916: loss 0.026992
batch 2917: loss 0.059585
batch 2918: loss 0.021181
batch 2919: loss 0.110802
batch 2920: loss 0.091217
batch 2921: loss 0.100706
batch 2922: loss 0.051876
batch 2923: loss 0.071907
batch 2924: loss 0.102662
batch 2925: loss 0.015843
batch 2926: loss 0.037924
batch 2927: loss 0.036997
batch 2928: loss 0.179136
batch 2929: loss 0.071623
batch 2930: loss 0.062264
batch 2931: 

batch 3211: loss 0.020179
batch 3212: loss 0.238521
batch 3213: loss 0.178520
batch 3214: loss 0.174140
batch 3215: loss 0.238018
batch 3216: loss 0.132443
batch 3217: loss 0.063903
batch 3218: loss 0.049039
batch 3219: loss 0.043740
batch 3220: loss 0.044449
batch 3221: loss 0.031979
batch 3222: loss 0.015414
batch 3223: loss 0.028734
batch 3224: loss 0.071322
batch 3225: loss 0.044720
batch 3226: loss 0.078683
batch 3227: loss 0.297519
batch 3228: loss 0.033846
batch 3229: loss 0.017376
batch 3230: loss 0.030334
batch 3231: loss 0.058567
batch 3232: loss 0.056023
batch 3233: loss 0.058764
batch 3234: loss 0.114655
batch 3235: loss 0.092402
batch 3236: loss 0.070573
batch 3237: loss 0.029300
batch 3238: loss 0.036312
batch 3239: loss 0.024449
batch 3240: loss 0.055913
batch 3241: loss 0.020216
batch 3242: loss 0.034163
batch 3243: loss 0.075707
batch 3244: loss 0.153273
batch 3245: loss 0.030200
batch 3246: loss 0.053985
batch 3247: loss 0.090595
batch 3248: loss 0.040730
batch 3249: 

batch 3529: loss 0.077130
batch 3530: loss 0.023725
batch 3531: loss 0.011207
batch 3532: loss 0.007546
batch 3533: loss 0.132406
batch 3534: loss 0.036045
batch 3535: loss 0.053071
batch 3536: loss 0.037004
batch 3537: loss 0.034898
batch 3538: loss 0.048929
batch 3539: loss 0.097904
batch 3540: loss 0.049507
batch 3541: loss 0.014883
batch 3542: loss 0.017634
batch 3543: loss 0.016257
batch 3544: loss 0.057056
batch 3545: loss 0.019623
batch 3546: loss 0.011446
batch 3547: loss 0.053458
batch 3548: loss 0.025673
batch 3549: loss 0.020348
batch 3550: loss 0.056762
batch 3551: loss 0.044080
batch 3552: loss 0.094130
batch 3553: loss 0.159994
batch 3554: loss 0.025857
batch 3555: loss 0.052624
batch 3556: loss 0.079621
batch 3557: loss 0.108425
batch 3558: loss 0.092982
batch 3559: loss 0.017021
batch 3560: loss 0.008153
batch 3561: loss 0.036594
batch 3562: loss 0.017706
batch 3563: loss 0.032076
batch 3564: loss 0.101515
batch 3565: loss 0.026184
batch 3566: loss 0.103399
batch 3567: 

batch 3845: loss 0.051109
batch 3846: loss 0.033133
batch 3847: loss 0.063353
batch 3848: loss 0.026299
batch 3849: loss 0.053977
batch 3850: loss 0.060085
batch 3851: loss 0.054794
batch 3852: loss 0.094857
batch 3853: loss 0.113539
batch 3854: loss 0.121612
batch 3855: loss 0.118903
batch 3856: loss 0.112074
batch 3857: loss 0.056492
batch 3858: loss 0.089329
batch 3859: loss 0.133973
batch 3860: loss 0.009429
batch 3861: loss 0.031185
batch 3862: loss 0.007194
batch 3863: loss 0.028703
batch 3864: loss 0.107487
batch 3865: loss 0.188598
batch 3866: loss 0.068947
batch 3867: loss 0.085927
batch 3868: loss 0.011912
batch 3869: loss 0.016781
batch 3870: loss 0.034943
batch 3871: loss 0.049136
batch 3872: loss 0.141700
batch 3873: loss 0.068455
batch 3874: loss 0.085523
batch 3875: loss 0.115919
batch 3876: loss 0.169109
batch 3877: loss 0.007211
batch 3878: loss 0.011644
batch 3879: loss 0.017243
batch 3880: loss 0.044515
batch 3881: loss 0.087319
batch 3882: loss 0.038673
batch 3883: 

batch 4163: loss 0.027340
batch 4164: loss 0.011342
batch 4165: loss 0.083789
batch 4166: loss 0.040270
batch 4167: loss 0.007570
batch 4168: loss 0.086470
batch 4169: loss 0.098847
batch 4170: loss 0.203955
batch 4171: loss 0.049522
batch 4172: loss 0.073282
batch 4173: loss 0.039643
batch 4174: loss 0.014799
batch 4175: loss 0.079772
batch 4176: loss 0.076849
batch 4177: loss 0.059371
batch 4178: loss 0.040792
batch 4179: loss 0.090585
batch 4180: loss 0.125637
batch 4181: loss 0.214861
batch 4182: loss 0.005542
batch 4183: loss 0.073241
batch 4184: loss 0.018001
batch 4185: loss 0.020391
batch 4186: loss 0.080988
batch 4187: loss 0.011467
batch 4188: loss 0.007013
batch 4189: loss 0.020126
batch 4190: loss 0.038734
batch 4191: loss 0.022236
batch 4192: loss 0.077826
batch 4193: loss 0.034458
batch 4194: loss 0.019527
batch 4195: loss 0.009965
batch 4196: loss 0.211863
batch 4197: loss 0.120547
batch 4198: loss 0.051580
batch 4199: loss 0.054931
batch 4200: loss 0.020263
batch 4201: 

batch 4480: loss 0.031695
batch 4481: loss 0.051495
batch 4482: loss 0.017669
batch 4483: loss 0.038978
batch 4484: loss 0.087250
batch 4485: loss 0.007678
batch 4486: loss 0.046886
batch 4487: loss 0.029899
batch 4488: loss 0.014919
batch 4489: loss 0.042650
batch 4490: loss 0.203431
batch 4491: loss 0.037215
batch 4492: loss 0.023062
batch 4493: loss 0.032773
batch 4494: loss 0.007940
batch 4495: loss 0.022988
batch 4496: loss 0.017288
batch 4497: loss 0.015852
batch 4498: loss 0.031344
batch 4499: loss 0.096065
batch 4500: loss 0.035347
batch 4501: loss 0.127164
batch 4502: loss 0.077116
batch 4503: loss 0.069699
batch 4504: loss 0.097580
batch 4505: loss 0.112566
batch 4506: loss 0.131098
batch 4507: loss 0.015718
batch 4508: loss 0.205919
batch 4509: loss 0.028893
batch 4510: loss 0.006245
batch 4511: loss 0.088470
batch 4512: loss 0.098599
batch 4513: loss 0.088160
batch 4514: loss 0.019332
batch 4515: loss 0.308668
batch 4516: loss 0.153038
batch 4517: loss 0.066528
batch 4518: 

batch 4796: loss 0.016402
batch 4797: loss 0.113384
batch 4798: loss 0.085849
batch 4799: loss 0.136712
batch 4800: loss 0.009297
batch 4801: loss 0.071733
batch 4802: loss 0.111569
batch 4803: loss 0.119761
batch 4804: loss 0.040831
batch 4805: loss 0.046658
batch 4806: loss 0.042550
batch 4807: loss 0.041347
batch 4808: loss 0.039264
batch 4809: loss 0.038418
batch 4810: loss 0.019617
batch 4811: loss 0.033066
batch 4812: loss 0.063240
batch 4813: loss 0.061903
batch 4814: loss 0.077249
batch 4815: loss 0.075535
batch 4816: loss 0.005309
batch 4817: loss 0.003091
batch 4818: loss 0.178375
batch 4819: loss 0.013666
batch 4820: loss 0.043907
batch 4821: loss 0.013359
batch 4822: loss 0.004983
batch 4823: loss 0.160524
batch 4824: loss 0.003838
batch 4825: loss 0.016221
batch 4826: loss 0.075598
batch 4827: loss 0.031656
batch 4828: loss 0.066435
batch 4829: loss 0.010994
batch 4830: loss 0.009174
batch 4831: loss 0.011138
batch 4832: loss 0.012588
batch 4833: loss 0.011770
batch 4834: 

batch 5113: loss 0.027276
batch 5114: loss 0.040076
batch 5115: loss 0.067743
batch 5116: loss 0.070514
batch 5117: loss 0.070070
batch 5118: loss 0.014139
batch 5119: loss 0.001971
batch 5120: loss 0.023901
batch 5121: loss 0.019207
batch 5122: loss 0.021131
batch 5123: loss 0.099582
batch 5124: loss 0.095998
batch 5125: loss 0.013496
batch 5126: loss 0.073024
batch 5127: loss 0.179545
batch 5128: loss 0.009622
batch 5129: loss 0.036374
batch 5130: loss 0.113486
batch 5131: loss 0.008679
batch 5132: loss 0.061154
batch 5133: loss 0.014053
batch 5134: loss 0.010052
batch 5135: loss 0.124278
batch 5136: loss 0.020569
batch 5137: loss 0.050836
batch 5138: loss 0.125671
batch 5139: loss 0.033343
batch 5140: loss 0.005982
batch 5141: loss 0.074716
batch 5142: loss 0.027612
batch 5143: loss 0.033172
batch 5144: loss 0.034675
batch 5145: loss 0.133278
batch 5146: loss 0.040267
batch 5147: loss 0.058383
batch 5148: loss 0.064736
batch 5149: loss 0.111583
batch 5150: loss 0.091052
batch 5151: 

batch 5431: loss 0.015753
batch 5432: loss 0.127509
batch 5433: loss 0.025043
batch 5434: loss 0.187638
batch 5435: loss 0.002897
batch 5436: loss 0.019337
batch 5437: loss 0.054113
batch 5438: loss 0.029730
batch 5439: loss 0.102181
batch 5440: loss 0.013624
batch 5441: loss 0.097693
batch 5442: loss 0.008078
batch 5443: loss 0.030942
batch 5444: loss 0.066598
batch 5445: loss 0.046411
batch 5446: loss 0.099952
batch 5447: loss 0.026498
batch 5448: loss 0.052223
batch 5449: loss 0.011639
batch 5450: loss 0.324012
batch 5451: loss 0.003329
batch 5452: loss 0.007185
batch 5453: loss 0.013544
batch 5454: loss 0.007988
batch 5455: loss 0.021015
batch 5456: loss 0.003848
batch 5457: loss 0.046884
batch 5458: loss 0.015437
batch 5459: loss 0.016198
batch 5460: loss 0.010160
batch 5461: loss 0.049813
batch 5462: loss 0.008031
batch 5463: loss 0.055383
batch 5464: loss 0.060648
batch 5465: loss 0.008318
batch 5466: loss 0.038649
batch 5467: loss 0.147043
batch 5468: loss 0.023772
batch 5469: 

In [89]:
# sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
categorical_accuracy = tf.keras.metrics.CategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
for batch_index in range(num_batches):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = model.predict(data_loader.test_data[start_index: end_index])
    categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % categorical_accuracy.result())

test accuracy: 0.982100


![image.png](attachment:image.png)
卷积示意图。一个单通道的 7×7 图像在通过一个感受野为 3×3 ，参数为 10 个的卷积层神经元后，得到 5×5 的矩阵作为卷积结果。

## 2.4 循环神经网络（RNN）

循环神经网络（Recurrent Neural Network, RNN）是一种适宜于处理序列数据的神经网络，被广泛用于语言模型、文本生成、机器翻译等。

首先，还是实现一个简单的 DataLoader 类来读取文本，并以字符为单位进行编码。设字符种类数为 num_chars ，则每种字符赋予一个 0 到 num_chars - 1 之间的唯一整数编号 i。

In [154]:
import re
# 2. 简单对数据进行清洗
def clean_data(string):
    # 对数据清洗
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)  # \s表示空格，删除多个空格
    return string.strip().lower()

In [155]:
class DataLoader():
    def __init__(self, word_or_char='char'):
        path = tf.keras.utils.get_file('nietzsche.txt',
            origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
        with open(path, encoding='utf-8') as f:
            self.raw_text = f.read().lower()
        if word_or_char == 'char':
            self.chars = sorted(list(set(self.raw_text)))
            self.char_indices = dict((c, i) for i, c in enumerate(self.chars))  # 定义 id：word
            self.indices_char = dict((i, c) for i, c in enumerate(self.chars))  # 定义 word：id
            self.text = [self.char_indices[c] for c in self.raw_text]  # 给完整的文档编码成 id
        else:
            # 以每个单词为模板进行训练
            self.raw_text = clean_data(self.raw_text)
            self.chars = sorted(list(set(self.raw_text.split(' '))))
            self.char_indices = dict((c, i) for i, c in enumerate(self.chars))  # 定义 id：word
            self.indices_char = dict((i, c) for i, c in enumerate(self.chars))  # 定义 word：id
            self.text = [self.char_indices[c] for c in self.chars]  # 给完整的文档编码成 id
            

    def get_batch(self, seq_length, batch_size):
        seq = [] 
        next_char = [] 
        for i in range(batch_size):
            index = np.random.randint(0, len(self.text) - seq_length) 
            seq.append(self.text[index:index + seq_length])   # 随机读取长度为seq_length的向量
            next_char.append(self.text[index + seq_length])   # 读取后一个字符作为label
        return np.array(seq), np.array(next_char)       # [batch_size, seq_length], [num_batch]

接下来进行模型的实现。

In [156]:
class RNN(tf.keras.Model):
    def __init__(self, num_chars, batch_size, seq_length):
        super().__init__()
        self.num_chars = num_chars  # 总的词数量
        self.seq_length = seq_length   # 每个batch长度
        self.batch_size = batch_size   # 每一轮训练输入的训练样本量
        self.cell = tf.keras.layers.LSTMCell(units=256)   # 选取RNN个数256，输出空间的维数为256
        self.dense = tf.keras.layers.Dense(units=self.num_chars)

    def call(self, inputs, from_logits=False):
        inputs = tf.one_hot(inputs, depth=self.num_chars)       # [batch_size, seq_length, num_chars]
        state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)   # 获得 RNN 的初始状态 (确定RNN的batch_size)
        
        # 例如：seq_length = 4  循环四次：
        # [[0,1,0,0,...], [0,0,0,1,0....], [0,0,0,1,0,0...], [1,0,0,0.....]]     n[[num_chars]]) == seq_length
        # 例如 t = 1 表示选取了batch_size==50个 第2个位置上的值进行训练
        
        for t in range(self.seq_length):
            output, state = self.cell(inputs[:, t, :], state)   # 通过当前输入和前一时刻的状态，得到输出和当前时刻的状态
        logits = self.dense(output)
        if from_logits:                     # from_logits 参数控制输出是否通过 softmax 函数进行归一化
            return logits
        else:
            return tf.nn.softmax(logits)
    
    def predict(self, inputs, temperature=1.):
        batch_size, _ = tf.shape(inputs)
        logits = self(inputs, from_logits=True)                         # 调用训练好的RNN模型，预测下一个字符的概率分布
        prob = tf.nn.softmax(logits / temperature).numpy()              # 使用带 temperature 参数的 softmax 函数获得归一化的概率分布值
        return np.array([np.random.choice(self.num_chars, p=prob[i, :]) # 使用 np.random.choice 函数，
                         for i in range(batch_size.numpy())])           # 在预测的概率分布 prob 上进行随机取样

在 __init__ 方法中我们实例化一个常用的 LSTMCell 单元，以及一个线性变换用的全连接层，我们首先对序列进行 “One Hot” 操作，即将序列中的每个字符的编码 i 均变换为一个 num_char 维向量，其第 i 位为 1，其余均为 0。变换后的序列张量形状为 [seq_length, num_chars] 。

然后，我们初始化 RNN 单元的状态，存入变量 state 中。接下来，将序列从头到尾依次送入 RNN 单元，即在 t 时刻，将上一个时刻 t-1 的 RNN 单元状态 state 和序列的第 t 个元素 inputs[t, :] 送入 RNN 单元，得到当前时刻的输出 output 和 RNN 单元状态。取 RNN 单元最后一次的输出，通过全连接层变换到 num_chars 维，即作为模型的输出。

![image.png](attachment:image.png)

In [157]:
num_batches = 100
seq_length = 5 
batch_size = 50
learning_rate = 1e-3

训练过程与前节基本一致，在此复述：<br/>
（1）从 DataLoader 中随机取一批训练数据；<br/>
（2）将这批数据送入模型，计算出模型的预测值；<br/>
（3）将模型预测值与真实值进行比较，计算损失函数（loss）；<br/>
（4）计算损失函数关于模型变量的导数；<br/>
（5）使用优化器更新模型参数以最小化损失函数。<br/>

In [158]:
data_loader = DataLoader('word')

model = RNN(num_chars=len(data_loader.chars), batch_size=batch_size, seq_length=seq_length)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)


for batch_index in range(num_batches):
    X, y = data_loader.get_batch(seq_length, batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))  # w1 = w1 - a*grads

batch 0: loss 9.231982
batch 1: loss 9.231997
batch 2: loss 9.231779
batch 3: loss 9.232134
batch 4: loss 9.232060
batch 5: loss 9.231837
batch 6: loss 9.231556
batch 7: loss 9.231973
batch 8: loss 9.231834
batch 9: loss 9.231473
batch 10: loss 9.232635
batch 11: loss 9.232227
batch 12: loss 9.232550
batch 13: loss 9.230679
batch 14: loss 9.231379
batch 15: loss 9.230841
batch 16: loss 9.232001
batch 17: loss 9.230255
batch 18: loss 9.231311
batch 19: loss 9.231050
batch 20: loss 9.231351
batch 21: loss 9.231658
batch 22: loss 9.231020
batch 23: loss 9.232651
batch 24: loss 9.231641
batch 25: loss 9.232889
batch 26: loss 9.231209
batch 27: loss 9.231699
batch 28: loss 9.231276
batch 29: loss 9.230980
batch 30: loss 9.230420
batch 31: loss 9.226340
batch 32: loss 9.232630
batch 33: loss 9.233208
batch 34: loss 9.235175
batch 35: loss 9.231370
batch 36: loss 9.231544
batch 37: loss 9.230949
batch 38: loss 9.227948
batch 39: loss 9.230601
batch 40: loss 9.220881
batch 41: loss 9.233994
ba

通过这种方式进行 “滚雪球” 式的连续预测，即可得到生成文本。

In [159]:
X_, _ = data_loader.get_batch(seq_length, 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:      # 丰富度（即temperature）分别设置为从小到大的 4 个值
    X = X_
    print("diversity %f:" % diversity)
    for t in range(400):
        y_pred = model.predict(X, diversity)    # 预测下一个字符的编号
        print(data_loader.indices_char[y_pred[0]], end=' ', flush=True)  # 输出预测的字符
        X = np.concatenate([X[:, 1:], np.expand_dims(y_pred, axis=1)], axis=-1)     # 将预测的字符接在输入 X 的末尾，并截断 X 的第一个字符，以保证 X 的长度不变
    print("\n")

diversity 0.200000:

diversity 0.500000:
demimonde slippery being maturing sentient unsuspected hans helping contentedness relative build calvin sate educational convictions complacently direct absurde come reverently hinders understand reciprocity irritability henri savagely hears superficialities pessimism physical trellis grandiose troublesome fatalists severing forgetfulness riders something epochs awakens wagnerienne uncertainty jubilantly escape headstrong pinu aggrandizement centuries threatening reproached favourably undertake classification tread pardon inclusive tedious traditionally mischievous victor great minently gradations distastefulness sets overturns obstinately preferably tranquilize contend vindictively said motions whence affirmation gives colour 38 down endeavor perspective employs hideousness helvetius duplicate contemptuous invents can cattle fumes sided 137 intelleto angry positivists soothed tribes admits humaner circulates widening efforts raffinements adject

## 2.5 深度强化学习（DRL）

In [61]:
import tensorflow as tf

## 2.6 Keras Pipeline *

不过在很多时候，我们只需要建立一个结构相对简单和典型的神经网络（比如上文中的 MLP 和 CNN），并使用常规的手段进行训练。这时，Keras 也给我们提供了另一套更为简单高效的内置方法来建立、训练和评估模型。

### 1. Keras Sequential/Functional API 模式建立模型 
通过向 tf.keras.models.Sequential() 提供一个层的列表，就能快速地建立一个 tf.keras.Model 模型并返回：

In [2]:
 model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(100, activation=tf.nn.relu),
            tf.keras.layers.Dense(10),
            tf.keras.layers.Softmax()
        ])

Keras 提供了 Functional API，帮助我们建立更为复杂的模型，例如多输入 / 输出或存在参数共享的模型。其使用方法是将层作为可调用的对象并返回张量（这点与之前章节的使用方法一致），并将输入向量和输出向量提供给 tf.keras.Model 的 inputs 和 outputs 参数，示例如下：

In [None]:
inputs = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Flatten()(inputs)
x = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)(x)
x = tf.keras.layers.Dense(units=10)(x)
outputs = tf.keras.layers.Softmax()(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

### 2. 使用 Keras Model 的 compile 、 fit 和 evaluate 方法训练和评估模型
当模型建立完成后，通过 tf.keras.Model 的 compile 方法配置训练过程：

In [5]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),    # oplimizer ：优化器，可从 tf.keras.optimizers 中选择；
    loss=tf.keras.losses.sparse_categorical_crossentropy,       # loss ：损失函数，可从 tf.keras.losses 中选择；
    metrics=[tf.keras.metrics.sparse_categorical_accuracy]      # metrics ：评估指标，可从 tf.keras.metrics 中选择。
)


接下来，可以使用 tf.keras.Model 的 fit 方法训练模型：

In [None]:
model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size)
# x ：训练数据；
# y ：目标数据（数据标签）；
# epochs ：将训练数据迭代多少遍；
# batch_size ：批次的大小；
#validation_data ：验证数据，可用于在训练过程中监控模型的性能。

最后，使用 tf.keras.Model.evaluate 评估训练效果，提供测试数据及标签即可：

In [None]:
print(model.evaluate(data_loader.test_data, data_loader.test_label))

In [78]:
# 1. 继承 keras.Model类，实现函数calL()
# 2. 输入输出使用 keras.Model类 的intput 、output 方法
# 3. 使用tf.keras.models.Sequential()简单定义网络层数


from tensorflow_core.examples.tutorials.mnist import input_data
import numpy as np
class MNISTLoader_my_download():
    def __init__(self):
        # 读取数据，预先已经下载了相应的数据直
        mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
        self.train_data = mnist.train.images
        self.train_label = mnist.train.labels
        self.test_data = mnist.test.images
        self.test_label = mnist.test.labels
        
        # MNIST中的图像默认为uint8（0-255的数字）。以下代码将其归一化到0-1之间的浮点数，并在最后增加一维作为颜色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 784, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 784, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, self.num_train_data, batch_size)
        return self.train_data[index, :], self.train_label[index]

num_epochs = 5
batch_size = 50
learning_rate = 0.001
data_loader = MNISTLoader_my_download()  # 导入数据


# 使用第二个方法
inputs = tf.keras.Input(shape=(784, 1))
x = tf.keras.layers.Flatten()(inputs)
x = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)(x)
x = tf.keras.layers.Dense(units=10)(x)
outputs = tf.keras.layers.Softmax()(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)


# # 使用第三个方法，使用Sequential定义一个简单的网络结构
# model = tf.keras.models.Sequential([
#         tf.keras.layers.Flatten(),
#         tf.keras.layers.Dense(100, activation=tf.nn.relu),
#         tf.keras.layers.Dense(10),
#         tf.keras.layers.Softmax()
#     ])

# 通过 tf.keras.Model 的 compile 方法配置训练过程
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=tf.keras.losses.categorical_crossentropy,
    metrics=[tf.keras.metrics.categorical_accuracy]
)

model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size)


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Train on 55000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x2033a5b0668>

In [79]:
print(model.evaluate(data_loader.test_data, data_loader.test_label))



[0.3354526966929436, 0.9074]


## 2.7 自定义层、损失函数和评估指标 *

我们不仅可以继承 tf.keras.Model 编写自己的模型类，也可以继承 tf.keras.layers.Layer 编写自己的层。

### 1. 自定义层
自定义层需要继承 tf.keras.layers.Layer 类，并重写 __init__ 、 build 和 call 三个方法，如下所示：

In [8]:
class MyLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        # 初始化代码

    def build(self, input_shape):     # input_shape 是一个 TensorShape 类型对象，提供输入的形状
        # 在第一次使用该层的时候调用该部分代码，在这里创建变量可以使得变量的形状自适应输入的形状
        # 而不需要使用者额外指定变量形状。
        # 如果已经可以完全确定变量的形状，也可以在__init__部分创建变量
        self.variable_0 = self.add_weight(...)
        self.variable_1 = self.add_weight(...)

    def call(self, inputs):
        # 模型调用的代码（处理输入并返回输出）
        return output

例如，如果我们要自己实现一个 本章第一节 中的全连接层（ tf.keras.layers.Dense ），可以按如下方式编写。此代码在 build 方法中创建两个变量，并在 call 方法中使用创建的变量进行运算：

In [7]:
class LinearLayer(tf.keras.layers.Layer):
    def __init__(self, units):
        super().__init__()
        self.units = units

    def build(self, input_shape):     # 这里 input_shape 是第一次运行call()时参数inputs的形状
        self.w = self.add_weight(name='w',
            shape=[input_shape[-1], self.units], initializer=tf.zeros_initializer())
        self.b = self.add_weight(name='b',
            shape=[self.units], initializer=tf.zeros_initializer())

    def call(self, inputs):
        y_pred = tf.matmul(inputs, self.w) + self.b
        return y_pred

在定义模型的时候，我们便可以如同 Keras 中的其他层一样，调用我们自定义的层 LinearLayer：

In [9]:
class LinearModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.layer = LinearLayer(units=1)  # 定义

    def call(self, inputs):
        output = self.layer(inputs)
        return output

### 2.自定义损失函数和评估指标 

自定义损失函数需要继承 tf.keras.losses.Loss 类，重写 call 方法即可，输入真实值 y_true 和模型预测值 y_pred ，输出模型预测值和真实值之间通过自定义的损失函数计算出的损失值。下面的示例为均方差损失函数：

In [144]:
# 可以在进行调试
tf.config.experimental_run_functions_eagerly(True)

class MeanSquaredError(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        y_pred = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
        y_pred = tf.one_hot(y_pred, depth=10, dtype=tf.int32) # 转变成one_hot
        print(y_pred)
        return tf.reduce_mean(tf.square(y_pred - y_true))

自定义评估指标需要继承 tf.keras.metrics.Metric 类，并重写 __init__ 、 update_state 和 result 三个方法。下面的示例对前面用到的 SparseCategoricalAccuracy 评估指标类做了一个简单的重实现：

In [147]:
class SparseCategoricalAccuracy(tf.keras.metrics.Metric):
    def __init__(self):
        super().__init__()
        self.total = self.add_weight(name='total', dtype=tf.int32, initializer=tf.zeros_initializer())
        self.count = self.add_weight(name='count', dtype=tf.int32, initializer=tf.zeros_initializer())

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_ture = tf.argmax(y_true, axis=-1, output_type=tf.int32)
        values = tf.cast(tf.equal(y_true, tf.argmax(y_pred, axis=-1, output_type=tf.int32)), tf.int32)
        self.total.assign_add(tf.shape(y_true)[0])
        self.count.assign_add(tf.reduce_sum(values))

    def result(self):
        return self.count / self.total

In [148]:
# 可以在
tf.config.experimental_run_functions_eagerly(True)

# # 使用第二个方法
# inputs = tf.keras.Input(shape=(784, 1))
# x = tf.keras.layers.Flatten()(inputs)
# x = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)(x)
# x = tf.keras.layers.Dense(units=10)(x)
# outputs = tf.keras.layers.Softmax()(x)
# model = tf.keras.Model(inputs=inputs, outputs=outputs)


# 使用第三个方法，使用Sequential定义一个简单的网络结构
model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(100, activation=tf.nn.relu),
        tf.keras.layers.Dense(10),
        tf.keras.layers.Softmax()
    ])

# 通过 tf.keras.Model 的 compile 方法配置训练过程
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=MeanSquaredError(),
    metrics=[SparseCategoricalAccuracy()]
)

model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size)

Train on 55000 samples
Epoch 1/5
tf.Tensor(0, shape=(), dtype=int32)
   50/55000 [..............................] - ETA: 15s

NotFoundError: Could not find valid device for node.
Node:{{node DivNoNan}}
All kernels registered for op DivNoNan :
  device='CPU'; T in [DT_HALF]
  device='CPU'; T in [DT_FLOAT]
  device='CPU'; T in [DT_DOUBLE]
  device='CPU'; T in [DT_COMPLEX64]
  device='CPU'; T in [DT_COMPLEX128]
 [Op:DivNoNan] name: loss/output_1_loss/value/