Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix out-of-bound checking in BidirectionalEndpointSpanExtractor for e…
Browse files Browse the repository at this point in the history
…mpty sequences. (#2763)
  • Loading branch information
Waleed Ammar authored and brendan-ai2 committed May 3, 2019
1 parent 9fc6f9e commit a701a0a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def forward(self,
sequence_tensor.size(1))

# shape (batch_size, num_spans, 1)
end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1)
end_sentinel_mask = (exclusive_span_ends >= sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1)

# As we added 1 to the span_ends to make them exclusive, which might have caused indices
# equal to the sequence_length to become out of bounds, we multiply by the inverse of the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,25 @@ def test_correct_sequence_elements_are_embedded_with_a_masked_sequence(self):
numpy.testing.assert_array_equal(backward_start_embeddings.data.numpy(),
correct_backward_start_embeddings.data.numpy())


def test_forward_doesnt_raise_with_empty_sequence(self):
# size: (batch_size=1, sequence_length=2, emb_dim=2)
sequence_tensor = torch.FloatTensor([[[0., 0.], [0., 0.]]])
# size: (batch_size=1, sequence_length=2)
sequence_mask = torch.LongTensor([[0, 0]])
# size: (batch_size=1, spans_count=1, 2)
span_indices = torch.LongTensor([[[-1, -1]]])
# size: (batch_size=1, spans_count=1)
span_indices_mask = torch.LongTensor([[0]])
extractor = BidirectionalEndpointSpanExtractor(input_dim=2,
forward_combination="x,y",
backward_combination="x,y")
span_representations = extractor(sequence_tensor, span_indices,
sequence_mask=sequence_mask,
span_indices_mask=span_indices_mask)
numpy.testing.assert_array_equal(span_representations.detach(),
torch.FloatTensor([[[0., 0., 0., 0.]]]))

def test_forward_raises_with_invalid_indices(self):
sequence_tensor = torch.randn([2, 5, 8])
extractor = BidirectionalEndpointSpanExtractor(input_dim=8)
Expand Down

0 comments on commit a701a0a

Please sign in to comment.