Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/master' into vision
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Jul 24, 2020
2 parents 6cc508d + d73f8a9 commit 3137961
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Removed unnecessary warning about deadlocks in `DataLoader`.
- Use slower tqdm intervals when output is being piped or redirected.
- Fixed testing models that only return a loss when they are in training mode


## [v1.1.0rc1](https://github.com/allenai/allennlp/releases/tag/v1.1.0rc1) - 2020-07-14
Expand Down Expand Up @@ -44,6 +45,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Adjust beam search to support multi-layer decoder.
- A method to ModelTestCase for running basic model tests when you aren't using config files.
- Added some convenience methods for reading files.
- Added an option to `file_utils.cached_path` to automatically extract archives.
Expand Down
13 changes: 8 additions & 5 deletions allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,19 @@ def ensure_model_can_train_save_and_load(
print("Predicting with loaded model")
loaded_model_predictions = loaded_model(**loaded_batch)

# Check loaded model's loss exists and we can compute gradients, for continuing training.
loaded_model_loss = loaded_model_predictions["loss"]
assert loaded_model_loss is not None
loaded_model_loss.backward()

# Both outputs should have the same keys and the values for these keys should be close.
for key in model_predictions.keys():
self.assert_fields_equal(
model_predictions[key], loaded_model_predictions[key], name=key, tolerance=tolerance
)

# Check loaded model's loss exists and we can compute gradients, for continuing training.
loaded_model.train()
loaded_model_predictions = loaded_model(**loaded_batch)
loaded_model_loss = loaded_model_predictions["loss"]
assert loaded_model_loss is not None
loaded_model_loss.backward()

return model, loaded_model

def ensure_model_can_train(
Expand Down Expand Up @@ -276,6 +278,7 @@ def check_model_computes_gradients_correctly(
):
print("Checking gradients")
model.zero_grad()
model.train()

original_dropouts: Dict[str, float] = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def tokens_to_indices(self, tokens: List[Token], vocabulary: Vocabulary) -> Inde
def indices_to_tokens(
self, indexed_tokens: IndexedTokenList, vocabulary: Vocabulary
) -> List[Token]:
self._add_encoding_to_vocabulary_if_needed(vocabulary)

token_ids = indexed_tokens["token_ids"]
type_ids = indexed_tokens.get("type_ids")

Expand Down
76 changes: 55 additions & 21 deletions allennlp/nn/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from allennlp.common.checks import ConfigurationError


StateType = Dict[str, torch.Tensor]
StepFunctionType = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
Expand All @@ -24,6 +23,7 @@ class BeamSearch:
max_steps : `int`, optional (default = `50`)
The maximum number of decoding steps to take, i.e. the maximum length
of the predicted sequences.
beam_size : `int`, optional (default = `10`)
The width of the beam used.
per_node_beam_size : `int`, optional (default = `beam_size`)
Expand Down Expand Up @@ -196,13 +196,27 @@ def new_step(
for key, state_tensor in state.items():
if state_tensor is None:
continue
_, *last_dims = state_tensor.size()
# shape: (batch_size * beam_size, *)
state[key] = (
state_tensor.unsqueeze(1)
.expand(batch_size, self.beam_size, *last_dims)
.reshape(batch_size * self.beam_size, *last_dims)
)
multilayer_rnn_decoder = state_tensor.dim() == 3 and key in {
"decoder_hidden",
"decoder_context",
}

if multilayer_rnn_decoder:
# shape: (num_layers, batch_size * beam_size, *)
num_layers, _, *last_dims = state_tensor.size()
state[key] = (
state_tensor.unsqueeze(2)
.expand(num_layers, batch_size, self.beam_size, *last_dims)
.reshape(num_layers, batch_size * self.beam_size, *last_dims)
)
else:
# shape: (batch_size * beam_size, *)
_, *last_dims = state_tensor.size()
state[key] = (
state_tensor.unsqueeze(1)
.expand(batch_size, self.beam_size, *last_dims)
.reshape(batch_size * self.beam_size, *last_dims)
)

for timestep in range(self.max_steps - 1):
# shape: (batch_size * beam_size,)
Expand Down Expand Up @@ -284,7 +298,7 @@ def new_step(
# dividing by per_node_beam_size gives the ancestor. (Note that this is integer
# division as the tensor is a LongTensor.)
# shape: (batch_size, beam_size)
backpointer = restricted_beam_indices / self.per_node_beam_size
backpointer = restricted_beam_indices // self.per_node_beam_size

backpointers.append(backpointer)

Expand All @@ -293,18 +307,38 @@ def new_step(
for key, state_tensor in state.items():
if state_tensor is None:
continue
_, *last_dims = state_tensor.size()
# shape: (batch_size, beam_size, *)
expanded_backpointer = backpointer.view(
batch_size, self.beam_size, *([1] * len(last_dims))
).expand(batch_size, self.beam_size, *last_dims)

# shape: (batch_size * beam_size, *)
state[key] = (
state_tensor.reshape(batch_size, self.beam_size, *last_dims)
.gather(1, expanded_backpointer)
.reshape(batch_size * self.beam_size, *last_dims)
)
multilayer_rnn_decoder = state_tensor.dim() == 3 and key in {
"decoder_hidden",
"decoder_context",
}
if multilayer_rnn_decoder:
# shape: (num_layers, batch_size * beam_size, *)
num_layers, _, *last_dims = state_tensor.size()
expanded_backpointer = backpointer.view(
batch_size, self.beam_size, *([1] * len(last_dims))
).expand(batch_size, self.beam_size, *last_dims)
expanded_backpointer = expanded_backpointer.unsqueeze(0).repeat(
num_layers, 1, 1, 1
)
# shape: (num_layers, batch_size * beam_size, *)
state[key] = (
state_tensor.reshape(num_layers, batch_size, self.beam_size, *last_dims)
.gather(2, expanded_backpointer)
.reshape(num_layers, batch_size * self.beam_size, *last_dims)
)
else:
_, *last_dims = state_tensor.size()
# shape: (batch_size, beam_size, *)
expanded_backpointer = backpointer.view(
batch_size, self.beam_size, *([1] * len(last_dims))
).expand(batch_size, self.beam_size, *last_dims)

# shape: (batch_size * beam_size, *)
state[key] = (
state_tensor.reshape(batch_size, self.beam_size, *last_dims)
.gather(1, expanded_backpointer)
.reshape(batch_size * self.beam_size, *last_dims)
)

if not torch.isfinite(last_log_probabilities).all():
warnings.warn(
Expand Down
36 changes: 36 additions & 0 deletions tests/nn/beam_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,42 @@ def test_finished_state(self):
for key, array in expected_finished_state.items():
np.testing.assert_allclose(state[key].numpy(), array)

def test_diff_shape_state(self):
state = {}
state["decoder_hidden"] = torch.tensor(
[[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]]
)
state["decoder_hidden"] = state["decoder_hidden"].unsqueeze(0).repeat(2, 1, 1)
# shape: (2, batch_size, 3)

seq = [
[1, 0, 1],
[1, 0, 1],
[1, 0, 1],
[2, 0, 1],
[2, 0, 1],
[2, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
]
seq = [seq] * 2
expected_finished_state = {}
expected_finished_state["decoder_hidden"] = np.array(seq)
# shape: (2, batch_size x beam_size, 3)

self._check_results(state=state)

# check finished state.
for key, array in expected_finished_state.items():
np.testing.assert_allclose(state[key].numpy(), array)

def test_batch_size_of_one(self):
self._check_results(batch_size=1)

Expand Down

0 comments on commit 3137961

Please sign in to comment.