Skip to content

Commit

Permalink
Merge pull request #24 from yaolu/patch-1
Browse files Browse the repository at this point in the history
12 times speedup for inference
  • Loading branch information
Alex Fabbri committed May 12, 2020
2 parents 0066d99 + 1e7f4c8 commit f6476d1
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions code/Hi_MAP/onmt/translate/beam.py
Expand Up @@ -34,11 +34,13 @@ def __init__(self, size, pad, bos, eos,

# The backpointers at each time-step.
self.prev_ks = []
self.prev_ks_cpu = []

# The outputs at each time-step.
self.next_ys = [self.tt.LongTensor(size)
.fill_(pad)]
self.next_ys[0][0] = bos
self.next_ys_cpu = [elem.tolist() for elem in self.next_ys]

# Has EOS topped the beam yet.
self._eos = eos
Expand Down Expand Up @@ -89,30 +91,26 @@ def advance(self, word_probs, attn_out):
# force the output to be longer than self.min_length
cur_len = len(self.next_ys)
if cur_len < self.min_length:
for k in range(len(word_probs)):
word_probs[k][self._eos] = -1e20
word_probs[:, self._eos] = -1e20
# Sum the previous scores.
if len(self.prev_ks) > 0:
beam_scores = word_probs + \
self.scores.unsqueeze(1).expand_as(word_probs)
# Don't let EOS have children.
for i in range(self.next_ys[-1].size(0)):
if self.next_ys[-1][i] == self._eos:
beam_scores[i] = -1e20

# Don't let EOS have children.
beam_scores[self.next_ys[-1] == self._eos] = -1e20
# Block ngram repeats
if self.block_ngram_repeat > 0:
ngrams = []
le = len(self.next_ys)
for j in range(self.next_ys[-1].size(0)):
hyp, _ = self.get_hyp(le - 1, j)
hyp, _ = self.get_hyp(le - 1, j, requires_attn=False)
ngrams = set()
fail = False
gram = []
for i in range(le - 1):
# Last n tokens, n = block_ngram_repeat
gram = (gram +
[hyp[i].item()])[-self.block_ngram_repeat:]
[hyp[i]])[-self.block_ngram_repeat:]
# Skip the blocking if it is in the exclusion list
if set(gram) & self.exclusion_tokens:
continue
Expand All @@ -134,15 +132,27 @@ def advance(self, word_probs, attn_out):
# word and beam each score came from
prev_k = best_scores_id / num_words
self.prev_ks.append(prev_k)
self.prev_ks_cpu.append(prev_k.tolist())

self.next_ys.append((best_scores_id - prev_k * num_words))
self.next_ys_cpu.append((best_scores_id - prev_k * num_words).tolist())

self.attn.append(attn_out.index_select(0, prev_k))
self.global_scorer.update_global_state(self)

for i in range(self.next_ys[-1].size(0)):
if self.next_ys[-1][i] == self._eos:
global_scores = self.global_scorer.score(self, self.scores)
s = global_scores[i]

eos_indicator = self.next_ys[-1] == self._eos
if eos_indicator.any():
global_scores = self.global_scorer.score(self, self.scores)
global_scores_eos = global_scores[eos_indicator]
i_indexes = torch.where(eos_indicator)[0]
for s, i, in zip(global_scores_eos.tolist(), i_indexes.tolist()):
self.finished.append((s, len(self.next_ys) - 1, i))

# for i in range(self.next_ys[-1].size(0)):
# if self.next_ys[-1][i] == self._eos:
# global_scores = self.global_scorer.score(self, self.scores)
# s = global_scores[i]
# self.finished.append((s, len(self.next_ys) - 1, i))

# End condition is when top-of-beam is EOS and no global score.
if self.next_ys[-1][0] == self._eos:
Expand All @@ -167,16 +177,20 @@ def sort_finished(self, minimum=None):
ks = [(t, k) for _, t, k in self.finished]
return scores, ks

def get_hyp(self, timestep, k):
def get_hyp(self, timestep, k, requires_attn=True):
"""
Walk back to construct the full hypothesis.
"""
hyp, attn = [], []
for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
hyp.append(self.next_ys[j + 1][k])
attn.append(self.attn[j][k])
k = self.prev_ks[j][k]
return hyp[::-1], torch.stack(attn[::-1])
hyp.append(self.next_ys_cpu[j + 1][k])
if requires_attn:
attn.append(self.attn[j][k])
k = self.prev_ks_cpu[j][k]

if requires_attn:
attn = torch.stack(attn[::-1])
return hyp[::-1], attn


class GNMTGlobalScorer(object):
Expand Down

0 comments on commit f6476d1

Please sign in to comment.