Skip to content

Commit

Permalink
Implement repeat and beam_update via map_batch_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Oct 30, 2018
1 parent 82d8bb0 commit 9fef5b2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 64 deletions.
14 changes: 4 additions & 10 deletions onmt/decoders/cnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,14 @@ def __init__(self, memory_bank, enc_hidden):
self.init_src = (memory_bank + enc_hidden) * SCALE_WEIGHT
self.previous_input = None

@property
def _all(self):
"""
Contains attributes that need to be updated in self.beam_update().
"""
return (self.previous_input,)

def detach(self):
self.previous_input = self.previous_input.detach()

def update_state(self, new_input):
""" Called for every decoder forward pass. """
self.previous_input = new_input

def repeat_beam_size_times(self, beam_size):
""" Repeat beam_size times along batch dimension. """
self.init_src = self.init_src.data.repeat(1, beam_size, 1)
def map_batch_fn(self, fn):
self.init_src = fn(self.init_src, 1)
if self.previous_input is not None:
self.previous_input = fn(self.previous_input, 1)
28 changes: 0 additions & 28 deletions onmt/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,23 +390,6 @@ def detach(self):
self.hidden = tuple([_.detach() for _ in self.hidden])
self.input_feed = self.input_feed.detach()

def beam_update(self, idx, positions, beam_size):
""" Need to document this """
for e in self._all:
sizes = e.size()
br = sizes[1]
if len(sizes) == 3:
sent_states = e.view(sizes[0], beam_size, br // beam_size,
sizes[2])[:, :, idx]
else:
sent_states = e.view(sizes[0], beam_size,
br // beam_size,
sizes[2],
sizes[3])[:, :, idx]

sent_states.data.copy_(
sent_states.data.index_select(1, positions))

def map_batch_fn(self, fn):
raise NotImplementedError()

Expand All @@ -433,10 +416,6 @@ def __init__(self, hidden_size, rnnstate):
self.input_feed = self.hidden[0].data.new(*h_size).zero_() \
.unsqueeze(0)

@property
def _all(self):
return self.hidden + (self.input_feed,)

def update_state(self, rnnstate, input_feed, coverage):
""" Update decoder state """
if not isinstance(rnnstate, tuple):
Expand All @@ -446,13 +425,6 @@ def update_state(self, rnnstate, input_feed, coverage):
self.input_feed = input_feed
self.coverage = coverage

def repeat_beam_size_times(self, beam_size):
""" Repeat beam_size times along batch dimension. """
vars = [e.data.repeat(1, beam_size, 1)
for e in self._all]
self.hidden = tuple(vars[:-1])
self.input_feed = vars[-1]

def map_batch_fn(self, fn):
self.hidden = tuple(map(lambda x: fn(x, 1), self.hidden))
self.input_feed = fn(self.input_feed, 1)
9 changes: 2 additions & 7 deletions onmt/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,9 @@ class EnsembleDecoderState(DecoderState):
def __init__(self, model_decoder_states):
self.model_decoder_states = tuple(model_decoder_states)

def beam_update(self, idx, positions, beam_size):
def map_batch_fn(self, fn):
for model_state in self.model_decoder_states:
model_state.beam_update(idx, positions, beam_size)

def repeat_beam_size_times(self, beam_size):
""" Repeat beam_size times along batch dimension. """
for model_state in self.model_decoder_states:
model_state.repeat_beam_size_times(beam_size)
model_state.map_batch_fn(fn)

def __getitem__(self, index):
return self.model_decoder_states[index]
Expand Down
21 changes: 4 additions & 17 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,19 +260,6 @@ def __init__(self, src):
self.previous_layer_inputs = None
self.cache = None

@property
def _all(self):
"""
Contains attributes that need to be updated in self.beam_update().
"""
if (self.previous_input is not None
and self.previous_layer_inputs is not None):
return (self.previous_input,
self.previous_layer_inputs,
self.src)
else:
return (self.src,)

def detach(self):
if self.previous_input is not None:
self.previous_input = self.previous_input.detach()
Expand Down Expand Up @@ -306,10 +293,6 @@ def _init_cache(self, memory_bank, num_layers, self_attn_type):
layer_cache["self_values"] = None
self.cache["layer_{}".format(l)] = layer_cache

def repeat_beam_size_times(self, beam_size):
""" Repeat beam_size times along batch dimension. """
self.src = self.src.data.repeat(1, beam_size, 1)

def map_batch_fn(self, fn):
def _recursive_map(struct, batch_dim=0):
for k, v in struct.items():
Expand All @@ -320,5 +303,9 @@ def _recursive_map(struct, batch_dim=0):
struct[k] = fn(v, batch_dim)

self.src = fn(self.src, 1)
if self.previous_input is not None:
self.previous_input = fn(self.previous_input, 1)
if self.previous_layer_inputs is not None:
self.previous_layer_inputs = fn(self.previous_layer_inputs, 1)
if self.cache is not None:
_recursive_map(self.cache)
18 changes: 16 additions & 2 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,11 @@ def bottle(m):
def unbottle(m):
return m.view(beam_size, batch_size, -1)

def _repeat_beam_size_times(x, dim):
repeats = [1] * x.dim()
repeats[dim] = beam_size
return x.repeat(*repeats)

# (1) Run the encoder on the src.
src, enc_states, memory_bank, src_lengths = self._run_encoder(
batch, data_type)
Expand All @@ -580,7 +585,7 @@ def unbottle(m):
else:
memory_bank = rvar(memory_bank.data)
memory_lengths = src_lengths.repeat(beam_size)
dec_states.repeat_beam_size_times(beam_size)
dec_states.map_batch_fn(_repeat_beam_size_times)

# (3) run the decoder to generate sentences, using beam search.
for i in range(self.max_length):
Expand Down Expand Up @@ -631,10 +636,19 @@ def unbottle(m):
beam_attn = unbottle(attn["copy"])

# (c) Advance each beam.
select_indices_array = []
for j, b in enumerate(beam):
b.advance(out[:, j],
beam_attn.data[:, j, :memory_lengths[j]])
dec_states.beam_update(j, b.get_current_origin(), beam_size)
select_indices_array.append(
b.get_current_origin() * batch_size + j)
select_indices = torch.cat(select_indices_array) \
.view(batch_size, beam_size) \
.transpose(0, 1) \
.contiguous() \
.view(-1)
dec_states.map_batch_fn(
lambda state, dim: state.index_select(dim, select_indices))

# (4) Extract sentences from beam.
ret = self._from_beam(beam)
Expand Down

0 comments on commit 9fef5b2

Please sign in to comment.