Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Add beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Jun 3, 2019
1 parent 59e0a39 commit 76eec0b
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 2 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
[![Travis](https://travis-ci.org/CyberZHG/keras-transformer.svg)](https://travis-ci.org/CyberZHG/keras-transformer)
[![Coverage](https://coveralls.io/repos/github/CyberZHG/keras-transformer/badge.svg?branch=master)](https://coveralls.io/github/CyberZHG/keras-transformer)
[![Version](https://img.shields.io/pypi/v/keras-transformer.svg)](https://pypi.org/project/keras-transformer/)
![Downloads](https://img.shields.io/pypi/dm/keras-transformer.svg)
![License](https://img.shields.io/pypi/l/keras-transformer.svg)

![](https://img.shields.io/badge/keras-tensorflow-blue.svg)
![](https://img.shields.io/badge/keras-theano-blue.svg)
![](https://img.shields.io/badge/keras-tf.keras-blue.svg)
![](https://img.shields.io/badge/keras-tf.keras/eager-blue.svg)

\[[中文](https://github.com/CyberZHG/keras-transformer/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-transformer/blob/master/README.md)\]

Implementation of [transformer](https://arxiv.org/pdf/1706.03762.pdf) for seq2seq tasks.

Expand Down Expand Up @@ -173,3 +182,21 @@ decoded = decode(
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1])))
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1])))
```

### Beam Search

In `decode`, the word with top probability is selected as the predicted token by default. You can enable beam search by setting `top_k` and `temperature`:

```python
decoded = decode(
model,
encode_input,
start_token=target_token_dict['<START>'],
end_token=target_token_dict['<END>'],
pad_token=target_token_dict['<PAD>'],
top_k=10,
temperature=1.0,
)
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1])))
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1])))
```
27 changes: 27 additions & 0 deletions README.zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
[![Travis](https://travis-ci.org/CyberZHG/keras-transformer.svg)](https://travis-ci.org/CyberZHG/keras-transformer)
[![Coverage](https://coveralls.io/repos/github/CyberZHG/keras-transformer/badge.svg?branch=master)](https://coveralls.io/github/CyberZHG/keras-transformer)
[![Version](https://img.shields.io/pypi/v/keras-transformer.svg)](https://pypi.org/project/keras-transformer/)
![Downloads](https://img.shields.io/pypi/dm/keras-transformer.svg)
![License](https://img.shields.io/pypi/l/keras-transformer.svg)

![](https://img.shields.io/badge/keras-tensorflow-blue.svg)
![](https://img.shields.io/badge/keras-theano-blue.svg)
![](https://img.shields.io/badge/keras-tf.keras-blue.svg)
![](https://img.shields.io/badge/keras-tf.keras/eager-blue.svg)

\[[中文](https://github.com/CyberZHG/keras-transformer/blob/master/README.zh-CN.md)|[English](https://github.com/CyberZHG/keras-transformer/blob/master/README.md)\]

[Transformer](https://arxiv.org/pdf/1706.03762.pdf)的实现。

Expand Down Expand Up @@ -173,3 +182,21 @@ decoded = decode(
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1])))
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1])))
```

### 柱搜索

默认参数下,`decode`只使用概率最高的词作为结果。通过调整`top_k``temperature`可以启用柱搜索,较高的温度会使每个词被选中的概率更为平均,而极为接近零的温度相当于`top_k`为1的结果:

```python
decoded = decode(
model,
encode_input,
start_token=target_token_dict['<START>'],
end_token=target_token_dict['<END>'],
pad_token=target_token_dict['<PAD>'],
top_k=10,
temperature=1.0,
)
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1])))
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1])))
```
17 changes: 16 additions & 1 deletion keras_transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ def decode(model,
start_token,
end_token,
pad_token,
top_k=1,
temperature=1.0,
max_len=10000,
max_repeat=10,
max_repeat_block=10):
Expand All @@ -430,6 +432,8 @@ def decode(model,
:param start_token: The token that represents the start of a sentence.
:param end_token: The token that represents the end of a sentence.
:param pad_token: The token that represents padding.
:param top_k: Choose the last token from top K.
:param temperature: Randomness in boltzmann distribution.
:param max_len: Maximum length of decoded list.
:param max_repeat: Maximum number of repeating blocks.
:param max_repeat_block: Maximum length of the repeating block.
Expand Down Expand Up @@ -457,7 +461,18 @@ def decode(model,
batch_inputs[i] += [pad_token] * (max_input_len - len(batch_inputs[i]))
predicts = model.predict([np.array(batch_inputs), np.array(batch_outputs)])
for i in range(len(predicts)):
last_token = np.argmax(predicts[i][-1])
if top_k == 1:
last_token = predicts[i][-1].argmax(axis=-1)
else:
probs = [(prob, i) for i, prob in enumerate(predicts[i][-1])]
probs.sort(reverse=True)
probs = probs[:top_k]
indices, probs = list(map(lambda x: x[1], probs)), list(map(lambda x: x[0], probs))
probs = np.array(probs) / temperature
probs = probs - np.max(probs)
probs = np.exp(probs)
probs = probs / np.sum(probs)
last_token = np.random.choice(indices, p=probs)
decoder_inputs[index_map[i]].append(last_token)
if last_token == end_token or\
(max_len is not None and output_len >= max_len) or\
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name='keras-transformer',
version='0.25.0',
version='0.26.0',
packages=find_packages(),
url='https://github.com/CyberZHG/keras-transformer',
license='MIT',
Expand Down
35 changes: 35 additions & 0 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,38 @@ def test_decode(self):
self.assertTrue(len(decoded[i]) <= 4, decoded[i])
for j in range(len(decoded[i])):
self.assertEqual(decoder_inputs[i][j], decoded[i][j], decoded)

decoded_top_5 = decode(
model,
encoder_inputs_no_padding,
start_token=token_dict['<START>'],
end_token=token_dict['<END>'],
pad_token=token_dict['<PAD>'],
max_len=4,
top_k=5,
temperature=1e-10,
)
has_diff = False
for i in range(len(decoded)):
s1 = ' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1]))
s5 = ' '.join(map(lambda x: token_dict_rev[x], decoded_top_5[i][1:-1]))
if s1 != s5:
has_diff = True
self.assertFalse(has_diff)

decoded_top_5 = decode(
model,
encoder_inputs_no_padding,
start_token=token_dict['<START>'],
end_token=token_dict['<END>'],
pad_token=token_dict['<PAD>'],
max_len=4,
top_k=5,
)
has_diff = False
for i in range(len(decoded)):
s1 = ' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1]))
s5 = ' '.join(map(lambda x: token_dict_rev[x], decoded_top_5[i][1:-1]))
if s1 != s5:
has_diff = True
self.assertTrue(has_diff)

0 comments on commit 76eec0b

Please sign in to comment.