Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Allow option to only reset some states in _EncoderBase (#2967)
Browse files Browse the repository at this point in the history
This PR adds the ability for only some states to be reset when `_EncoderBase.reset_states()` is called, which can be useful when not all sequences in a batch terminate. Closes #2828.
  • Loading branch information
rloganiv authored and brendan-ai2 committed Jun 28, 2019
1 parent 15a9cbe commit e71618d
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
31 changes: 29 additions & 2 deletions allennlp/modules/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,5 +281,32 @@ def _update_states(self,
# that there are some unused elements (zero-length) for the RNN computation.
self._states = tuple(new_states)

def reset_states(self):
self._states = None
def reset_states(self,
mask: torch.Tensor = None) -> None:
"""
Resets the internal states of a stateful encoder.
Parameters
----------
mask : ``torch.Tensor``, optional.
A tensor of shape ``(batch_size,)`` indicating which states should
be reset. If not provided, all states will be reset.
"""
if mask is None:
self._states = None
else:
# state has shape (num_layers, batch_size, hidden_size). We reshape
# mask to have shape (1, batch_size, 1) so that operations
# broadcast properly.
mask_batch_size = mask.size(0)
mask = mask.float().view(1, mask_batch_size, 1)
new_states = []
for old_state in self._states:
old_state_batch_size = old_state.size(1)
if old_state_batch_size != mask_batch_size:
raise ValueError(f'Trying to reset states using mask with incorrect batch size. '
f'Expected batch size: {old_state_batch_size}. '
f'Provided batch size: {mask_batch_size}.')
new_state = (1 - mask) * old_state
new_states.append(new_state.detach())
self._states = tuple(new_states)
31 changes: 31 additions & 0 deletions allennlp/tests/modules/encoder_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,37 @@ def test_update_states(self):
numpy.testing.assert_array_equal(self.encoder_base._states[1][:, 4, :].data.numpy(),
index_selected_initial_states[1][:, 4, :].data.numpy())

def test_reset_states(self):
# Initialize the encoder states.
assert self.encoder_base._states is None
initial_states = torch.randn([1, 5, 7]), torch.randn([1, 5, 7])
index_selected_initial_states = (initial_states[0].index_select(1, self.restoration_indices),
initial_states[1].index_select(1, self.restoration_indices))
self.encoder_base._update_states(initial_states, self.restoration_indices)

# Check that only some of the states are reset when a mask is provided.
mask = torch.FloatTensor([1, 1, 0, 0, 0])
self.encoder_base.reset_states(mask)
# First two states should be zeros
numpy.testing.assert_array_equal(self.encoder_base._states[0][:, :2, :].data.numpy(),
torch.zeros_like(initial_states[0])[:, :2, :].data.numpy())
numpy.testing.assert_array_equal(self.encoder_base._states[1][:, :2, :].data.numpy(),
torch.zeros_like(initial_states[1])[:, :2, :].data.numpy())
# Remaining states should be the same
numpy.testing.assert_array_equal(self.encoder_base._states[0][:, 2:, :].data.numpy(),
index_selected_initial_states[0][:, 2:, :].data.numpy())
numpy.testing.assert_array_equal(self.encoder_base._states[1][:, 2:, :].data.numpy(),
index_selected_initial_states[1][:, 2:, :].data.numpy())

# Check that error is raised if mask has wrong batch size.
bad_mask = torch.FloatTensor([1, 1, 0])
with self.assertRaises(ValueError):
self.encoder_base.reset_states(bad_mask)

# Check that states are reset to None if no mask is provided.
self.encoder_base.reset_states()
assert self.encoder_base._states is None

def test_non_contiguous_initial_states_handled(self):
# Check that the encoder is robust to non-contiguous initial states.

Expand Down

0 comments on commit e71618d

Please sign in to comment.