In [59]:
import os
import time
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

# 1. Basic training loops
训练神经网络的一般步骤为 (1)获取数据集，(2)搭建模型、定义损失函数和度量指标，(3)将数据送入网络、计算损失和度量，(4)根据损失反向传播计算梯度、更新参数；下面展示如何利用利用基础类通过简单的线性回归得到直线的斜率和截距：

```python
TRUE_W = 3.0
TRUE_B = 2.0
NUM_EXAMPLES = 1000
x = tf.random.normal([NUM_EXAMPLES,])
y = x * TRUE_W + TRUE_B + tf.random.normal([NUM_EXAMPLES,])

class LinearRegresModel(tf.Module):
    def __init__(self, init_fn, **kwargs):
        super().__init__(**kwargs)
        self.w = tf.Variable(init_fn([1,], "float32"))
        self.b = tf.Variable(init_fn([1,], "float32"))

    def __call__(self, x):
        return self.w * x + self.b

def loss_fn(y, y_pred):
    return tf.reduce_mean(tf.square(y - y_pred))

def train(model, x, y, lr, epochs):
    for epoch in range(epochs):
        with tf.GradientTape() as tape:
            loss = loss_fn(y, model(x))
        grad_w, grad_b = tape.gradient(loss, [model.w, model.b])
        model.w.assign_sub(lr * grad_w)
        model.b.assign_sub(lr * grad_b)
        print("Epoch %2d/%d: W=%1.2f b=%1.2f, loss=%2.5f" %
          (epoch, epochs, model.w, model.b, loss))

init_fn = tf.random_normal_initializer()
model = LinearRegresModel(init_fn)
train(model, x, y, lr=0.1, epochs=16)
```

利用 keras 实现该模型，并利用内置的`compile()`设置参数，利用`fit()`方法对网络进行训练：
```python
class KerasModel(tf.keras.Model):
    def __init__(self, init_fn, **kwargs):
        super().__init__(**kwargs)
        self.w = tf.Variable(init_fn([1,], "float32"))
        self.b = tf.Variable(init_fn([1,], "float32"))

    def __call__(self, x, **kwargs):
        return self.w * x + self.b

init_fn = tf.random_normal_initializer()
keras_model = KerasModel()
keras_model.compile(
    run_eagerly=False,
    optimizer=keras.optimizers.SGD(learning_rate=0.1),
    loss=keras.losses.mean_squared_error,
)
keras_model.fit(x, y, epochs=16, batch_size=1000)
```
以上仅是一个极其简单的问题，有关更实用的介绍，请参见[自定义训练教程](https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough)；更多关于 keras 内置训练函数的信息参加下面第 2 小节；如何人工实现训练循环请参见[相关指导](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch)；有关自定义分布式训练的方法，请参见[相关指导](https://www.tensorflow.org/guide/guide/distributed_training#using_tfdistributestrategy_with_basic_training_loops_loops)

# 

# 

# 2. Using built-in methods

本节不包含分布式训练，如需要了解分布式训练相关内容，请参阅[相关指导](https://www.tensorflow.org/guide/guide/distributed_training)；

本节基于 MNIST 数据集演示利用 keras 内置函数实现模型的训练、验证和评估的过程，数据准备和模型定义如下所示：
```python
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255.
x_test = x_test.reshape(10000, 784).astype("float32") / 255.
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

def mnist_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x = layers.Dense(64, activation="relu", name="dense1")(inputs)
    x = layers.Dense(64, activation="relu", name="dense2")(x)
    outputs = layers.Dense(10, activation="softmax", name="dense3")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

def mnist_model_compiled():
    model = mnist_model()
    model.compile(
        optimizer="rmsprop",
        loss="sparse_categorical_crossentropy",
        metrics=["sparse_categorical_accuracy"],
    )
    return model

model = mnist_model()
model_compiled = mnist_model_compiled()
```

给定一个`keras.Modle`模型实例，可以利用内置的`.compile()`方法定义优化算法、损失函数、度量值，利用`.fit()`方法对模型进行训练，利用`.evaluate()`方法对模型进行评估 (即测试)；简单的示例如上面代码`mnist_model_compiled`函数所示；




## 2.1 The `compile()` method
`compile()`方法通过接收`optimizer`、`loss`、`metrics`参数，以将模型与优化器 (即优化算法)、损失函数、度量值进行包装；其中`optimizer`可以是表示优化器的字符串，或`keras.optimizers`模块中的优化器实例；`loss`可以是表示目标函数的字符串，或自定义的目标函数，或`.keras.losses.Loss`实例；`metrics`应为一个列表，其每个元素可以是代表表示度量的字符串，或自定义函数，或`keras.metrics.Metric`实例；



### 2.1.1 Custom losses
实现自定义损失的方法有两种——为`loss`参数传递一个自定义函数，或传递一个继承了`keras.losses.Loss`类的实例



- 自定义的损失函数应满足范式`fn(y_true, y_pred)`，示例如下：

    ```python
    def custom_mse(y_true, y_pred):
        return tf.math.reduce_mean(tf.square(y_true - y_pred))

    model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mse)
    ```


- 若想要定义一个接收`y_true`和`y_pred`之外的参数的损失函数，则需要继承`keras.losses.Loss`类，同时对其`__init__(self, ...)`方法和`call(self, y_true, y_pred)`方法进行实现；下面假设我们希望在 MSE 损失中添加一个损失项，该损失项会惩罚远离 0.5 的预测值，由于使用 MSE 时分类目标总是 one-hot 编码，于是不难看出，这个附加项会使得模型对结果预判不过于绝对，进而有助于减少过拟合，其实现方法如下：

    ```python
    class CustomMSE(keras.losses.Loss):
        def __init__(self, reg_factor=0.1, name="custom_mse"):
            super().__init__(name=name)
            self.reg_factor = reg_factor

        def call(self, y_true, y_pred):
            mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
            reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
            return mse + reg * self.reg_factor

    model.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE())
```



### 2.1.2 Custom metrics
自定义度量标准则需要继承`keras.metrics.Metric`类，同时实现以下方法
- `__init__(self, custom_args)`：用于定义所需参数
- `update_state(self, y_true, y_pred, sample_weight=None)`：利用`y_true`、`y_pred`对状态变量进行更新；
- `result(self)`：利用状态变量得到最终度量结果；
- `reset_states(self)`：重新初始化

考虑到有时对度量结果的计算需要很大的计算量，并且只能定期进行，进而通常将对度量状态的更新和得到最终度量结果的方法分开；下面以“被正确分类的样本个数”为度量，演示添加自定义度量示例：

```python
class CateTruePositive(keras.metrics.Metric):
    def __init__(self, name="categorical_true_positive", **kwargs):
        super(CateTruePositive, self).__init__(name=name, **kwargs)
        self.true_positive = self.add_weight(
            name="ctp", initializer="zeros"
        )
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
        values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
        values = tf.cast(values, "float32")
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.true_positives.assign_add(tf.reduce_sum(values))

    def result(self):
        return self.true_positives

    def reset_states(self):
        self.true_positives.assign(0.0)

model.compile(
    optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[CateTruePositive()],
)
model.fit(x_train, y_train, batch_size=64, epochs=3)
```



### 2.1.3 Handling those don't fit the standard signature
对于无法利用`y_true`和`y_pred`计算的损失和度量，其可以在`call()`方法内调用`add_loss()`和`add_metric()`方法进行添加，这种方式添加的损失会在`fit()`阶段自动被添加至总损失中，模型也会自动对度量进行追踪；下面是添加对激活正则项的损失，以及添加激活标准差的度量的示例：

```python
class ActRegulAndMetric(layers.Layer):
    def call(self, inputs):
        self.add_loss(tf.reduce_sum(inputs) * 0.1)
        self.add_metric(
            keras.backend.std(inputs),
            name="std_of_activation",
            aggregation="mean"
        )
        return inputs  # a `pass` layer.

inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense1")(inputs)
x = ActRegulAndMetric()(x)  # Insert loss and metrics logging
x = layers.Dense(64, activation="relu", name="dense2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
model.fit(x_train, y_train, batch_size=64, epochs=1)
```

此外，也可以通过这两个方法对模型的总损失和度量进行定义，进而在模型编码阶段无需再指明`loss`和`metrics`参数；例如下面的示例，其在`LogisticEndpoint`层中定义了交叉熵损失和准确率度量：

```python
class LogisticEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(LogisticEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_fn = keras.metrics.BinaryAccuracy()

    def call(self, targets, logits, sample_weights=None):
        loss = self.loss_fn(targets, logits, sample_weights)
        acc = self.accuracy_fn(targets, logits, sample_weights)
        self.add_loss(loss)
        self.add_metric(acc)
        # Return the prediction for `.predict()`
        return tf.nn.softmax(logits)

inputs = keras.Input(shape=(3,), name="inputs")
targets = keras.Input(shape=(10,), name="targets")
logits = keras.layers.Dense(10)(inputs)
predictions = LogisticEndpoint(name="predictions")(logits, targets)
model = keras.Model(inputs=[inputs, targets], outputs=predictions)
model.compile(optimizer="adam")  # No loss argument!
data = {
    "inputs": np.random.random((3, 3)),
    "targets": np.random.random((3, 10)),
}
model.fit(data)
```



### 2.1.4 Compile multi-input, multi-output models
本小节以
对于有多个输出的模型，可以通过为`losses`传递一个由损失函数组成的列表，以为每个输出单独指定损失，该列表中损失函数排列顺序应与模型`compile`时传递的输出排列顺序相同；同样地也可以给`metrics`传递一个由度量组成的列表，以为每个输出单独指定损失，对于一个输出使用多个度量的情况，则可以使用嵌套列表`[[metric11, metric12], [metric21]]`；示例如下：

```python

若对输出指定了名称，则可以以字典的形式传递损失函数和度量

Consider the following model, which has an image input of shape (32, 32, 3) (that's (height, width, channels)) and a timeseries input of shape (None, 10) (that's (timesteps, features)). Our model will have two outputs computed from the combination of these inputs: a "score" (of shape (1,)) and a probability distribution over five classes (of shape (5,)).

At compilation time, we can specify different losses to different outputs, by passing the loss functions as a list:

If we only passed a single loss function to the model, the same loss function would be applied to every output (which is not appropriate here).

Likewise for metrics:

```python
img_input = keras.Input(shape=(32, 32, 3), name="img_input")
timeseries_input = keras.Input(shape=(None, 10), name="ts_input")

x1 = layers.Conv2D(3, 3)(img_input)
x1 = layers.GlobalMaxPooling2D()(x1)
x2 = layers.Conv1D(3, 3)(timeseries_input)
x2 = layers.GlobalMaxPooling1D()(x2)
x = layers.concatenate([x1, x2])

score_output = layers.Dense(1, name="score_output")(x)
class_output = layers.Dense(5, name="class_output")(x)

model = keras.Model(
    inputs=[image_input, timeseries_input],
    outputs=[score_output, class_output]
)
model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss=[
        keras.losses.MeanSquaredError(),
        keras.losses.CategoricalCrossentropy()
    ],
    metrics=[
        [
            keras.metrics.MeanAbsolutePercentageError(),
            keras.metrics.MeanAbsoluteError(),
        ],
        [keras.metrics.CategoricalAccuracy()],
    ],
)

Since we gave names to our output layers, we could also specify per-output losses and metrics via a dict:

model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss={
        "score_output": keras.losses.MeanSquaredError(),
        "class_output": keras.losses.CategoricalCrossentropy(),
    },
    metrics={
        "score_output": [
            keras.metrics.MeanAbsolutePercentageError(),
            keras.metrics.MeanAbsoluteError(),
        ],
        "class_output": [keras.metrics.CategoricalAccuracy()],
    },
)
```
We recommend the use of explicit names and dicts if you have more than 2 outputs.

It's possible to give different weights to different output-specific losses (for instance, one might wish to privilege the "score" loss in our example, by giving to 2x the importance of the class loss), using the loss_weights argument:
```python
model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss={
        "score_output": keras.losses.MeanSquaredError(),
        "class_output": keras.losses.CategoricalCrossentropy(),
    },
    metrics={
        "score_output": [
            keras.metrics.MeanAbsolutePercentageError(),
            keras.metrics.MeanAbsoluteError(),
        ],
        "class_output": [keras.metrics.CategoricalAccuracy()],
    },
    loss_weights={"score_output": 2.0, "class_output": 1.0},
)
You could also chose not to compute a loss for certain outputs, if these outputs meant for prediction but not for training:

# List loss version
model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss=[None, keras.losses.CategoricalCrossentropy()],
)

# Or dict loss version
model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss={"class_output": keras.losses.CategoricalCrossentropy()},
)
Passing data to a multi-input or multi-output model in fit works in a similar way as specifying a loss function in compile: you can pass lists of NumPy arrays (with 1:1 mapping to the outputs that received a loss function) or dicts mapping output names to NumPy arrays.

model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss=[keras.losses.MeanSquaredError(), keras.losses.CategoricalCrossentropy()],
)

# Generate dummy NumPy data
img_data = np.random.random_sample(size=(100, 32, 32, 3))
ts_data = np.random.random_sample(size=(100, 20, 10))
score_targets = np.random.random_sample(size=(100, 1))
class_targets = np.random.random_sample(size=(100, 5))

# Fit on lists
model.fit([img_data, ts_data], [score_targets, class_targets], batch_size=32, epochs=1)

# Alternatively, fit on dicts
model.fit(
    {"img_input": img_data, "ts_input": ts_data},
    {"score_output": score_targets, "class_output": class_targets},
    batch_size=32,
    epochs=1,
)
```

Here's the Dataset use case: similarly as what we did for NumPy arrays, the Dataset should return a tuple of dicts.
```python
train_dataset = tf.data.Dataset.from_tensor_slices(
    (
        {"img_input": img_data, "ts_input": ts_data},
        {"score_output": score_targets, "class_output": class_targets},
    )
)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

model.fit(train_dataset, epochs=1)
```




## 2.2 The `fit()` method
`fit()`, `evaluate()`, `predict()`方法均接收参数`x`、`y`作为训练、验证和预测的数据，其支持的格式包括 Numpy 数组、TF张量、字典、`tf.data.Dataset`实例、生成器或`keras.utils.Sequence`实例；此外`fit()`方法还接收参数`validation_split`、`validation_data`分别用于指明从训练集分割出作为验证集的比例、以及用于指明验证数据；需要说明的是，`validation_split`仅支持输入数据可被切片的类型；下面展示利用`validation_split`指明验证集的模型训练过程：
```python
model = mnist_model_compiled()
model.fit(x_train, y_train, batch_size=64, validation_split=0.2, epochs=1)
```



### 2.2.1 Using `tf.data Dataset` as an input
`tf.data`模块包含了 TF2 中用于加载和预处理数据的接口；有关其更详细的指导，参见[这里](https://www.tensorflow.org/guide/data)；

由于`Dataset`产生的是输入-输出对，以内部属性指定 batch 大小，且不支持切片操作，进而当以`Dataset`实例作为输入传递给`x`时，无需指定参数`y`、`batch_size`参数，且不支持参数`validation_split`；此时若需要指定验证集，则应通过参数`validation_data`传递，示例如下：

```python
trainset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
trainset = trainset.shuffle(buffer_size=1024).batch(64)
valset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(64)
testset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64)

model = mnist_model_compiled()
model.fit(trainset, epochs=3, validation_data=valset)
result = model.evaluate(testset)
```

需要说明的是，`fit()`还提供了`initial_epoch`参数，其默认值为 0，进而上面的训练共进行了 3 个 epoch；一般情况下，训练在`range(initial_epoch, epochs)`内进行；此外，默认每个 epoch 训练执行的步数为总样本数与 batch 大小的商，结束后训练集的`Dataset`会重置；然而若指明了`steps_per_epoch`参数，训练集的`Dataset`在 epoch 结束后则不重置，且在下一个 epoch 中会从上次结束的位置继续向后迭代，生成数据；类似地，`validation_steps`则用于指定每次验证执行的步数，不同的是，每次验证时都会先将验证集`Dataset`重置，以确保每次验证使用了相同的数据；





### 2.2.2 Using `keras.utils.Sequece` as an input
通过继承`keras.utils.Sequence`得到的生成器对象不仅支持多进程处理，还支持`fit()`方法的`shuffle`参数；继承`keras.utils.Sequence`生成器时必须实现其`__getitem__()`方法和`__len__()`方法，其中调用前者时应返回一个 batch 所含的数据，调用后者返回总 batch 个数；如果需要在 epoch 结束时对数据集进行调整，则应将`on_epoch_end`进行实现；

```python
from skimage.io import imread
from skimage.transform import resize

class CIFAR10Sequence(Sequence):
    def __init__(self, files, labels, batch_size):
        self.files, self.labels = files, labels
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.files) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.files[
            idx * self.batch_size: (idx + 1) * self.batch_size
        ]
        batch_y = self.labels[
            idx * self.batch_size: (idx + 1) * self.batch_size
        ]
        return np.array([
            resize(imread(file), (200, 200)) for file in batch_x
        ]), np.array(batch_y)

sequence = CIFAR10Sequence(files, labels, batch_size)
model.fit(sequence, epochs=10)
```



### 2.2.3 Other input formats supported
Besides NumPy arrays, eager tensors, and TensorFlow Datasets, it's possible to train a Keras model using Pandas dataframes, or from Python generators that yield batches of data & labels.

In particular, the keras.utils.Sequence class offers a simple interface to build Python data generators that are multiprocessing-aware and can be shuffled.

In general, we recommend that you use:

NumPy input data if your data is small and fits in memory
Dataset objects if you have large datasets and you need to do distributed training
Sequence objects if you have large datasets and you need to do a lot of custom Python-side processing that cannot be done in TensorFlow (e.g. if you rely on external libraries for data loading or preprocessing).




### 2.2.4 Using sample weighting and class weighting
默认情况下，样本权重由其在数据集中的出现频率决定；除此之外也可以通过`fit()`方法的`class_weight`参数和`sample_weight`方法指定样本权重；其中`class_weight`是将类别名称映射至相应权重的字典，进而可以弥补类别的不平衡问题；`sample_weight`适用于输入为 Numpy 数组的情况，其自身也是形状与输入数据相同的 Numpy 数组，其常用于处理类别不平衡问题，当`sample_weight`为一连串由 0、1 组成的数组时，可以视为损失的掩码使用；需要注意的是，对于输入为`tf.data.Dataset`，其不支持`sample_weight`，然而可以通过将`(input_batch, label_batch, sample_weight_batch)`元祖传递给`Dataset`来实现对样本的加权；

下面是利用`class_weight`和`sample_weight`对 MNIST 数据集中类别属于 5 的数据进行加权的示例：

```python
model = mnist_model_compiled()

# specify `class_weight`
class_weight = {
    0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0,
    5: 2.0, 6: 1.0, 7: 1.0, 8: 1.0, 9: 1.0,
}
model.fit(x_train, y_train, class_weight=class_weight, batch_size=64)

# specify `sample_weight` with Numpy array inputs
sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0
model.fit(x_train, y_train, sample_weight=sample_weight, batch_size=64)

# specify `sample_weight` with `Dataset` inputs
trainset = Dataset.from_tensor_slices((x_train, y_train, sample_weight))
trainset = trainset.shuffle(buffer_size=1024).batch(64)
model.fit(train_dataset, epochs=1)
```



### 2.2.5 Using callbacks
Keras 中的回调 (callback) 指在训练过程中不同时刻 (例如在 epoch 开始时时、一个 step 结束时、一个 epoch 结束时等) 调用的对象，它可以用于实现诸如
- 在训练的不同时刻进行验证，这里的验证与内置的验证并非同一个验证；
- 在模型执行一定步数后或当其精度超过一定阈值后对模型记录检查点；
- 在训练趋近于饱和时改变学习率或对顶层网络进行微调；
- 训练结束或模型超过某个性能阈值时发送电子邮件或消息通知
- Etc.

`keras.callbacks`模块提供了很多实现回调的 API，例如`CSVLogger`、`EarlyStopping`、`TensorBoard`、`ModelCheckpoint`等；实现时只需将需要进行的回调对象以列表的形式传递给`fit()`方法的`callbacks`参数即可；



#### 2.2.5.1 Checkpointing models

训练中对模型保存检查点的最简单方法是使用`ModelCheckpoint`回调；下面是利用`ModelCheckpoint`回调实现代码容错的示例，即在训练被随机中断后，可以从最近保存的模型状态继续训练；

```python
ckpt_dir = "./ckpt"  # dir to store all the checkpoints.
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

def make_or_restore_model():
    ckpts = [ckpt_dir + "/" + name for name in os.listdir(ckpt_dir)]
    if ckpts:
        latest_ckpt = max(ckpts, key=os.path.getctime)
        print("Restoring from", latest_ckpt)
        return keras.models.load_model(latest_ckpt)
    print("Creating a new model")
    return mnist_model_compiled()

model = make_or_restore_model()
callbacks = [keras.callbacks.ModelCheckpoint(
    filepath=ckpt_dir + "/ckpt-loss={loss:.2f}", save_freq=100
)]
model.fit(x_train, y_train, epochs=1, callbacks=callbacks)
```
当然也可以自定义回调函数来保存和恢复模型；关于模型序列化和保存的方法，请参阅[相关指导](https://www.tensorflow.org/guide/keras/save_and_serialize/)



#### 2.2.5.2 Using learning rate schedules

学习率静态衰减中，学习率通常是当前 epoch 数值或 batch 索引的一个函数；学习率动态衰减中，学习率是根据模型当前的行为来动态调整的，例如在验证损失不再提高时减少学习率；

实现学习率静态衰减的方式之一便是利用`keras.optimizers.schedules`模块内置的静态学习率衰减 API，如`ExponentialDecay`、`PiecewiseConstantDecay`、`PolynomialDecay`、`InverseTimeDecay`；使用时只需将其实例化对象传递给优化器的`learning_rate`参数即可

```python
init_lr = 0.1
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    init_lr, decay_steps=100000, decay_rate=0.96, staircase=True
)
optimizer = keras.optimizers.RMSprop(learning_rate=lr_schedule)
```

然而由于优化器无法访问验证指标，进而`keras.optimizers.schedules`模块中的 API 无法实现学习率的动态衰减；不过`keras.callbacks`模块中的 API 均可以对验证指标进行访问，例如内置的`ReduceLROnPlateau`类便可实现；



#### 2.2.5.3 Visualizing loss and metrics during training

TensorBoard 可以实时地绘制训练和验证的损失和度量图、对层模型的激活进行可视化、对嵌入层学习到的嵌入空间进行 3 维可视化；使用`TensorBoard`回调可以对这些概念进行快速实现，示例如下；更多信息请参阅`TensorBoard`回调的[相关文档](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/tensorboard/)

```python
keras.callbacks.TensorBoard(
    log_dir="/full_path_to_your_logs",
    histogram_freq=0,
    embeddings_freq=0,
    update_freq="epoch",
)
```

#### 2.2.5.4 Writing your own callback

可以通过继承`keras.callbacks.Callback`类来创建自定义回调，回调可以通过类属性`.model`访问其关联的模型；更多信息请参见[相关指导](https://www.tensorflow.org/guide/keras/custom_callback/)；下面是在训练期间将每个 batch 的损失以列表形式保存的简单示例：

```python
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))
```

# 

# 

In [106]:
class CustomModel(keras.Model):
    def __init__(self, **kwargs):
        super(CustomModel, self).__init__(**kwargs)
        self.dense1 = layers.Dense(64, activation="relu")
        self.dense2 = layers.Dense(64, activation="relu")
        self.dense3 = layers.Dense(10, activation="softmax")
    
    def call(self, x, **kwargs):
        y = self.dense3(self.dense2(self.dense1(x)))
        return y
    
    def train_step(self, data):
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = keras.losses.mean_squared_error(y, y_pred)
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(
            zip(grads, self.trainable_variables)
        )
        loss_tracker.update_state(loss)
        mae_metric.update_state(y, y_pred)
        return {"loss": loss_tracker.result(), "mae": mae_metric.result()}

    @property
    def metrics(self):
        return [loss_tracker, mae_metric]


model = CustomModel()
model.compile(optimizer="adam")
loss_tracker = keras.metrics.Mean(name="loss")
mae_metric = keras.metrics.MeanAbsoluteError(name="mae")


Model: "custom_model_19"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_52 (Dense)             multiple                  50240     
_________________________________________________________________
dense_53 (Dense)             multiple                  4160      
_________________________________________________________________
dense_54 (Dense)             multiple                  650       
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
_________________________________________________________________


In [104]:
x = np.random.random((1000, 784))
y = np.random.random((1000, 1))
weight = np.random.random((1000, 1))
model.fit(x, y, sample_weight=weight, epochs=5)

Epoch 1/5
Tensor("IteratorGetNext:2", shape=(None, 1), dtype=float32)
Tensor("IteratorGetNext:2", shape=(None, 1), dtype=float32)
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x23f6b8a3310>

# 3. Customize what happens in `Model.fit`

本节介绍的方法对`Sequetial`模型、functional 模型、类继承模型均适用；以下的操作需要 TF2.2 及以上版本；

## Customize training
通常情况下，调用`compile()`和`fit()`方法便可满足大部分训练要求；若需要自定义训练过程，一种方式是利用`GradientTape`完全手动实现；然而若仍希望利用`fit()`函数的便捷性，可以通过重写`Model`类`train_step()`方法来实现；具体而言，首先创建一个继承`keras.Model`的类，同时对其`train_step(self, data)`方法进行重写，其中参数`data`指被传递至`fit()`方法的训练数据组成的二元元祖，对于指明了`sample_weight`参数的情况，`data`则为包含了相应的权重的三元元祖；需要说明的是，尽管`fit()`可以接收多种格式的训练数据，但当其内部调用`train_step()`方法时，总是将输入与输出以及样本权重组成一个元祖传递给`data`参数，进而这里只需将参数设置为`data`即可；

`train_step()`方法内部则需要实现模型参数更新；这里使用`self.compiled_loss()`计算损失，使用`self.compiled_metrics.update_state()`计算度量，前提是损失函数和度量已提前通过调用`compile()`方法封装至模型中；当然也可以通过传递`keras.losses`模块和`keras.metrics`模块中的 API 来手动计算损失和度量，后者的实现方法参见第二个`CustomModel`示例

最后通过访问`self.metrics`返回一个将度量名称映射至度量值的字典，这里度量包括了损失；

```python
class CustomModel(keras.Model):
    def train_step(self, data):
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(
                y, y_pred,
                sample_weight=sample_weight,
                regularization_losses=self.losses,
            )
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(
            zip(grads, self.trainable_variables)
        )
        self.compiled_metrics.update_state(
            y, y_pred,
            sample_weight=sample_weight
        )
        return {m.name: m.result() for m in self.metrics}

inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense1")(inputs)
x = layers.Dense(64, activation="relu", name="dense2")(x)
outputs = layers.Dense(10, activation="softmax", name="dense3")(x)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

x = np.random.random((1000, 784))
y = np.random.random((1000, 10))
weight = np.random.random((1000, 1))
model.fit(x, y, sample_weight=weight, epochs=5)
```

对于自定义损失函数和度量的情况，便无需在`compile()`函数中再制定损失和度量，同时需要在一个训练步结束后调用`metric.update_state()`来更新度量值；需要注意的是，实际上所有度量在一个 epoch 训练结束后应该进行重置，一种方式是手动调用`metric.reset_states()`方法，另一种方式是将所有度量以列表的形式重写至`metrics`属性，进而`fit()`函数会在每个 epoch 一开始时对`metrics`属性中所有度量进行重置；示例如下，这里演示了使用类继承实现的方式：

```python
class CustomModel(keras.Model):
    def __init__(self, **kwargs):
        super(CustomModel, self).__init__(**kwargs)
        self.dense1 = layers.Dense(64, activation="relu")
        self.dense2 = layers.Dense(64, activation="relu")
        self.dense3 = layers.Dense(10, activation="softmax")
    
    def call(self, x, **kwargs):
        y = self.dense3(self.dense2(self.dense1(x)))
        return y
    
    def train_step(self, data):
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            (x, y), sample_weight = data, None
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = keras.losses.mean_squared_error(y, y_pred)
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(
            zip(grads, self.trainable_variables)
        )
        loss_tracker.update_state(loss)
        mae_metric.update_state(y, y_pred)
        return {"loss": loss_tracker.result(), "mae": mae_metric.result()}

    @property
    def metrics(self):
        return [loss_tracker, mae_metric]

model = CustomModel()
model.compile(optimizer="adam")
loss_tracker = keras.metrics.Mean(name="loss")
mae_metric = keras.metrics.MeanAbsoluteError(name="mae")

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
weight = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=5)
```


## Customize evaluation
如果需要对`model.evaluate()`进行自定义，可以通过重写`test_step()`方法实现：

```python
class CustomModel(keras.Model):
    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
```

## Wrapping up: an end-to-end GAN example

下面以一个在 MNIST 数据集上训练的 GAN 为例，演示自定义`fit()`函数的训练环节的过程；训练循环包括 2 个分支：

- 对判别器的训练：在隐空间中随机抽取一个 batch 的点 $\rightarrow$ 利用生成器将这些点变成假图像 $\rightarrow$ 获取一批真实图像，并与生成的图像进行组合 $\rightarrow$ 训练判别器以使其能够正确的对生成的图像和真实图像进行分类；

- 对生成器的训练：在隐空间中随机抽取一个 batch 的点 $\rightarrow$ 利用生成器将这些点变成假图像 $\rightarrow$ 获取一批真实图像，并与生成的图像进行组合 $\rightarrow$ 训练生成器，以使其生成的图片被判别器误以为是真的图片

实现如下：

```python
def get_discriminator_generator(latent_dim=128):
    discriminator = keras.Sequential([
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
        ], name="discriminator",
    )
    generator = keras.Sequential([
        keras.Input(shape=(latent_dim,)),
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
        ], name="generator",
    )
    return discriminator, generator


class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim=128):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_imgs):
        if isinstance(real_imgs, tuple):
            real_imgs = real_imgs[0]
        batch_size = tf.shape(real_imgs)[0]
        # ==> Train the discriminator
        latent_vecs = tf.random.normal([batch_size, self.latent_dim])
        fake_imgs = self.generator(latent_vecs)
        combined_imgs = tf.concat([fake_imgs, real_imgs], axis=0)
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))],
            axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))
        with tf.GradientTape() as tape:
            preds = self.discriminator(combined_imgs)
            d_loss = self.loss_fn(labels, preds)
        grads = tape.gradient(
            d_loss, self.discriminator.trainable_weights
        )
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )
        # ==> Train the generator
        latent_vecs = tf.random.normal([batch_size, self.latent_dim])
        fake_labels = tf.zeros((batch_size, 1))
        with tf.GradientTape() as tape:
            preds = self.discriminator(self.generator(latent_vecs))
            g_loss = self.loss_fn(fake_labels, preds)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(
            zip(grads, self.generator.trainable_weights)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}


def load_data(batch_size=64):
    (x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
    digits = np.concatenate([x_train, x_test])
    digits = digits.astype("float32") / 255.0
    digits = np.reshape(digits, (-1, 28, 28, 1))
    dataset = tf.data.Dataset.from_tensor_slices(digits)
    dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
    return dataset

dataset = load_data(batch_size=64)
discriminator, generator = get_discriminator_generator(latent_dim=128)
gan = GAN(discriminator, generator, latent_dim=128)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
gan.fit(dataset.take(100), epochs=1)
```

# 

# 

# 4. Writing a training loop from scratch

利用 keras 的内置函数`fit()`和`evaluate()`可以实现快速的训练和评估，相关指导可参见第 2 节；通过继承`Model`类并实现`train_step()`方法，可以在自定义模型的学习算法同时，依旧利用`fit()`函数进行训练，相关指导可参见第 3 节；下面仍以基于 MNIST 数据集的一个简单的模型为示例，介绍如何对训练和评估进行底层实现，即从头编写训练的循环代码；

下面是一个简单的示例：

```python
def mnist_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x = layers.Dense(64, activation="relu", name="dense1")(inputs)
    x = layers.Dense(64, activation="relu", name="dense2")(x)
    outputs = layers.Dense(10, activation="softmax", name="dense3")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

def load_dataset(batch_size):
    (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
    x_train = np.reshape(x_train, (-1, 784))
    x_test = np.reshape(x_test, (-1, 784))
    x_val = x_train[-10000:]
    y_val = y_train[-10000:]
    x_train = x_train[:-10000]
    y_train = y_train[:-10000]
    trainset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    trainset = trainset.shuffle(buffer_size=1024).batch(batch_size)
    valset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    valset = valset.batch(batch_size)
    return trainset, valset

def train(epochs):
    for epoch in range(epochs):
        print("\n==> Epoch %d:" % (epoch,))
        for step, (x_batch_train, y_batch_train) in enumerate(trainset):
            with tf.GradientTape() as tape:
                y_pred = model(x_batch_train, training=True)
                loss = loss_fn(y_batch_train, y_pred)
            grads = tape.gradient(loss, model.trainable_weights)
            optimizer.apply_gradients(zip(
                grads, model.trainable_weights
            ))
            if step % 200 == 0:
                print(
                    "Training loss (for one batch) at step %d: %.4f"
                    % (step, float(loss))
                )

trainset, valset = load_dataset(64)
model = mnist_model()
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train(model, trainset, epochs=2)
```





## 4.1 Low-level handling of metrics & losses

若需要在循环中对度量进行追踪，考虑到度量值是对整个数据集而言的，进而需要**在每个 batch 训练完成后调用`metric.update_state()`以更新度量值**，以及在每个 epoch 训练结束后，调用`metric.reset_states()`对其清零；此外可以利用`metric.result()`获取训练期间度量值以对模型性能进行监测；

在前向传播的过程中，层和模型对象会递归地调用网络层的`self.add_loss(value)`方法，来对期间产生的所有损失进行追踪，所有通过`.add_loss()`方法添加损失所组成的列表可以通过`.losses`属性获得；进而在人工实现训练循环时，只需在得到损失函数返回的损失值后，将其与模型向前传播期间得到的损失值相加即可；

对上述模型记录`SparseCategoricalAccuracy`度量，并添加激活正则项的方式如下：

```python
def mnist_model_act_reg():
    class ActivityRegularizationLayer(layers.Layer):
        def call(self, inputs):
            self.add_loss(1e-2 * tf.reduce_sum(inputs))
            return inputs

    inputs = keras.Input(shape=(784,), name="digits")
    x = layers.Dense(64, activation="relu")(inputs)
    x = ActivityRegularizationLayer()(x)
    x = layers.Dense(64, activation="relu")(x)
    outputs = layers.Dense(10, name="predictions")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

def train(epochs):
    for epoch in range(epochs):
        print("\n==> Epoch %d:" % (epoch,))
        start_time = time.time()
        for step, (x_train, y_train) in enumerate(trainset):
            with tf.GradientTape() as tape:
                y_pred = model(x_train, training=True)
                loss = loss_fn(y_train, y_pred)
                loss += sum(model.losses)
            grads = tape.gradient(loss, model.trainable_weights)
            optimizer.apply_gradients(zip(
                grads, model.trainable_weights
            ))
            train_acc_metric.update_state(y_train, y_pred)
            if step % 200 == 0:
                print("step: %3d, loss: %.4f" % (step, loss))
        train_acc = train_acc_metric.result()
        train_acc_metric.reset_states()
        print("Training acc over epoch: %.4f" % (float(train_acc),))

        for x_val, y_val in valset:
            y_pred = model(x_val, training=False)
            val_acc_metric.update_state(y_val, y_pred)
        val_acc = val_acc_metric.result()
        val_acc_metric.reset_states()
        print("Validation acc: %.4f" % (float(val_acc),))
        print("Time taken: %.2fs" % (time.time() - start_time))

trainset, valset = load_dataset(batch_size=64)
model = mnist_model_act_reg()
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

train(epochs=2)
```




## 4.2 Speeding-up training with `tf.function`
TF2 默认的运行时 (runtime) 是即时执行模型，尽管这对调试很友好，然而利用计算图编译具有明显的性能优势 —— 如果将整个模型计算描述为静态图，则可以使计算框架对全局性能进行优化；而如果计算框架不得不贪婪地执行操作，而无法获知后面的计算，显然无法对整体进行很好的加速；

TF2 中，通过添加`@tf.function`修饰符，可以将任何一个以张量为输入的函数编译至静态图形中，进而实现加速，以上面的添加了激活正则项的 MNIST 模型为例，演示如下：

```python
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x, training=True)
        loss = loss_fn(y, y_pred)
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, y_pred)
    return loss

@tf.function
def test_step(x, y):
    y_pred = model(x, training=False)
    val_acc_metric.update_state(y, y_pred)


def train(epochs):
    for epoch in range(epochs):
        print("==> Epoch: %d" % epoch)
        start_time = time.time()
        for step, (x, y) in enumerate(trainset):
            loss = train_step(x, y)
            loss += sum(model.losses)
            if step % 200 == 0:
                print("step: %3d, loss: %.4f" % (step, loss))
        train_acc = train_acc_metric.result()
        train_acc_metric.reset_states()
        print("Training acc over the epoch: %.4f" % train_acc)
        
        for x, y in valset:
            test_step(x, y)
        val_acc = val_acc_metric.result()
        val_acc_metric.reset_states()
        print("Validation acc: %.4f" % val_acc)
        print("Time taken: %.2f" % (time.time() - start_time))

model = mnist_model_act_reg()
optimizer = keras.optimizers.SGD(1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

train(epochs=2)
```




## 4.3 An end-to-end Example

下面以一个在 MNIST 数据集上训练的 GAN 为例，演示人工从头实现训练循环的过程；训练循环包括 2 个分支：

- 对判别器的训练：在隐空间中随机抽取一个 batch 的点 $\rightarrow$ 利用生成器将这些点变成假图像 $\rightarrow$ 获取一批真实图像，并与生成的图像进行组合 $\rightarrow$ 训练判别器以使其能够正确的对生成的图像和真实图像进行分类；

- 对生成器的训练：在隐空间中随机抽取一个 batch 的点 $\rightarrow$ 利用生成器将这些点变成假图像 $\rightarrow$ 获取一批真实图像，并与生成的图像进行组合 $\rightarrow$ 训练生成器，以使其生成的图片被判别器误以为是真的图片

实现如下：

```python
def load_dataset(batch_size=64):
    (x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
    digits = np.concatenate([x_train, x_test]).astype("float32") / 255.0
    digits = np.reshape(digits, (-1, 28, 28, 1))
    dataset = tf.data.Dataset.from_tensor_slices(digits)
    dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
    return dataset

def GAN_model(latent_dim=128):
    discriminator = keras.Sequential([
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
        ], name="discriminator",
    )
    generator = keras.Sequential([
        keras.Input(shape=(latent_dim,)),
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
        ], name="generator",
    )
    return discriminator, generator

@tf.function
def train_step(real_imgs, latent_dim=128, batch_size=64):
    # ==> Train the discriminator
    latent_vecs = tf.random.normal([batch_size, latent_dim])
    fake_imgs = generator(latent_vecs)
    combined_imgs = tf.concat([fake_imgs, real_imgs], axis=0)
    labels = tf.concat(
        [tf.ones((batch_size, 1)), tf.zeros((real_imgs.shape[0], 1))],
        axis=0
    )
    # Add random noise to labels -- important trick!
    labels += 0.05 * tf.random.uniform(labels.shape)
    with tf.GradientTape() as tape:
        preds = discriminator(combined_imgs)
        d_loss = loss_fn(labels, preds)
    grads = tape.gradient(d_loss, discriminator.trainable_weights)
    d_optimizer.apply_gradients(
        zip(grads, discriminator.trainable_weights)
    )
    # ==> Train the generator
    latent_vecs = tf.random.normal(shape=(batch_size, latent_dim))
    fake_labels = tf.zeros((batch_size, 1))
    with tf.GradientTape() as tape:
        preds = discriminator(generator(latent_vecs))
        g_loss = loss_fn(fake_labels, preds)
    grads = tape.gradient(g_loss, generator.trainable_weights)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
    return d_loss, g_loss, fake_imgs

def train(epochs=16, save_dir=None):
    for epoch in range(epochs):
        print("==> Epoch:", epoch)
        for step, real_imgs in enumerate(dataset):
            d_loss, g_loss, fake_imgs = train_step(real_imgs)
            if step % 200 == 0:
                print(
                    "step: {},"
                    "discriminator loss: {},"
                    "generator loss: {}".format(step, d_loss, g_loss)
                )
                if save_dir:
                    img = tf.keras.preprocessing.image.array_to_img(
                        fake_imgs[0] * 255.0, scale=False
                    )
                    img.save(os.path.join(
                        save_dir, "generated_img" + str(step) + ".png"
                    ))


dataset = load_dataset()
discriminator, generator = GAN_model()
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
train(epochs=2, "../../test/Guide/Keras/02.Train_a_Model/GAN_example")
```
注意，由于涉及到卷积运算，建议在 GPU 上运行；