### 6-2 训练模型的3种方法
模型的训练主要有内置fit方法、内置tran_on_batch方法、自定义训练循环。

注：fit_generator方法在tf.keras中不推荐使用，其功能已经被fit包含。

In [1]:
import tensorflow as tf 
import pandas as pd
import numpy as np
from tensorflow.keras import *

In [2]:
# 打印时间分割线
@tf.function
def printbar():
    ts = tf.timestamp()
    today_ts = ts%(24*60*60)

    hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
    minite = tf.cast((today_ts%3600)//60,tf.int32)
    second = tf.cast(tf.floor(today_ts%60),tf.int32)

    def timeformat(m):
        if tf.strings.length(tf.strings.format("{}",m))==1:
            return(tf.strings.format("0{}",m))
        else:
            return(tf.strings.format("{}",m))

    timestring = tf.strings.join([timeformat(hour),timeformat(minite),
                timeformat(second)],separator = ":")
    tf.print("=========="*8,end = "")
    tf.print(timestring)

In [3]:
MAX_LEN = 300
BATCH_SIZE = 32
(x_train, y_train), (x_test, y_test) = datasets.reuters.load_data()
x_train = preprocessing.sequence.pad_sequences(x_train, maxlen=MAX_LEN)
x_test = preprocessing.sequence.pad_sequences(x_test, maxlen=MAX_LEN)

MAX_WORDS = x_train.max() + 1
CAT_NUM = y_train.max() + 1

In [4]:
ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
          .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
          .prefetch(tf.data.experimental.AUTOTUNE).cache()

ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
          .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
          .prefetch(tf.data.experimental.AUTOTUNE).cache()

### 一、内置fit方法
该方法功能非常强大, 支持对numpy array, tf.data.Dataset以及 Python generator数据进行训练。

并且可以通过设置回调函数实现对训练过程的复杂控制逻辑。

In [5]:
tf.keras.backend.clear_session()

def create_model():
    model = models.Sequential()
    model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
    model.add(layers.Conv1D(filters=64, kernel_size=5, activation='relu'))
    model.add(layers.MaxPooling1D(2))
    model.add(layers.Conv1D(filters=32, kernel_size=3, activation='relu'))
    model.add(layers.MaxPooling1D(2))
    model.add(layers.Flatten())
    model.add(layers.Dense(CAT_NUM, activation="softmax"))
    return model

def compile_model(model):
    model.compile(optimizer=optimizers.Nadam(),
                loss=losses.SparseCategoricalCrossentropy(),
                metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)]) 
    return(model)

model = create_model()
model.summary()
model = compile_model(model)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (None, 300, 7)            216874    
_________________________________________________________________
conv1d (Conv1D)              (None, 296, 64)           2304      
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 148, 64)           0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 146, 32)           6176      
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 73, 32)            0         
_________________________________________________________________
flatten (Flatten)            (None, 2336)              0         
_________________________________________________________________
dense (Dense)                (None, 46)                1

In [6]:
history = model.fit(ds_train, validation_data=ds_test, epochs=6, verbose=1)

Train for 281 steps, validate for 71 steps
Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


### 二、内置train_on_batch方法
该内置方法相比较fit方法更加灵活，可以不通过回调函数而直接在批次层次上更加精细地控制训练的过程

In [7]:
def train_model(model, ds_train, ds_valid, epochs):
    for epoch in tf.range(1, epochs+1):
        model.reset_metrics()
        
        # 在后期降低学习率
        if epoch == 5:
            model.optimizer.lr.assign(model.optimizer.lr/2.0)
            tf.print("Lowering optimizer Learning Rate...\n\n")
            
        for x, y in ds_train:
            train_result = model.train_on_batch(x, y)
            
        for x,y in ds_valid:
            valid_result = model.test_on_batch(x, y)
            
        if epoch % 1 == 0:
            printbar()
            tf.print("epoch = ",epoch)
            print("train:", dict(zip(model.metrics_names, train_result)))
            print("valid:", dict(zip(model.metrics_names, valid_result)))
            print("")

In [8]:
train_model(model, ds_train, ds_test, epochs=10)

epoch =  1
train: {'loss': 0.03985308, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 4.4741173, 'sparse_categorical_accuracy': 0.33333334, 'sparse_top_k_categorical_accuracy': 0.6666667}

epoch =  2
train: {'loss': 0.033269268, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 4.78131, 'sparse_categorical_accuracy': 0.33333334, 'sparse_top_k_categorical_accuracy': 0.6666667}

epoch =  3
train: {'loss': 0.026025575, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 5.027471, 'sparse_categorical_accuracy': 0.33333334, 'sparse_top_k_categorical_accuracy': 0.6666667}

epoch =  4
train: {'loss': 0.022150518, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 5.1505265, 'sparse_categorical_accuracy': 0.33333334, 'sparse_top_k_categorical_accuracy': 0.6666667}

Lowering optimizer Learning Rate...


epoch =  5
train: 

### 三、自定义训练循环
自定义训练循环无需编译模型，直接利用优化器根据损失函数反向传播迭代参数，拥有最高的灵活性。

In [9]:
optimizer = optimizers.Nadam()
loss_func = losses.SparseCategoricalCrossentropy()

train_loss = metrics.Mean(name='train_loss')
train_metric = metrics.SparseCategoricalAccuracy(name='train_accuracy')

valid_loss = metrics.Mean(name='valid_loss')
valid_metric = metrics.SparseCategoricalAccuracy(name='valid_accuracy')

@tf.function
def train_step(model, features, labels):
    with tf.GradientTape() as tape:
        predictions = model(features,training = True)
        loss = loss_func(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss.update_state(loss)
    train_metric.update_state(labels, predictions)
    
@tf.function
def valid_step(model, features, labels):
    predictions = model(features)
    batch_loss = loss_func(labels, predictions)
    valid_loss.update_state(batch_loss)
    valid_metric.update_state(labels, predictions)
    

def train_model(model,ds_train,ds_valid,epochs):
    for epoch in tf.range(1,epochs+1):
        
        for features, labels in ds_train:
            train_step(model,features,labels)

        for features, labels in ds_valid:
            valid_step(model,features,labels)

        logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
        
        if epoch%1 ==0:
            printbar()
            tf.print(tf.strings.format(logs,
            (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
            tf.print("")
            
        train_loss.reset_states()
        valid_loss.reset_states()
        train_metric.reset_states()
        valid_metric.reset_states()

train_model(model,ds_train,ds_test,10)

Epoch=1,Loss:0.184110194,Accuracy:0.949343145,Valid Loss:3.47809172,Valid Accuracy:0.63089937

Epoch=2,Loss:0.17708835,Accuracy:0.946671128,Valid Loss:3.6346848,Valid Accuracy:0.628673196

Epoch=3,Loss:0.163738102,Accuracy:0.94900912,Valid Loss:3.66780066,Valid Accuracy:0.629118443

Epoch=4,Loss:0.150552884,Accuracy:0.949899793,Valid Loss:3.69140673,Valid Accuracy:0.626447

Epoch=5,Loss:0.143080384,Accuracy:0.951013148,Valid Loss:3.66971564,Valid Accuracy:0.623330355

Epoch=6,Loss:0.135870054,Accuracy:0.952126503,Valid Loss:3.71025276,Valid Accuracy:0.623775601

Epoch=7,Loss:0.127501383,Accuracy:0.955466509,Valid Loss:3.76818109,Valid Accuracy:0.621549428

Epoch=8,Loss:0.120428041,Accuracy:0.956579804,Valid Loss:3.81625462,Valid Accuracy:0.616651833

Epoch=9,Loss:0.116922945,Accuracy:0.958583832,Valid Loss:3.90766406,Valid Accuracy:0.614870906

Epoch=10,Loss:0.110454135,Accuracy:0.959363163,Valid Loss:3.91532326,Valid Accuracy:0.610418499

