In [1]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")
import sys
sys.path.append("/home/lcz/lenlp/text-generation/py_project")

## 配置GPU使用

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

## 加载数据集

In [3]:
from utils.data_loader import load_dataset, load_test_dataset
from utils.linux_config import embedding_matrix_path
from utils.wv_loader import load_embedding_matrix, load_vocab

root path: /home/lcz/lenlp/text-generation


In [4]:
train_X, train_Y, test_X = load_dataset()

In [5]:
embedding_matrix = load_embedding_matrix("./../data/embedding_matrix")

In [6]:
vocab_path = "./../data/vocab.txt"
vocab, reverse_vocab = load_vocab(vocab_path)

## 配置参数

In [7]:
units = 1024
params = {}
params["vocab_size"] = len(vocab)
params["embed_size"] = 300
params["enc_units"] = units
params["attn_units"] = units
params["dec_units"] = units
params["batch_size"] = 64
params["epochs"] = 2
params["max_enc_len"] = 200
params["max_dec_len"] = 41

## 构建数据集

In [8]:
import tensorflow as tf
steps_per_epoch = len(train_X) // params["batch_size"]
dataset = tf.data.Dataset.from_tensor_slices((train_X, train_Y)).shuffle(params["batch_size"])
dataset = dataset.batch(params["batch_size"], drop_remainder=True)

## 构建模型

In [9]:
from seq2seq.layers import Encoder, BahdanauAttention, Decoder

root path: /home/lcz/lenlp/text-generation


### Encoder

In [10]:
encoder = Encoder(params["vocab_size"], params["embed_size"], embedding_matrix, params["enc_units"], params["batch_size"])
enc_hidden = encoder.initialize_hidden_state()
example_input_batch = tf.ones(shape=(params["batch_size"], params["max_enc_len"]), dtype=tf.int32)
sample_output, sample_hidden = encoder(example_input_batch, enc_hidden)
sample_output.shape
sample_hidden.shape

(64, 200, 300)


TensorShape([64, 1024])

### Decoder

In [11]:
decoder = Decoder(params["vocab_size"], params["embed_size"], embedding_matrix, params["enc_units"], params["batch_size"])
sample_decoder_output, _, _ = decoder(tf.random.uniform((64, 1)), sample_hidden, sample_output)
sample_decoder_output.shape

TensorShape([64, 32909])

## 保存点设置

In [12]:
optimizer = tf.keras.optimizers.Adam(name='Adam',learning_rate=0.001)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')


pad_index=vocab['<PAD>']
nuk_index=vocab['<UNK>']

def loss_function(real, pred):
    pad_mask = tf.math.equal(real, pad_index)
    nuk_mask = tf.math.equal(real, nuk_index)
    mask = tf.math.logical_not(tf.math.logical_or(pad_mask,nuk_mask))
    
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

checkpoint_dir = "./../data/checkpoints/beam_search_training_checkpoints_mask_loss_dim300_seq"
import os
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, encoder=encoder, decoder=decoder)

## 训练模型

In [13]:
@tf.function
def train_step(inp, targ, enc_hidden):
    loss = 0
    
    with tf.GradientTape() as tape:
        # 1. 构建encoder
        enc_output, enc_hidden = encoder(inp, enc_hidden)
        # 2. 复制
        dec_hidden = enc_hidden
        # 3. <START> * BATCH_SIZE 
        dec_input = tf.expand_dims([vocab['<START>']] * params["batch_size"], 1)
        
        # Teacher forcing - feeding the target as the next input
        for t in range(1, targ.shape[1]):
            # decoder(x, hidden, enc_output)
            predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
            
            loss += loss_function(targ[:, t], predictions)

            # using teacher forcing
            dec_input = tf.expand_dims(targ[:, t], 1)

        batch_loss = (loss / int(targ.shape[1]))

        variables = encoder.trainable_variables + decoder.trainable_variables

        gradients = tape.gradient(loss, variables)

        optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss

In [14]:
import time
epochs = params["epochs"]
# 如果检查点存在，则恢复最新的检查点。
if checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)):
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
    print ('Latest checkpoint restored!!')
    
for epoch in range(epochs):
    start = time.time()
    total_loss = 0
    enc_hidden = encoder.initialize_hidden_state()

    for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
        batch_loss = train_step(inp, targ, enc_hidden)
        total_loss += batch_loss

        if batch % 1 == 0:
            print('Epoch {} Step {} Loss {:.4f}'.format(epoch + 1,
                                                         batch,
                                                         batch_loss.numpy()))
    # saving (checkpoint) the model every 2 epochs
    if (epoch + 1) % 2 == 0:
        checkpoint.save(file_prefix = checkpoint_prefix)
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                             checkpoint_prefix))

    print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                      total_loss / steps_per_epoch))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

Latest checkpoint restored!!
(64, 260, 300)
(64, 260, 300)
Epoch 1 Step 0 Loss 2.4456
Epoch 1 Step 1 Loss 2.1865
Epoch 1 Step 2 Loss 2.2299
Epoch 1 Step 3 Loss 1.8850
Epoch 1 Step 4 Loss 2.0126
Epoch 1 Step 5 Loss 1.7054
Epoch 1 Step 6 Loss 1.3285
Epoch 1 Step 7 Loss 1.3981
Epoch 1 Step 8 Loss 1.1917
Epoch 1 Step 9 Loss 1.9490
Epoch 1 Step 10 Loss 1.8808
Epoch 1 Step 11 Loss 1.7249
Epoch 1 Step 12 Loss 1.4042
Epoch 1 Step 13 Loss 1.5886
Epoch 1 Step 14 Loss 1.9183
Epoch 1 Step 15 Loss 2.1339
Epoch 1 Step 16 Loss 1.2378
Epoch 1 Step 17 Loss 1.1591
Epoch 1 Step 18 Loss 1.0553
Epoch 1 Step 19 Loss 0.9926
Epoch 1 Step 20 Loss 1.2389
Epoch 1 Step 21 Loss 1.3238
Epoch 1 Step 22 Loss 1.2575
Epoch 1 Step 23 Loss 1.8039
Epoch 1 Step 24 Loss 1.6078
Epoch 1 Step 25 Loss 1.2910
Epoch 1 Step 26 Loss 1.7685
Epoch 1 Step 27 Loss 1.9162
Epoch 1 Step 28 Loss 2.2273
Epoch 1 Step 29 Loss 2.0795
Epoch 1 Step 30 Loss 1.6395
Epoch 1 Step 31 Loss 1.6040
Epoch 1 Step 32 Loss 1.6771
Epoch 1 Step 33 Loss 1.4080

Epoch 1 Step 285 Loss 1.2565
Epoch 1 Step 286 Loss 1.2878
Epoch 1 Step 287 Loss 1.4938
Epoch 1 Step 288 Loss 1.8572
Epoch 1 Step 289 Loss 1.6122
Epoch 1 Step 290 Loss 1.6246
Epoch 1 Step 291 Loss 1.5670
Epoch 1 Step 292 Loss 1.5945
Epoch 1 Step 293 Loss 1.4165
Epoch 1 Step 294 Loss 1.3098
Epoch 1 Step 295 Loss 1.7176
Epoch 1 Step 296 Loss 1.7223
Epoch 1 Step 297 Loss 1.6943
Epoch 1 Step 298 Loss 1.6069
Epoch 1 Step 299 Loss 1.3955
Epoch 1 Step 300 Loss 1.6190
Epoch 1 Step 301 Loss 1.8479
Epoch 1 Step 302 Loss 1.5356
Epoch 1 Step 303 Loss 1.8408
Epoch 1 Step 304 Loss 1.6648
Epoch 1 Step 305 Loss 1.4143
Epoch 1 Step 306 Loss 1.7094
Epoch 1 Step 307 Loss 1.5851
Epoch 1 Step 308 Loss 1.5710
Epoch 1 Step 309 Loss 1.6333
Epoch 1 Step 310 Loss 1.4250
Epoch 1 Step 311 Loss 1.3075
Epoch 1 Step 312 Loss 1.3422
Epoch 1 Step 313 Loss 1.3995
Epoch 1 Step 314 Loss 1.7133
Epoch 1 Step 315 Loss 1.8418
Epoch 1 Step 316 Loss 1.5623
Epoch 1 Step 317 Loss 1.7417
Epoch 1 Step 318 Loss 1.5493
Epoch 1 Step 3

Epoch 1 Step 568 Loss 1.5719
Epoch 1 Step 569 Loss 1.8428
Epoch 1 Step 570 Loss 1.8017
Epoch 1 Step 571 Loss 1.3925
Epoch 1 Step 572 Loss 1.2540
Epoch 1 Step 573 Loss 1.3535
Epoch 1 Step 574 Loss 1.4959
Epoch 1 Step 575 Loss 1.6224
Epoch 1 Step 576 Loss 1.6155
Epoch 1 Step 577 Loss 1.7281
Epoch 1 Step 578 Loss 1.7099
Epoch 1 Step 579 Loss 1.5102
Epoch 1 Step 580 Loss 1.6561
Epoch 1 Step 581 Loss 1.3339
Epoch 1 Step 582 Loss 1.6021
Epoch 1 Step 583 Loss 1.6234
Epoch 1 Step 584 Loss 1.5980
Epoch 1 Step 585 Loss 1.1881
Epoch 1 Step 586 Loss 1.3655
Epoch 1 Step 587 Loss 1.0363
Epoch 1 Step 588 Loss 1.3266
Epoch 1 Step 589 Loss 1.3545
Epoch 1 Step 590 Loss 1.3855
Epoch 1 Step 591 Loss 1.4297
Epoch 1 Step 592 Loss 1.5634
Epoch 1 Step 593 Loss 1.9073
Epoch 1 Step 594 Loss 1.6187
Epoch 1 Step 595 Loss 1.7126
Epoch 1 Step 596 Loss 1.7541
Epoch 1 Step 597 Loss 1.7585
Epoch 1 Step 598 Loss 1.4500
Epoch 1 Step 599 Loss 1.1699
Epoch 1 Step 600 Loss 1.6550
Epoch 1 Step 601 Loss 1.4049
Epoch 1 Step 6

Epoch 1 Step 851 Loss 1.5697
Epoch 1 Step 852 Loss 1.4947
Epoch 1 Step 853 Loss 1.6987
Epoch 1 Step 854 Loss 1.8966
Epoch 1 Step 855 Loss 1.7274
Epoch 1 Step 856 Loss 1.7030
Epoch 1 Step 857 Loss 1.5574
Epoch 1 Step 858 Loss 1.4223
Epoch 1 Step 859 Loss 1.6918
Epoch 1 Step 860 Loss 1.5309
Epoch 1 Step 861 Loss 1.6159
Epoch 1 Step 862 Loss 1.7107
Epoch 1 Step 863 Loss 1.4933
Epoch 1 Step 864 Loss 1.6759
Epoch 1 Step 865 Loss 1.7065
Epoch 1 Step 866 Loss 1.2270
Epoch 1 Step 867 Loss 1.1809
Epoch 1 Step 868 Loss 2.0018
Epoch 1 Step 869 Loss 1.4630
Epoch 1 Step 870 Loss 1.6152
Epoch 1 Step 871 Loss 1.3056
Epoch 1 Step 872 Loss 1.4176
Epoch 1 Step 873 Loss 1.4448
Epoch 1 Step 874 Loss 1.6528
Epoch 1 Step 875 Loss 1.4799
Epoch 1 Step 876 Loss 1.8080
Epoch 1 Step 877 Loss 1.6132
Epoch 1 Step 878 Loss 1.7325
Epoch 1 Step 879 Loss 1.4573
Epoch 1 Step 880 Loss 1.6093
Epoch 1 Step 881 Loss 1.6043
Epoch 1 Step 882 Loss 1.4053
Epoch 1 Step 883 Loss 1.6657
Epoch 1 Step 884 Loss 1.9029
Epoch 1 Step 8

Epoch 1 Step 1130 Loss 1.3041
Epoch 1 Step 1131 Loss 1.5006
Epoch 1 Step 1132 Loss 1.2238
Epoch 1 Step 1133 Loss 1.5615
Epoch 1 Step 1134 Loss 1.2758
Epoch 1 Step 1135 Loss 1.4416
Epoch 1 Step 1136 Loss 1.5162
Epoch 1 Step 1137 Loss 1.7932
Epoch 1 Step 1138 Loss 1.5023
Epoch 1 Step 1139 Loss 1.3944
Epoch 1 Step 1140 Loss 1.1727
Epoch 1 Step 1141 Loss 1.3250
Epoch 1 Step 1142 Loss 1.2207
Epoch 1 Step 1143 Loss 1.2596
Epoch 1 Step 1144 Loss 1.6981
Epoch 1 Step 1145 Loss 1.4756
Epoch 1 Step 1146 Loss 1.5538
Epoch 1 Step 1147 Loss 1.8674
Epoch 1 Step 1148 Loss 1.3044
Epoch 1 Step 1149 Loss 1.6143
Epoch 1 Step 1150 Loss 1.2896
Epoch 1 Step 1151 Loss 1.6749
Epoch 1 Step 1152 Loss 1.7667
Epoch 1 Step 1153 Loss 1.4632
Epoch 1 Step 1154 Loss 1.3810
Epoch 1 Step 1155 Loss 1.6846
Epoch 1 Step 1156 Loss 1.5247
Epoch 1 Step 1157 Loss 1.5920
Epoch 1 Step 1158 Loss 1.6724
Epoch 1 Step 1159 Loss 1.7008
Epoch 1 Step 1160 Loss 1.3471
Epoch 1 Step 1161 Loss 1.4235
Epoch 1 Step 1162 Loss 1.5317
Epoch 1 St

Epoch 2 Step 115 Loss 1.2477
Epoch 2 Step 116 Loss 0.9601
Epoch 2 Step 117 Loss 1.0053
Epoch 2 Step 118 Loss 1.2999
Epoch 2 Step 119 Loss 1.4113
Epoch 2 Step 120 Loss 1.4942
Epoch 2 Step 121 Loss 1.6900
Epoch 2 Step 122 Loss 1.2280
Epoch 2 Step 123 Loss 1.2649
Epoch 2 Step 124 Loss 1.3153
Epoch 2 Step 125 Loss 1.3872
Epoch 2 Step 126 Loss 1.9023
Epoch 2 Step 127 Loss 1.2585
Epoch 2 Step 128 Loss 1.4062
Epoch 2 Step 129 Loss 1.2511
Epoch 2 Step 130 Loss 1.0929
Epoch 2 Step 131 Loss 1.2582
Epoch 2 Step 132 Loss 1.1345
Epoch 2 Step 133 Loss 1.0617
Epoch 2 Step 134 Loss 1.7335
Epoch 2 Step 135 Loss 1.8863
Epoch 2 Step 136 Loss 1.7920
Epoch 2 Step 137 Loss 1.3440
Epoch 2 Step 138 Loss 1.3865
Epoch 2 Step 139 Loss 1.3282
Epoch 2 Step 140 Loss 1.4499
Epoch 2 Step 141 Loss 1.5549
Epoch 2 Step 142 Loss 1.4428
Epoch 2 Step 143 Loss 1.2227
Epoch 2 Step 144 Loss 1.2626
Epoch 2 Step 145 Loss 1.1983
Epoch 2 Step 146 Loss 1.2423
Epoch 2 Step 147 Loss 1.1212
Epoch 2 Step 148 Loss 1.3069
Epoch 2 Step 1

Epoch 2 Step 398 Loss 1.0727
Epoch 2 Step 399 Loss 1.2374
Epoch 2 Step 400 Loss 1.2682
Epoch 2 Step 401 Loss 1.1327
Epoch 2 Step 402 Loss 1.1833
Epoch 2 Step 403 Loss 1.5080
Epoch 2 Step 404 Loss 1.5451
Epoch 2 Step 405 Loss 1.6545
Epoch 2 Step 406 Loss 1.3269
Epoch 2 Step 407 Loss 1.5589
Epoch 2 Step 408 Loss 1.5570
Epoch 2 Step 409 Loss 1.9131
Epoch 2 Step 410 Loss 1.5480
Epoch 2 Step 411 Loss 1.2065
Epoch 2 Step 412 Loss 1.3845
Epoch 2 Step 413 Loss 1.6209
Epoch 2 Step 414 Loss 1.4569
Epoch 2 Step 415 Loss 1.5612
Epoch 2 Step 416 Loss 1.2086
Epoch 2 Step 417 Loss 1.2382
Epoch 2 Step 418 Loss 1.2134
Epoch 2 Step 419 Loss 1.0444
Epoch 2 Step 420 Loss 1.3883
Epoch 2 Step 421 Loss 1.3934
Epoch 2 Step 422 Loss 1.3694
Epoch 2 Step 423 Loss 1.3063
Epoch 2 Step 424 Loss 1.2933
Epoch 2 Step 425 Loss 1.4188
Epoch 2 Step 426 Loss 1.3730
Epoch 2 Step 427 Loss 1.3535
Epoch 2 Step 428 Loss 1.1966
Epoch 2 Step 429 Loss 1.2620
Epoch 2 Step 430 Loss 1.4871
Epoch 2 Step 431 Loss 1.5845
Epoch 2 Step 4

Epoch 2 Step 681 Loss 1.5401
Epoch 2 Step 682 Loss 1.1190
Epoch 2 Step 683 Loss 1.3389
Epoch 2 Step 684 Loss 1.2423
Epoch 2 Step 685 Loss 1.4725
Epoch 2 Step 686 Loss 1.4918
Epoch 2 Step 687 Loss 1.1671
Epoch 2 Step 688 Loss 0.8373
Epoch 2 Step 689 Loss 1.2598
Epoch 2 Step 690 Loss 1.4465
Epoch 2 Step 691 Loss 1.2563
Epoch 2 Step 692 Loss 1.4369
Epoch 2 Step 693 Loss 1.3136
Epoch 2 Step 694 Loss 1.3515
Epoch 2 Step 695 Loss 1.3226
Epoch 2 Step 696 Loss 1.2912
Epoch 2 Step 697 Loss 1.3642
Epoch 2 Step 698 Loss 1.3943
Epoch 2 Step 699 Loss 1.1426
Epoch 2 Step 700 Loss 1.0780
Epoch 2 Step 701 Loss 1.2616
Epoch 2 Step 702 Loss 1.3106
Epoch 2 Step 703 Loss 1.7230
Epoch 2 Step 704 Loss 1.4909
Epoch 2 Step 705 Loss 1.2866
Epoch 2 Step 706 Loss 1.1702
Epoch 2 Step 707 Loss 0.9458
Epoch 2 Step 708 Loss 1.1398
Epoch 2 Step 709 Loss 1.3713
Epoch 2 Step 710 Loss 1.2544
Epoch 2 Step 711 Loss 1.4465
Epoch 2 Step 712 Loss 1.2190
Epoch 2 Step 713 Loss 1.2865
Epoch 2 Step 714 Loss 1.2020
Epoch 2 Step 7

Epoch 2 Step 964 Loss 1.3941
Epoch 2 Step 965 Loss 1.3389
Epoch 2 Step 966 Loss 1.4772
Epoch 2 Step 967 Loss 1.5092
Epoch 2 Step 968 Loss 1.5010
Epoch 2 Step 969 Loss 1.5444
Epoch 2 Step 970 Loss 1.5623
Epoch 2 Step 971 Loss 1.5168
Epoch 2 Step 972 Loss 1.5240
Epoch 2 Step 973 Loss 1.6286
Epoch 2 Step 974 Loss 1.3423
Epoch 2 Step 975 Loss 1.3782
Epoch 2 Step 976 Loss 1.2018
Epoch 2 Step 977 Loss 1.8924
Epoch 2 Step 978 Loss 1.6991
Epoch 2 Step 979 Loss 1.3151
Epoch 2 Step 980 Loss 1.5550
Epoch 2 Step 981 Loss 1.5112
Epoch 2 Step 982 Loss 1.2432
Epoch 2 Step 983 Loss 1.3353
Epoch 2 Step 984 Loss 1.7030
Epoch 2 Step 985 Loss 1.7976
Epoch 2 Step 986 Loss 1.2433
Epoch 2 Step 987 Loss 1.4414
Epoch 2 Step 988 Loss 1.6484
Epoch 2 Step 989 Loss 1.1843
Epoch 2 Step 990 Loss 1.3600
Epoch 2 Step 991 Loss 1.2062
Epoch 2 Step 992 Loss 1.6143
Epoch 2 Step 993 Loss 1.3164
Epoch 2 Step 994 Loss 1.2919
Epoch 2 Step 995 Loss 1.1616
Epoch 2 Step 996 Loss 1.4395
Epoch 2 Step 997 Loss 1.2829
Epoch 2 Step 9

Epoch 2 Step 1239 Loss 0.9701
Epoch 2 Step 1240 Loss 1.2955
Epoch 2 Step 1241 Loss 1.1022
Epoch 2 Step 1242 Loss 1.0581
Epoch 2 Step 1243 Loss 1.5035
Epoch 2 Step 1244 Loss 1.4444
Epoch 2 Step 1245 Loss 1.0435
Epoch 2 Step 1246 Loss 1.2159
Epoch 2 Step 1247 Loss 1.3651
Epoch 2 Step 1248 Loss 1.2752
Epoch 2 Step 1249 Loss 1.0896
Epoch 2 Step 1250 Loss 1.0371
Epoch 2 Step 1251 Loss 1.1840
Epoch 2 Step 1252 Loss 1.1478
Epoch 2 Step 1253 Loss 1.1368
Epoch 2 Step 1254 Loss 1.2614
Epoch 2 Step 1255 Loss 1.3705
Epoch 2 Step 1256 Loss 1.0719
Epoch 2 Step 1257 Loss 1.1358
Epoch 2 Step 1258 Loss 1.1116
Epoch 2 Step 1259 Loss 1.3290
Epoch 2 Step 1260 Loss 1.4149
Epoch 2 Step 1261 Loss 1.3824
Epoch 2 Step 1262 Loss 1.6379
Epoch 2 Step 1263 Loss 1.5932
Epoch 2 Step 1264 Loss 1.3232
Epoch 2 Step 1265 Loss 1.2185
Epoch 2 Step 1266 Loss 1.3543
Epoch 2 Step 1267 Loss 1.1244
Epoch 2 Step 1268 Loss 1.1191
Epoch 2 Step 1269 Loss 1.0153
Epoch 2 Step 1270 Loss 1.1007
Epoch 2 Step 1271 Loss 1.1421
Epoch 2 St

## beam search

In [None]:
class Hypothesis:
    """ Class designed to hold hypothesises throughout the beamSearch decoding """

    def __init__(self, tokens, log_probs, hidden, attn_dists):
        self.tokens = tokens  # list of all the tokens from time 0 to the current time step t
        self.log_probs = log_probs  # list of the log probabilities of the tokens of the tokens
        self.hidden = hidden  # decoder hidden state after the last token decoding
        self.attn_dists = attn_dists  # attention dists of all the tokens
        self.abstract = ""

    def extend(self, token, log_prob, hidden, attn_dist):
        """Method to extend the current hypothesis by adding the next decoded token and all the informations associated with it"""
        return Hypothesis(tokens=self.tokens + [token],  # we add the decoded token
                          log_probs=self.log_probs + [log_prob],  # we add the log prob of the decoded token
                          hidden=hidden,  # we update the state
                          attn_dists=self.attn_dists + [attn_dist])

    @property
    def latest_token(self):
        return self.tokens[-1]

    @property
    def tot_log_prob(self):
        return sum(self.log_probs)

    @property
    def avg_log_prob(self):
        return self.tot_log_prob / len(self.tokens)


def beam_decode(model, batch, vocab, params):
    # 初始化mask
    start_index = vocab.word_to_id(vocab.START_DECODING)
    stop_index = vocab.word_to_id(vocab.STOP_DECODING)
    unk_index = vocab.word_to_id(vocab.UNKNOWN_TOKEN)

    batch_size = params['batch_size']

    # 单步decoder
    def decoder_onestep(enc_output, dec_input, dec_hidden, enc_extended_inp, batch_oov_len):
        # 单个时间步 运行
        # dec_input, dec_hidden, enc_output, enc_extended_inp, batch_oov_len
        final_preds, dec_hidden, context_vector, attention_weights, p_gens = model.call_decoder_onestep(dec_input,
                                                                                                        dec_hidden,
                                                                                                        enc_output,
                                                                                                        enc_extended_inp,
                                                                                                        batch_oov_len)
        # 拿到top k个index 和 概率
        top_k_probs, top_k_ids = tf.nn.top_k(tf.squeeze(final_preds), k=params["beam_size"] * 2)
        # 计算log概率
        top_k_log_probs = tf.math.log(top_k_probs)

        results = {
            # 'final_dists': preds,
            "last_context_vector": context_vector,
            "dec_hidden": dec_hidden,
            "attention_weights": attention_weights,
            "top_k_ids": top_k_ids,
            "top_k_log_probs": top_k_log_probs,
            "p_gen": p_gens}

        # 返回需要保存的中间结果和概率
        return results

    # 测试数据的输入
    enc_input = batch[0]["enc_input"]

    # 计算第encoder的输出
    enc_output, enc_hidden = model.call_encoder(enc_input)

    # 初始化batch size个 假设对象
    hyps = [Hypothesis(tokens=[start_index],
                       log_probs=[0.0],
                       hidden=enc_hidden[0],
                       attn_dists=[],
                       ) for _ in range(batch_size)]
    # 初始化结果集
    results = []  # list to hold the top beam_size hypothesises
    # 遍历步数
    steps = 0  # initial step

    enc_extended_inp = batch[0]["extended_enc_input"]
    batch_oov_len = batch[0]["max_oov_len"]

    # 长度还不够 并且 结果还不够 继续搜索
    while steps < params['max_dec_len'] and len(results) < params['beam_size']:
        # 获取最新待使用的token
        latest_tokens = [h.latest_token for h in hyps]
        # 替换掉 oov token unknown token
        latest_tokens = [t if t in vocab.id2word else unk_index for t in latest_tokens]
        # 获取所以隐藏层状态
        hiddens = [h.hidden for h in hyps]

        dec_input = tf.expand_dims(latest_tokens, axis=1)
        dec_hidden = tf.stack(hiddens, axis=0)

        # 单步运行decoder 计算需要的值
        decoder_results = decoder_onestep(enc_output,
                                          dec_input,
                                          dec_hidden,
                                          enc_extended_inp,
                                          batch_oov_len)

        # preds = decoder_results['final_dists']
        # context_vector = decoder_results['last_context_vector']

        dec_hidden = decoder_results['dec_hidden']
        attention_weights = decoder_results['attention_weights']
        top_k_log_probs = decoder_results['top_k_log_probs']
        top_k_ids = decoder_results['top_k_ids']

        # print('top_k_ids {}'.format(top_k_ids))

        # 现阶段全部可能情况
        all_hyps = []
        # 原有的可能情况数量
        num_orig_hyps = 1 if steps == 0 else len(hyps)

        # 遍历添加所有可能结果
        for i in range(num_orig_hyps):
            h, new_hidden, attn_dist = hyps[i], dec_hidden[i], attention_weights[i]
            # 分裂 添加 beam size 种可能性
            for j in range(params['beam_size'] * 2):
                # 构造可能的情况
                new_hyp = h.extend(token=top_k_ids[i, j].numpy(),
                                   log_prob=top_k_log_probs[i, j],
                                   hidden=new_hidden,
                                   attn_dist=attn_dist)
                # 添加可能情况
                all_hyps.append(new_hyp)

        # 重置
        hyps = []
        # 按照概率来排序
        sorted_hyps = sorted(all_hyps, key=lambda h: h.avg_log_prob, reverse=True)

        # 筛选top前beam_size句话
        for h in sorted_hyps:
            if h.latest_token == stop_index:
                # 长度符合预期,遇到句尾,添加到结果集
                if steps >= params['min_dec_steps']:
                    results.append(h)
            else:
                # 未到结束 ,添加到假设集
                hyps.append(h)

            # 如果假设句子正好等于beam_size 或者结果集正好等于beam_size 就不在添加
            if len(hyps) == params['beam_size'] or len(results) == params['beam_size']:
                break

        steps += 1

    if len(results) == 0:
        results = hyps

    hyps_sorted = sorted(results, key=lambda h: h.avg_log_prob, reverse=True)
    best_hyp = hyps_sorted[0]
    best_hyp.abstract = " ".join([vocab.id_to_word(index) for index in best_hyp.tokens])
    best_hyp.text = batch[0]["article"].numpy()[0].decode()
    
    return best_hyp