In [2]:
import tensorflow as tf
from tensorflow_core.examples.tutorials.mnist import input_data
import numpy as np
class MNISTLoader_my_download():
    def __init__(self):
        # 读取数据，预先已经下载了相应的数据直
        mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
        self.train_data = mnist.train.images
        self.train_label = mnist.train.labels
        self.test_data = mnist.test.images
        self.test_label = mnist.test.labels
        
        # MNIST中的图像默认为uint8（0-255的数字）。以下代码将其归一化到0-1之间的浮点数，并在最后增加一维作为颜色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 784, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 784, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, self.num_train_data, batch_size)
        return self.train_data[index, :], self.train_label[index]

In [6]:
class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()    # Flatten层将除第一维（batch_size）以外的维度展平
        self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)  # 第一层神经元的个数为100
        self.dense2 = tf.keras.layers.Dense(units=10)   # 第二层神经元的个数为10,输出一个样本的维度为10

    def call(self, inputs):         # [batch_size, 28, 28, 1]
        x = self.flatten(inputs)    # [batch_size, 784]
        x = self.dense1(x)          # [batch_size, 100]
        x = self.dense2(x)          # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output

In [7]:
# 定义一些模型超参数：
num_epochs = 5
batch_size = 50
learning_rate = 0.001

# 实例化模型和数据读取类，并实例化一个 tf.keras.optimizer 的优化器（这里使用常用的 Adam 优化器）：
model = MLP()
# data_loader = MNISTLoader() # 导入数据 
data_loader = MNISTLoader_my_download()  # 导入数据
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 更新梯度

# num_batches = int(mnist.train.num_examples // batch_size * num_epochs)
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
batch 0: loss 2.302219
batch 1: loss 2.301422
batch 2: loss 2.300723
batch 3: loss 2.300955
batch 4: loss 2.300592
batch 5: loss 2.302301
batch 6: loss 2.302447
batch 7: loss 2.302885
batch 8: loss 2.299646
batch 9: loss 2.299504
batch 10: loss 2.298020
batch 11: loss 2.297356
batch 12: loss 2.299600
batch 13: loss 2.298584
batch 14: loss 2.299991
batch 15: loss 2.295949
batch 16: loss 2.295848
batch 17: loss 2.296138
batch 18: loss 2.293093
batch 19: loss 2.293828
batch 20: loss 2.292792
batch 21: loss 2.292058
batch 22: loss 2.296293
batch 23: loss 2.291264
batch 24: loss 2.289793
batch 25: loss 2.290593
batch 26: loss 2.292573
batch 27: loss 2.288534
batch 28: loss 2.293505
batch 29: loss 2.297284
batch 30: loss 2.287549
batch 31: loss 2.286821
batch 32: loss 2.287719
batch 33: loss 2.291282


batch 329: loss 1.725568
batch 330: loss 1.765414
batch 331: loss 1.690040
batch 332: loss 1.720434
batch 333: loss 1.753545
batch 334: loss 1.775700
batch 335: loss 1.823447
batch 336: loss 1.836362
batch 337: loss 1.819796
batch 338: loss 1.753536
batch 339: loss 1.724267
batch 340: loss 1.748280
batch 341: loss 1.644292
batch 342: loss 1.805654
batch 343: loss 1.616218
batch 344: loss 1.670159
batch 345: loss 1.778253
batch 346: loss 1.597938
batch 347: loss 1.742141
batch 348: loss 1.674846
batch 349: loss 1.682862
batch 350: loss 1.679306
batch 351: loss 1.667018
batch 352: loss 1.700643
batch 353: loss 1.761148
batch 354: loss 1.779939
batch 355: loss 1.671694
batch 356: loss 1.796482
batch 357: loss 1.716781
batch 358: loss 1.702067
batch 359: loss 1.693311
batch 360: loss 1.661478
batch 361: loss 1.632875
batch 362: loss 1.617589
batch 363: loss 1.667943
batch 364: loss 1.717641
batch 365: loss 1.699959
batch 366: loss 1.684882
batch 367: loss 1.706889
batch 368: loss 1.762512


batch 990: loss 0.914851
batch 991: loss 0.983410
batch 992: loss 0.879827
batch 993: loss 0.800982
batch 994: loss 0.973723
batch 995: loss 0.673251
batch 996: loss 0.888513
batch 997: loss 0.902901
batch 998: loss 0.872275
batch 999: loss 0.884556
batch 1000: loss 0.963140
batch 1001: loss 0.789567
batch 1002: loss 0.957051
batch 1003: loss 0.822445
batch 1004: loss 0.836480
batch 1005: loss 0.647777
batch 1006: loss 0.947477
batch 1007: loss 0.857541
batch 1008: loss 0.953716
batch 1009: loss 0.735502
batch 1010: loss 0.714304
batch 1011: loss 0.943559
batch 1012: loss 0.840628
batch 1013: loss 0.768259
batch 1014: loss 0.748090
batch 1015: loss 0.703898
batch 1016: loss 0.818043
batch 1017: loss 0.637366
batch 1018: loss 0.722528
batch 1019: loss 0.783805
batch 1020: loss 0.890520
batch 1021: loss 0.855393
batch 1022: loss 0.693065
batch 1023: loss 0.919418
batch 1024: loss 0.926174
batch 1025: loss 0.799007
batch 1026: loss 0.817318
batch 1027: loss 0.721477
batch 1028: loss 0.961

batch 1316: loss 0.658525
batch 1317: loss 0.671394
batch 1318: loss 0.727253
batch 1319: loss 0.665925
batch 1320: loss 0.646469
batch 1321: loss 0.615763
batch 1322: loss 0.648285
batch 1323: loss 0.601521
batch 1324: loss 0.690069
batch 1325: loss 0.808954
batch 1326: loss 0.760732
batch 1327: loss 0.638628
batch 1328: loss 0.727024
batch 1329: loss 0.780128
batch 1330: loss 0.720718
batch 1331: loss 0.661415
batch 1332: loss 0.636843
batch 1333: loss 0.686686
batch 1334: loss 0.747302
batch 1335: loss 0.648433
batch 1336: loss 1.074362
batch 1337: loss 0.621543
batch 1338: loss 0.696614
batch 1339: loss 0.942074
batch 1340: loss 0.523304
batch 1341: loss 0.685202
batch 1342: loss 0.753509
batch 1343: loss 0.753131
batch 1344: loss 0.859449
batch 1345: loss 0.644387
batch 1346: loss 0.842469
batch 1347: loss 0.614612
batch 1348: loss 0.823629
batch 1349: loss 0.638849
batch 1350: loss 0.565703
batch 1351: loss 0.774390
batch 1352: loss 0.759620
batch 1353: loss 0.635825
batch 1354: 

batch 1647: loss 0.575888
batch 1648: loss 0.566537
batch 1649: loss 0.852583
batch 1650: loss 0.509603
batch 1651: loss 0.515211
batch 1652: loss 0.616841
batch 1653: loss 0.808108
batch 1654: loss 0.688520
batch 1655: loss 0.421160
batch 1656: loss 0.564218
batch 1657: loss 0.505799
batch 1658: loss 0.482542
batch 1659: loss 0.615862
batch 1660: loss 0.582146
batch 1661: loss 0.726576
batch 1662: loss 0.584119
batch 1663: loss 0.650148
batch 1664: loss 0.838927
batch 1665: loss 0.644927
batch 1666: loss 0.740786
batch 1667: loss 0.589867
batch 1668: loss 0.644680
batch 1669: loss 0.757301
batch 1670: loss 0.507625
batch 1671: loss 0.545814
batch 1672: loss 0.590056
batch 1673: loss 0.498751
batch 1674: loss 0.653624
batch 1675: loss 0.527967
batch 1676: loss 0.542075
batch 1677: loss 0.592759
batch 1678: loss 0.513007
batch 1679: loss 0.735367
batch 1680: loss 0.599252
batch 1681: loss 0.549289
batch 1682: loss 0.762347
batch 1683: loss 0.750583
batch 1684: loss 0.776802
batch 1685: 

batch 1979: loss 0.749557
batch 1980: loss 0.490687
batch 1981: loss 0.510103
batch 1982: loss 0.536599
batch 1983: loss 0.469121
batch 1984: loss 0.497206
batch 1985: loss 0.472424
batch 1986: loss 0.389275
batch 1987: loss 0.501917
batch 1988: loss 0.352632
batch 1989: loss 0.564049
batch 1990: loss 0.527379
batch 1991: loss 0.500119
batch 1992: loss 0.480271
batch 1993: loss 0.627980
batch 1994: loss 0.480041
batch 1995: loss 0.565999
batch 1996: loss 0.641105
batch 1997: loss 0.508551
batch 1998: loss 0.481240
batch 1999: loss 0.481201
batch 2000: loss 0.495671
batch 2001: loss 0.544312
batch 2002: loss 0.345390
batch 2003: loss 0.375512
batch 2004: loss 0.654940
batch 2005: loss 0.370940
batch 2006: loss 0.592065
batch 2007: loss 0.617299
batch 2008: loss 0.435508
batch 2009: loss 0.593200
batch 2010: loss 0.597404
batch 2011: loss 0.689003
batch 2012: loss 0.523281
batch 2013: loss 0.481638
batch 2014: loss 0.456115
batch 2015: loss 0.491198
batch 2016: loss 0.747737
batch 2017: 

batch 2312: loss 0.423153
batch 2313: loss 0.530026
batch 2314: loss 0.382237
batch 2315: loss 0.684850
batch 2316: loss 0.473524
batch 2317: loss 0.558720
batch 2318: loss 0.479632
batch 2319: loss 0.659803
batch 2320: loss 0.378776
batch 2321: loss 0.606426
batch 2322: loss 0.619408
batch 2323: loss 0.417759
batch 2324: loss 0.473819
batch 2325: loss 0.397069
batch 2326: loss 0.589830
batch 2327: loss 0.643786
batch 2328: loss 0.563173
batch 2329: loss 0.546069
batch 2330: loss 0.481209
batch 2331: loss 0.673139
batch 2332: loss 0.606036
batch 2333: loss 0.752144
batch 2334: loss 0.355222
batch 2335: loss 0.565494
batch 2336: loss 0.614276
batch 2337: loss 0.547794
batch 2338: loss 0.742819
batch 2339: loss 0.493450
batch 2340: loss 0.574988
batch 2341: loss 0.631656
batch 2342: loss 0.440691
batch 2343: loss 0.480379
batch 2344: loss 0.423455
batch 2345: loss 0.469851
batch 2346: loss 0.498084
batch 2347: loss 0.478952
batch 2348: loss 0.389382
batch 2349: loss 0.474937
batch 2350: 

batch 2639: loss 0.482717
batch 2640: loss 0.432695
batch 2641: loss 0.662827
batch 2642: loss 0.445412
batch 2643: loss 0.460272
batch 2644: loss 0.402905
batch 2645: loss 0.453567
batch 2646: loss 0.492583
batch 2647: loss 0.385350
batch 2648: loss 0.344275
batch 2649: loss 0.607630
batch 2650: loss 0.453704
batch 2651: loss 0.624153
batch 2652: loss 0.363900
batch 2653: loss 0.473910
batch 2654: loss 0.540908
batch 2655: loss 0.385924
batch 2656: loss 0.339264
batch 2657: loss 0.582204
batch 2658: loss 0.520542
batch 2659: loss 0.475700
batch 2660: loss 0.363865
batch 2661: loss 0.326507
batch 2662: loss 0.605712
batch 2663: loss 0.641088
batch 2664: loss 0.336313
batch 2665: loss 0.368913
batch 2666: loss 0.367411
batch 2667: loss 0.349145
batch 2668: loss 0.608670
batch 2669: loss 0.602477
batch 2670: loss 0.523100
batch 2671: loss 0.538784
batch 2672: loss 0.556191
batch 2673: loss 0.363321
batch 2674: loss 0.586998
batch 2675: loss 0.480675
batch 2676: loss 0.377284
batch 2677: 

batch 2970: loss 0.281784
batch 2971: loss 0.430651
batch 2972: loss 0.289647
batch 2973: loss 0.555661
batch 2974: loss 0.444470
batch 2975: loss 0.505618
batch 2976: loss 0.469611
batch 2977: loss 0.296204
batch 2978: loss 0.528664
batch 2979: loss 0.778321
batch 2980: loss 0.247367
batch 2981: loss 0.500571
batch 2982: loss 0.375129
batch 2983: loss 0.606787
batch 2984: loss 0.396143
batch 2985: loss 0.527923
batch 2986: loss 0.674322
batch 2987: loss 0.410016
batch 2988: loss 0.411899
batch 2989: loss 0.381646
batch 2990: loss 0.416068
batch 2991: loss 0.464380
batch 2992: loss 0.508538
batch 2993: loss 0.320713
batch 2994: loss 0.477883
batch 2995: loss 0.374011
batch 2996: loss 0.227590
batch 2997: loss 0.390063
batch 2998: loss 0.588759
batch 2999: loss 0.289316
batch 3000: loss 0.576974
batch 3001: loss 0.273091
batch 3002: loss 0.454285
batch 3003: loss 0.401500
batch 3004: loss 0.439024
batch 3005: loss 0.573588
batch 3006: loss 0.375842
batch 3007: loss 0.355414
batch 3008: 

batch 3305: loss 0.412538
batch 3306: loss 0.338086
batch 3307: loss 0.547395
batch 3308: loss 0.590162
batch 3309: loss 0.244342
batch 3310: loss 0.431840
batch 3311: loss 0.363955
batch 3312: loss 0.322640
batch 3313: loss 0.554611
batch 3314: loss 0.369382
batch 3315: loss 0.314319
batch 3316: loss 0.719661
batch 3317: loss 0.459410
batch 3318: loss 0.391389
batch 3319: loss 0.718531
batch 3320: loss 0.455417
batch 3321: loss 0.310999
batch 3322: loss 0.337593
batch 3323: loss 0.348585
batch 3324: loss 0.376282
batch 3325: loss 0.519015
batch 3326: loss 0.487672
batch 3327: loss 0.399542
batch 3328: loss 0.419411
batch 3329: loss 0.651986
batch 3330: loss 0.527283
batch 3331: loss 0.459104
batch 3332: loss 0.519444
batch 3333: loss 0.416387
batch 3334: loss 0.370711
batch 3335: loss 0.444055
batch 3336: loss 0.382469
batch 3337: loss 0.427856
batch 3338: loss 0.558194
batch 3339: loss 0.343968
batch 3340: loss 0.273463
batch 3341: loss 0.523900
batch 3342: loss 0.317139
batch 3343: 

batch 3634: loss 0.315041
batch 3635: loss 0.428744
batch 3636: loss 0.295056
batch 3637: loss 0.482995
batch 3638: loss 0.381560
batch 3639: loss 0.285773
batch 3640: loss 0.346374
batch 3641: loss 0.491439
batch 3642: loss 0.436141
batch 3643: loss 0.218787
batch 3644: loss 0.419180
batch 3645: loss 0.231663
batch 3646: loss 0.280087
batch 3647: loss 0.340209
batch 3648: loss 0.545818
batch 3649: loss 0.375387
batch 3650: loss 0.461132
batch 3651: loss 0.397539
batch 3652: loss 0.349121
batch 3653: loss 0.411919
batch 3654: loss 0.282165
batch 3655: loss 0.453682
batch 3656: loss 0.347188
batch 3657: loss 0.429344
batch 3658: loss 0.401562
batch 3659: loss 0.348232
batch 3660: loss 0.539543
batch 3661: loss 0.343375
batch 3662: loss 0.325386
batch 3663: loss 0.406274
batch 3664: loss 0.471190
batch 3665: loss 0.330080
batch 3666: loss 0.382547
batch 3667: loss 0.446478
batch 3668: loss 0.474614
batch 3669: loss 0.315799
batch 3670: loss 0.452830
batch 3671: loss 0.383004
batch 3672: 

batch 3964: loss 0.386218
batch 3965: loss 0.424322
batch 3966: loss 0.524059
batch 3967: loss 0.302708
batch 3968: loss 0.487854
batch 3969: loss 0.327433
batch 3970: loss 0.409593
batch 3971: loss 0.365979
batch 3972: loss 0.287426
batch 3973: loss 0.318753
batch 3974: loss 0.369342
batch 3975: loss 0.271539
batch 3976: loss 0.275945
batch 3977: loss 0.542150
batch 3978: loss 0.344318
batch 3979: loss 0.405079
batch 3980: loss 0.442062
batch 3981: loss 0.195965
batch 3982: loss 0.273658
batch 3983: loss 0.333468
batch 3984: loss 0.428682
batch 3985: loss 0.267171
batch 3986: loss 0.329988
batch 3987: loss 0.451824
batch 3988: loss 0.270979
batch 3989: loss 0.297140
batch 3990: loss 0.427469
batch 3991: loss 0.239448
batch 3992: loss 0.450628
batch 3993: loss 0.306219
batch 3994: loss 0.351463
batch 3995: loss 0.607751
batch 3996: loss 0.653694
batch 3997: loss 0.470637
batch 3998: loss 0.544904
batch 3999: loss 0.193482
batch 4000: loss 0.449796
batch 4001: loss 0.339613
batch 4002: 

batch 4291: loss 0.309474
batch 4292: loss 0.378057
batch 4293: loss 0.307088
batch 4294: loss 0.489785
batch 4295: loss 0.278384
batch 4296: loss 0.410208
batch 4297: loss 0.549360
batch 4298: loss 0.522572
batch 4299: loss 0.419292
batch 4300: loss 0.402801
batch 4301: loss 0.380424
batch 4302: loss 0.410837
batch 4303: loss 0.414971
batch 4304: loss 0.431970
batch 4305: loss 0.405408
batch 4306: loss 0.569391
batch 4307: loss 0.435038
batch 4308: loss 0.413860
batch 4309: loss 0.271008
batch 4310: loss 0.358764
batch 4311: loss 0.293947
batch 4312: loss 0.270990
batch 4313: loss 0.357443
batch 4314: loss 0.492562
batch 4315: loss 0.389309
batch 4316: loss 0.443402
batch 4317: loss 0.492626
batch 4318: loss 0.229525
batch 4319: loss 0.321333
batch 4320: loss 0.435126
batch 4321: loss 0.402847
batch 4322: loss 0.315383
batch 4323: loss 0.422726
batch 4324: loss 0.482013
batch 4325: loss 0.485354
batch 4326: loss 0.609487
batch 4327: loss 0.231309
batch 4328: loss 0.459526
batch 4329: 

batch 4621: loss 0.299835
batch 4622: loss 0.375484
batch 4623: loss 0.400967
batch 4624: loss 0.214634
batch 4625: loss 0.369022
batch 4626: loss 0.286843
batch 4627: loss 0.309143
batch 4628: loss 0.403598
batch 4629: loss 0.394657
batch 4630: loss 0.416860
batch 4631: loss 0.372907
batch 4632: loss 0.330731
batch 4633: loss 0.276249
batch 4634: loss 0.380058
batch 4635: loss 0.351811
batch 4636: loss 0.428431
batch 4637: loss 0.386466
batch 4638: loss 0.351659
batch 4639: loss 0.354523
batch 4640: loss 0.494377
batch 4641: loss 0.450769
batch 4642: loss 0.323137
batch 4643: loss 0.624979
batch 4644: loss 0.211322
batch 4645: loss 0.321506
batch 4646: loss 0.408596
batch 4647: loss 0.310691
batch 4648: loss 0.840021
batch 4649: loss 0.408998
batch 4650: loss 0.119663
batch 4651: loss 0.380339
batch 4652: loss 0.636624
batch 4653: loss 0.316965
batch 4654: loss 0.327380
batch 4655: loss 0.158354
batch 4656: loss 0.347643
batch 4657: loss 0.230026
batch 4658: loss 0.515220
batch 4659: 

batch 4954: loss 0.426022
batch 4955: loss 0.357043
batch 4956: loss 0.239253
batch 4957: loss 0.280199
batch 4958: loss 0.317207
batch 4959: loss 0.357459
batch 4960: loss 0.362327
batch 4961: loss 0.250988
batch 4962: loss 0.368882
batch 4963: loss 0.283142
batch 4964: loss 0.278075
batch 4965: loss 0.374975
batch 4966: loss 0.551075
batch 4967: loss 0.361621
batch 4968: loss 0.231932
batch 4969: loss 0.337216
batch 4970: loss 0.367625
batch 4971: loss 0.321157
batch 4972: loss 0.301117
batch 4973: loss 0.429934
batch 4974: loss 0.533546
batch 4975: loss 0.296307
batch 4976: loss 0.490444
batch 4977: loss 0.520848
batch 4978: loss 0.490120
batch 4979: loss 0.293328
batch 4980: loss 0.401180
batch 4981: loss 0.293670
batch 4982: loss 0.419897
batch 4983: loss 0.373288
batch 4984: loss 0.261106
batch 4985: loss 0.257050
batch 4986: loss 0.300601
batch 4987: loss 0.379706
batch 4988: loss 0.691899
batch 4989: loss 0.576500
batch 4990: loss 0.612205
batch 4991: loss 0.441731
batch 4992: 

batch 5281: loss 0.464448
batch 5282: loss 0.333797
batch 5283: loss 0.266728
batch 5284: loss 0.331454
batch 5285: loss 0.238499
batch 5286: loss 0.530285
batch 5287: loss 0.281276
batch 5288: loss 0.267661
batch 5289: loss 0.438047
batch 5290: loss 0.180621
batch 5291: loss 0.508271
batch 5292: loss 0.307319
batch 5293: loss 0.287689
batch 5294: loss 0.428738
batch 5295: loss 0.293536
batch 5296: loss 0.400318
batch 5297: loss 0.267971
batch 5298: loss 0.378808
batch 5299: loss 0.297762
batch 5300: loss 0.542674
batch 5301: loss 0.308477
batch 5302: loss 0.424807
batch 5303: loss 0.362610
batch 5304: loss 0.483146
batch 5305: loss 0.223527
batch 5306: loss 0.331224
batch 5307: loss 0.488449
batch 5308: loss 0.244557
batch 5309: loss 0.408495
batch 5310: loss 0.258482
batch 5311: loss 0.247604
batch 5312: loss 0.608301
batch 5313: loss 0.395165
batch 5314: loss 0.377494
batch 5315: loss 0.318112
batch 5316: loss 0.447782
batch 5317: loss 0.348895
batch 5318: loss 0.277059
batch 5319: 

In [3]:
categorical_accuracy = tf.keras.metrics.CategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
for batch_index in range(num_batches):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = model.predict(data_loader.test_data[start_index: end_index])
    categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % categorical_accuracy.result())

NameError: name 'data_loader' is not defined