From 6d74b95590af81f9fda5de1c88f1ad9ee4ea0487 Mon Sep 17 00:00:00 2001 From: HandH1998 <44199326+HandH1998@users.noreply.github.com> Date: Sun, 7 Mar 2021 18:30:38 +0800 Subject: [PATCH] Update transformer.py fix position embedding --- transformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer.py b/transformer.py index aae3d58..c156278 100644 --- a/transformer.py +++ b/transformer.py @@ -127,7 +127,8 @@ class PositionEmbedding(keras.layers.Layer): def __init__(self, max_len, model_dim, n_vocab): super().__init__() pos = np.arange(max_len)[:, None] - pe = pos / np.power(10000, 2. * np.arange(model_dim)[None, :] / model_dim) # [max_len, dim] +# pe = pos / np.power(10000, 2. * np.arange(model_dim)[None, :] / model_dim) # [max_len, dim] + pe = pos / np.power(10000, 2. * (np.arange(model_dim)[None, :]//2) / model_dim) # [max_len, dim] # 按照position embedding公式来应该是这样 pe[:, 0::2] = np.sin(pe[:, 0::2]) pe[:, 1::2] = np.cos(pe[:, 1::2]) pe = pe[None, :, :] # [1, max_len, model_dim] for batch adding @@ -268,4 +269,4 @@ def export_attention(model, data, name="transformer"): m = Transformer(MODEL_DIM, MAX_LEN, N_LAYER, N_HEAD, d.num_word, DROP_RATE) train(m, d, step=800) - export_attention(m, d) \ No newline at end of file + export_attention(m, d)