Skip to content

Commit

Permalink
Adapt to fill_batch_size_like and make outputs between fast_infer and…
Browse files Browse the repository at this point in the history
… the original python infer alignment in Transformer
  • Loading branch information
guoshengCS committed Jun 14, 2018
1 parent 3e9fcce commit 0712608
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions fluid/neural_machine_translation/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
"""
scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
# global FLAG
# if FLAG and attn_bias:
# print "hehehehehe"
# layers.Print(product, message="product")
# layers.Print(attn_bias, message="bias")
# FLAG = False
weights = layers.reshape(
x=layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product,
Expand Down Expand Up @@ -598,7 +604,7 @@ def wrap_decoder(trg_vocab_size,
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax") # if dec_inputs is None else None)
act="softmax" if dec_inputs is None else None)
return predict


Expand Down Expand Up @@ -656,27 +662,31 @@ def beam_search():
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
pre_scores = layers.array_read(array=scores, i=step_idx)
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_ids, value=1, shape=[-1, 1], dtype=pre_ids.dtype),
y=layers.increment(
x=step_idx, value=1.0, in_place=False),
axis=0)
pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_scores)
# layers.Print(pre_src_attn_bias)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
pre_caches = [{
"k": layers.sequence_expand(
x=cache["k"], y=pre_scores),
"v": layers.sequence_expand(
x=cache["v"], y=pre_scores),
} for cache in caches]
# layers.Print(pre_ids)
# layers.Print(pre_pos)
# layers.Print(pre_enc_output)
# layers.Print(pre_src_attn_bias)
# layers.Print(pre_caches[0]["k"])
# layers.Print(pre_caches[0]["v"])
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_enc_output, # cann't use pre_ids here since it has lod
value=1,
shape=[-1, 1],
dtype=pre_ids.dtype),
y=layers.increment(
x=step_idx, value=1.0, in_place=False),
axis=0)
# layers.Print(pre_ids, summarize=10)
# layers.Print(pre_pos, summarize=10)
# layers.Print(pre_enc_output, summarize=10)
# layers.Print(pre_src_attn_bias, summarize=10)
# layers.Print(pre_caches[0]["k"], summarize=10)
# layers.Print(pre_caches[0]["v"], summarize=10)
# layers.Print(slf_attn_post_softmax_shape)
logits = wrap_decoder(
trg_vocab_size,
Expand All @@ -695,7 +705,8 @@ def beam_search():
enc_output=pre_enc_output,
caches=pre_caches)
# layers.Print(logits)
topk_scores, topk_indices = layers.topk(logits, k=beam_size)
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
# layers.Print(topk_scores)
# layers.Print(topk_indices)
accu_scores = layers.elementwise_add(
Expand Down

0 comments on commit 0712608

Please sign in to comment.