From 76eec0bcb70af69f76190d39984b62bfa3fc5fa1 Mon Sep 17 00:00:00 2001 From: CyberZHG Date: Mon, 3 Jun 2019 19:28:47 +0800 Subject: [PATCH] Add beam search --- README.md | 27 ++++++++++++++++++++++++ README.zh-CN.md | 27 ++++++++++++++++++++++++ keras_transformer/transformer.py | 17 +++++++++++++++- setup.py | 2 +- tests/test_decode.py | 35 ++++++++++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index eaaa7bc..7896229 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,15 @@ [![Travis](https://travis-ci.org/CyberZHG/keras-transformer.svg)](https://travis-ci.org/CyberZHG/keras-transformer) [![Coverage](https://coveralls.io/repos/github/CyberZHG/keras-transformer/badge.svg?branch=master)](https://coveralls.io/github/CyberZHG/keras-transformer) [![Version](https://img.shields.io/pypi/v/keras-transformer.svg)](https://pypi.org/project/keras-transformer/) +![Downloads](https://img.shields.io/pypi/dm/keras-transformer.svg) +![License](https://img.shields.io/pypi/l/keras-transformer.svg) + +![](https://img.shields.io/badge/keras-tensorflow-blue.svg) +![](https://img.shields.io/badge/keras-theano-blue.svg) +![](https://img.shields.io/badge/keras-tf.keras-blue.svg) +![](https://img.shields.io/badge/keras-tf.keras/eager-blue.svg) + + \[[中文](https://github.com/CyberZHG/keras-transformer/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-transformer/blob/master/README.md)\] Implementation of [transformer](https://arxiv.org/pdf/1706.03762.pdf) for seq2seq tasks. @@ -173,3 +182,21 @@ decoded = decode( print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1]))) print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1]))) ``` + +### Beam Search + +In `decode`, the word with top probability is selected as the predicted token by default. You can enable beam search by setting `top_k` and `temperature`: + +```python +decoded = decode( + model, + encode_input, + start_token=target_token_dict[''], + end_token=target_token_dict[''], + pad_token=target_token_dict[''], + top_k=10, + temperature=1.0, +) +print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1]))) +print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1]))) +``` diff --git a/README.zh-CN.md b/README.zh-CN.md index af5e57f..fa1ecbd 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -3,6 +3,15 @@ [![Travis](https://travis-ci.org/CyberZHG/keras-transformer.svg)](https://travis-ci.org/CyberZHG/keras-transformer) [![Coverage](https://coveralls.io/repos/github/CyberZHG/keras-transformer/badge.svg?branch=master)](https://coveralls.io/github/CyberZHG/keras-transformer) [![Version](https://img.shields.io/pypi/v/keras-transformer.svg)](https://pypi.org/project/keras-transformer/) +![Downloads](https://img.shields.io/pypi/dm/keras-transformer.svg) +![License](https://img.shields.io/pypi/l/keras-transformer.svg) + +![](https://img.shields.io/badge/keras-tensorflow-blue.svg) +![](https://img.shields.io/badge/keras-theano-blue.svg) +![](https://img.shields.io/badge/keras-tf.keras-blue.svg) +![](https://img.shields.io/badge/keras-tf.keras/eager-blue.svg) + + \[[中文](https://github.com/CyberZHG/keras-transformer/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-transformer/blob/master/README.md)\] [Transformer](https://arxiv.org/pdf/1706.03762.pdf)的实现。 @@ -173,3 +182,21 @@ decoded = decode( print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1]))) print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1]))) ``` + +### 柱搜索 + +默认参数下,`decode`只使用概率最高的词作为结果。通过调整`top_k`和`temperature`可以启用柱搜索,较高的温度会使每个词被选中的概率更为平均,而极为接近零的温度相当于`top_k`为1的结果: + +```python +decoded = decode( + model, + encode_input, + start_token=target_token_dict[''], + end_token=target_token_dict[''], + pad_token=target_token_dict[''], + top_k=10, + temperature=1.0, +) +print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1]))) +print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1]))) +``` diff --git a/keras_transformer/transformer.py b/keras_transformer/transformer.py index 2fd1049..6960f1c 100644 --- a/keras_transformer/transformer.py +++ b/keras_transformer/transformer.py @@ -420,6 +420,8 @@ def decode(model, start_token, end_token, pad_token, + top_k=1, + temperature=1.0, max_len=10000, max_repeat=10, max_repeat_block=10): @@ -430,6 +432,8 @@ def decode(model, :param start_token: The token that represents the start of a sentence. :param end_token: The token that represents the end of a sentence. :param pad_token: The token that represents padding. + :param top_k: Choose the last token from top K. + :param temperature: Randomness in boltzmann distribution. :param max_len: Maximum length of decoded list. :param max_repeat: Maximum number of repeating blocks. :param max_repeat_block: Maximum length of the repeating block. @@ -457,7 +461,18 @@ def decode(model, batch_inputs[i] += [pad_token] * (max_input_len - len(batch_inputs[i])) predicts = model.predict([np.array(batch_inputs), np.array(batch_outputs)]) for i in range(len(predicts)): - last_token = np.argmax(predicts[i][-1]) + if top_k == 1: + last_token = predicts[i][-1].argmax(axis=-1) + else: + probs = [(prob, i) for i, prob in enumerate(predicts[i][-1])] + probs.sort(reverse=True) + probs = probs[:top_k] + indices, probs = list(map(lambda x: x[1], probs)), list(map(lambda x: x[0], probs)) + probs = np.array(probs) / temperature + probs = probs - np.max(probs) + probs = np.exp(probs) + probs = probs / np.sum(probs) + last_token = np.random.choice(indices, p=probs) decoder_inputs[index_map[i]].append(last_token) if last_token == end_token or\ (max_len is not None and output_len >= max_len) or\ diff --git a/setup.py b/setup.py index 151f766..18ad6e9 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( name='keras-transformer', - version='0.25.0', + version='0.26.0', packages=find_packages(), url='https://github.com/CyberZHG/keras-transformer', license='MIT', diff --git a/tests/test_decode.py b/tests/test_decode.py index da0842e..3dee6b1 100644 --- a/tests/test_decode.py +++ b/tests/test_decode.py @@ -100,3 +100,38 @@ def test_decode(self): self.assertTrue(len(decoded[i]) <= 4, decoded[i]) for j in range(len(decoded[i])): self.assertEqual(decoder_inputs[i][j], decoded[i][j], decoded) + + decoded_top_5 = decode( + model, + encoder_inputs_no_padding, + start_token=token_dict[''], + end_token=token_dict[''], + pad_token=token_dict[''], + max_len=4, + top_k=5, + temperature=1e-10, + ) + has_diff = False + for i in range(len(decoded)): + s1 = ' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1])) + s5 = ' '.join(map(lambda x: token_dict_rev[x], decoded_top_5[i][1:-1])) + if s1 != s5: + has_diff = True + self.assertFalse(has_diff) + + decoded_top_5 = decode( + model, + encoder_inputs_no_padding, + start_token=token_dict[''], + end_token=token_dict[''], + pad_token=token_dict[''], + max_len=4, + top_k=5, + ) + has_diff = False + for i in range(len(decoded)): + s1 = ' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1])) + s5 = ' '.join(map(lambda x: token_dict_rev[x], decoded_top_5[i][1:-1])) + if s1 != s5: + has_diff = True + self.assertTrue(has_diff)