## 1.导入依赖以及网络结构

In [72]:
# 加载网络结构
%run AI2Flutter.py

## 2.实例化网络，并设置模型输入形状

In [74]:
# 超参数
num_layers = 4
d_model = 256
dff = 512
num_heads = 8
dropout_rate = 0.1
input_node_dim = 24  # demo中均使用24维向量
target_node_dim = 24  # 

# 权重保存位置
save_weight_path = "./model_weight/model_1"
save_path = "./model/model_1"

transformer = Transformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    dff=dff,
    input_node_dim=input_node_dim,
    target_node_dim=target_node_dim,
    dropout_rate=dropout_rate)

# 优化器采用Adam，学习率自定义
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)
transformer.compile(
    loss='mean_squared_error',
    optimizer=optimizer,
)
# 设置模型输入形状
transformer((tf.keras.layers.Input(shape=(None, input_node_dim,)),
             tf.keras.layers.Input(shape=(None, target_node_dim,))))
# 网络概览
transformer.summary()

Model: "transformer_13"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_13 (Encoder)        multiple                  2642048   
                                                                 
 decoder_13 (Decoder)        multiple                  4753024   
                                                                 
 dense_133 (Dense)           multiple                  3096      
                                                                 
Total params: 7,398,168
Trainable params: 7,398,168
Non-trainable params: 0
_________________________________________________________________


## 3.加载已经训练的权重，方便继续训练

In [3]:
# 加载已训练权重
transformer.load_weights(save_weight_path)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x139ea2c40>

## 4.加载数据集训练网络

In [75]:
%run AI2Flutter_demo_data.py
# 数据规模
train_seqs_num = 200
validation_seqs_num = 30
# 生成随机数据集
input_data, output_data, output_label = demo_generate_data(train_seqs_num)
vali_input_data, vali_output_data, vali_output_label = demo_generate_data(validation_seqs_num)

# 训练集
train_input = tf.data.Dataset.from_generator(
    lambda: input_data, 
    output_signature=(
        tf.TensorSpec(shape=(None, input_node_dim), dtype=tf.float32)))
train_output = tf.data.Dataset.from_generator(
    lambda: output_data, 
    output_signature=(
        tf.TensorSpec(shape=(None, target_node_dim), dtype=tf.float32)))
train_label = tf.data.Dataset.from_generator(
    lambda: output_label, 
    output_signature=(
        tf.TensorSpec(shape=(None, target_node_dim), dtype=tf.float32)))
train_dataset = tf.data.Dataset.zip(((train_input, train_output), train_label))
# batch设置为4
train_dataset = train_dataset.padded_batch(4)

# 验证集
vali_input = tf.data.Dataset.from_generator(
    lambda: vali_input_data, 
    output_signature=(
        tf.TensorSpec(shape=(None, input_node_dim), dtype=tf.float32)))
vali_output = tf.data.Dataset.from_generator(
    lambda: vali_output_data, 
    output_signature=(
        tf.TensorSpec(shape=(None, target_node_dim), dtype=tf.float32)))
vali_label = tf.data.Dataset.from_generator(
    lambda: vali_output_label, 
    output_signature=(
        tf.TensorSpec(shape=(None, target_node_dim), dtype=tf.float32)))
vali_dataset = tf.data.Dataset.zip(((vali_input, vali_output), vali_label))
vali_dataset = vali_dataset.padded_batch(1)

In [76]:
# 训练
transformer.fit(
    x=train_dataset,
    epochs=500,
#     validation_data=train_dataset
)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

Epoch 94/500
Epoch 95/500
Epoch 96/500
Epoch 97/500
Epoch 98/500
Epoch 99/500
Epoch 100/500
Epoch 101/500
Epoch 102/500
Epoch 103/500
Epoch 104/500
Epoch 105/500
Epoch 106/500
Epoch 107/500
Epoch 108/500
Epoch 109/500
Epoch 110/500
Epoch 111/500
Epoch 112/500
Epoch 113/500
Epoch 114/500
Epoch 115/500
Epoch 116/500
Epoch 117/500
Epoch 118/500
Epoch 119/500
Epoch 120/500
Epoch 121/500
Epoch 122/500
Epoch 123/500
Epoch 124/500
Epoch 125/500
Epoch 126/500
Epoch 127/500
Epoch 128/500
Epoch 129/500
Epoch 130/500
Epoch 131/500
Epoch 132/500
Epoch 133/500
Epoch 134/500
Epoch 135/500
Epoch 136/500
Epoch 137/500
Epoch 138/500
Epoch 139/500
Epoch 140/500
Epoch 141/500
Epoch 142/500
Epoch 143/500
Epoch 144/500
Epoch 145/500
Epoch 146/500
Epoch 147/500
Epoch 148/500
Epoch 149/500
Epoch 150/500
Epoch 151/500
Epoch 152/500
Epoch 153/500
Epoch 154/500
Epoch 155/500
Epoch 156/500
Epoch 157/500
Epoch 158/500
Epoch 159/500
Epoch 160/500
Epoch 161/500
Epoch 162/500
Epoch 163/500
Epoch 164/500
Epoch 165/50

Epoch 185/500
Epoch 186/500
Epoch 187/500
Epoch 188/500
Epoch 189/500
Epoch 190/500
Epoch 191/500
Epoch 192/500
Epoch 193/500
Epoch 194/500
Epoch 195/500
Epoch 196/500
Epoch 197/500
Epoch 198/500
Epoch 199/500
Epoch 200/500
Epoch 201/500
Epoch 202/500
Epoch 203/500
Epoch 204/500
Epoch 205/500
Epoch 206/500
Epoch 207/500
Epoch 208/500
Epoch 209/500
Epoch 210/500
Epoch 211/500
Epoch 212/500
Epoch 213/500
Epoch 214/500
Epoch 215/500
Epoch 216/500
Epoch 217/500
Epoch 218/500
Epoch 219/500
Epoch 220/500
Epoch 221/500
Epoch 222/500
Epoch 223/500
Epoch 224/500
Epoch 225/500
Epoch 226/500
Epoch 227/500
Epoch 228/500
Epoch 229/500
Epoch 230/500
Epoch 231/500
Epoch 232/500
Epoch 233/500
Epoch 234/500
Epoch 235/500
Epoch 236/500
Epoch 237/500
Epoch 238/500
Epoch 239/500
Epoch 240/500
Epoch 241/500
Epoch 242/500
Epoch 243/500
Epoch 244/500
Epoch 245/500
Epoch 246/500
Epoch 247/500
Epoch 248/500
Epoch 249/500
Epoch 250/500
Epoch 251/500
Epoch 252/500
Epoch 253/500
Epoch 254/500
Epoch 255/500
Epoch 

Epoch 276/500
Epoch 277/500
Epoch 278/500
Epoch 279/500
Epoch 280/500
Epoch 281/500
Epoch 282/500
Epoch 283/500
Epoch 284/500
Epoch 285/500
Epoch 286/500
Epoch 287/500
Epoch 288/500
Epoch 289/500
Epoch 290/500
Epoch 291/500
Epoch 292/500
Epoch 293/500
Epoch 294/500
Epoch 295/500
Epoch 296/500
Epoch 297/500
Epoch 298/500
Epoch 299/500
Epoch 300/500
Epoch 301/500
Epoch 302/500
Epoch 303/500
Epoch 304/500
Epoch 305/500
Epoch 306/500
Epoch 307/500
Epoch 308/500
Epoch 309/500
Epoch 310/500
Epoch 311/500
Epoch 312/500
Epoch 313/500
Epoch 314/500
Epoch 315/500
Epoch 316/500
Epoch 317/500
Epoch 318/500
Epoch 319/500
Epoch 320/500
Epoch 321/500
Epoch 322/500
Epoch 323/500
Epoch 324/500
Epoch 325/500
Epoch 326/500
Epoch 327/500
Epoch 328/500
Epoch 329/500
Epoch 330/500
Epoch 331/500
Epoch 332/500
Epoch 333/500
Epoch 334/500
Epoch 335/500
Epoch 336/500
Epoch 337/500
Epoch 338/500
Epoch 339/500
Epoch 340/500
Epoch 341/500
Epoch 342/500
Epoch 343/500
Epoch 344/500
Epoch 345/500
Epoch 346/500
Epoch 

Epoch 367/500
Epoch 368/500
Epoch 369/500
Epoch 370/500
Epoch 371/500
Epoch 372/500
Epoch 373/500
Epoch 374/500
Epoch 375/500
Epoch 376/500
Epoch 377/500
Epoch 378/500
Epoch 379/500
Epoch 380/500
Epoch 381/500
Epoch 382/500
Epoch 383/500
Epoch 384/500
Epoch 385/500
Epoch 386/500
Epoch 387/500
Epoch 388/500
Epoch 389/500
Epoch 390/500
Epoch 391/500
Epoch 392/500
Epoch 393/500
Epoch 394/500
Epoch 395/500
Epoch 396/500
Epoch 397/500
Epoch 398/500
Epoch 399/500
Epoch 400/500
Epoch 401/500
Epoch 402/500
Epoch 403/500
Epoch 404/500
Epoch 405/500
Epoch 406/500
Epoch 407/500
Epoch 408/500
Epoch 409/500
Epoch 410/500
Epoch 411/500
Epoch 412/500
Epoch 413/500
Epoch 414/500
Epoch 415/500
Epoch 416/500
Epoch 417/500
Epoch 418/500
Epoch 419/500
Epoch 420/500
Epoch 421/500
Epoch 422/500
Epoch 423/500
Epoch 424/500
Epoch 425/500
Epoch 426/500
Epoch 427/500
Epoch 428/500
Epoch 429/500
Epoch 430/500
Epoch 431/500
Epoch 432/500
Epoch 433/500
Epoch 434/500
Epoch 435/500
Epoch 436/500
Epoch 437/500
Epoch 

Epoch 458/500
Epoch 459/500
Epoch 460/500
Epoch 461/500
Epoch 462/500
Epoch 463/500
Epoch 464/500
Epoch 465/500
Epoch 466/500
Epoch 467/500
Epoch 468/500
Epoch 469/500
Epoch 470/500
Epoch 471/500
Epoch 472/500
Epoch 473/500
Epoch 474/500
Epoch 475/500
Epoch 476/500
Epoch 477/500
Epoch 478/500
Epoch 479/500
Epoch 480/500
Epoch 481/500
Epoch 482/500
Epoch 483/500
Epoch 484/500
Epoch 485/500
Epoch 486/500
Epoch 487/500
Epoch 488/500
Epoch 489/500
Epoch 490/500
Epoch 491/500
Epoch 492/500
Epoch 493/500
Epoch 494/500
Epoch 495/500
Epoch 496/500
Epoch 497/500
Epoch 498/500
Epoch 499/500
Epoch 500/500


<keras.callbacks.History at 0x16b879100>

## 5.使用网络预测

In [79]:
# # 预测，基于原先的全是1的样本，网络似乎学会了把任何数字全部映射为1
# input1 = tf.zeros((1,5,24))
# input2 = tf.constant([[[
#     12526,-52,-97877773352,45,5,6,7,8,9,10,-22099995,-12,13,999,15,167866687,17,18,19,10,0,0,0,0
# ]]])
# # input2 = tf.ones((1, 1, 24))
# # input2 = tf.zeros((1, 1, 24))
# input2 = tf.cast(input2, tf.float32)
# re = transformer((input1, input2), training=False)
# print(input2)
# print(re)
# print(tf.reduce_sum(tf.abs(re - tf.ones((1, 1, 24)))))

input1, input2, output = demo_generate_data(1)
start = [[[-1 for i in range(output_seq_dim)]]]
p1 = transformer((tf.constant(input1), tf.constant(start)), training=False)
print(p1)
print(output[0][0])
print("***************")
d = np.array(p1 - output[0][0])
print(d)
print(np.sum(d**2))

tf.Tensor(
[[[ 1.3431368e+00  1.3543750e+00 -1.9080779e-01  1.2191708e+04
    1.4264363e+04  8.2335645e+03  7.9567627e+03  2.2426299e+03
    2.4998171e+03  4.0127897e+00  7.7040659e+03  8.9238076e+03
    3.3457300e+03  3.1872917e+03  1.4882181e+03  1.9194377e+03
    2.3226509e+01  8.5134983e+00  8.9365416e+00 -1.8620905e-01
   -1.8570238e-01 -1.8455765e-01 -1.8512613e-01 -1.8593711e-01]]], shape=(1, 1, 24), dtype=float32)
[2, 2, 1, 33521, 38596, 32852, 30431, 25163, 28216, 0, 27979, 35797, 100, 101, 109, 111, 0, 25, 29, 0, 0, 0, 0, 0]
***************
[[[-6.5686321e-01 -6.4562500e-01 -1.1908078e+00 -2.1329293e+04
   -2.4331637e+04 -2.4618436e+04 -2.2474238e+04 -2.2920371e+04
   -2.5716184e+04  4.0127897e+00 -2.0274934e+04 -2.6873191e+04
    3.2457300e+03  3.0862917e+03  1.3792181e+03  1.8084377e+03
    2.3226509e+01 -1.6486502e+01 -2.0063457e+01 -1.8620905e-01
   -1.8570238e-01 -1.8455765e-01 -1.8512613e-01 -1.8593711e-01]]]
4503267000.0


## 6.保存模型的权重，方便下一次训练

In [15]:
#保存训练权重
transformer.save_weights(save_weight_path)

## 7.保存整个模型，方便迁移到其他地方

In [16]:
# 直接加载模型
# transformer = tf.saved_model.load("model2")
# 保存模型
tf.saved_model.save(transformer, save_path)



INFO:tensorflow:Assets written to: ./model/model_1/assets


INFO:tensorflow:Assets written to: ./model/model_1/assets
