Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Added batched versions of scatter and fill to util.py (#4598)
Browse files Browse the repository at this point in the history
* added batched_index_fill

* added batched_index_scatter

* added batched_index_scatter..

* added batched_index_scatter..

* fixed util.  typo

* fixed test

* ran linter

* ran linter (again....)

* renamed methods and other small changes

* removed target.clone() and  changed from scatter_ to scatter

* Update allennlp/nn/util.py

Co-authored-by: Dirk Groeneveld <groeneveld@gmail.com>

* ran black

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Dirk Groeneveld <groeneveld@gmail.com>
  • Loading branch information
3 people committed Aug 26, 2020
1 parent 2c54cf8 commit dbc3c3f
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
89 changes: 89 additions & 0 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,95 @@ def batched_index_select(
return selected_targets


def masked_index_fill(
target: torch.Tensor, indices: torch.LongTensor, mask: torch.BoolTensor, fill_value: int = 1
) -> torch.Tensor:
"""
The given `indices` in `target` will be will be filled with `fill_value` given a `mask`.
# Parameters
target : `torch.Tensor`, required.
A 2 dimensional tensor of shape (batch_size, sequence_length).
This is the tensor to be filled.
indices : `torch.LongTensor`, required
A 2 dimensional tensor of shape (batch_size, num_indices),
These are the indices that will be filled in the original tensor.
mask : `torch.Tensor`, required.
A 2 dimensional tensor of shape (batch_size, num_indices), mask.sum() == `nonzero_indices`.
fill_value : `int`, optional (default = `1`)
The value we fill the tensor with.
# Returns
filled_target : `torch.Tensor`
A tensor with shape (batch_size, sequence_length) where 'indices' are filled with `fill_value`
"""
mask = mask.bool()
prev_shape = target.size()
# Shape: (batch_size * num_indices)
flattened_indices = flatten_and_batch_shift_indices(indices * mask, target.size(1))
# Shape: (batch_size * num_indices, 1)
mask = mask.view(-1)
# Shape: (batch_size * sequence_length, 1)
flattened_target = target.view(-1, 1)
# Shape: (nonzero_indices, 1)
unmasked_indices = flattened_indices[mask].unsqueeze(-1)

flattened_target = flattened_target.scatter(0, unmasked_indices, fill_value)

filled_target = flattened_target.reshape(prev_shape)

return filled_target


def masked_index_replace(
target: torch.Tensor,
indices: torch.LongTensor,
mask: torch.BoolTensor,
replace: torch.Tensor,
) -> torch.Tensor:
"""
The given `indices` in `target` will be will be replaced with corresponding index
from the `replace` tensor given a `mask`.
# Parameters
target : `torch.Tensor`, required.
A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_dim).
This is the tensor to be replaced into.
indices : `torch.LongTensor`, required
A 2 dimensional tensor of shape (batch_size, num_indices),
These are the indices that will be replaced in the original tensor.
mask : `torch.Tensor`, required.
A 2 dimensional tensor of shape (batch_size, num_indices), mask.sum() == `nonzero_indices`.
replace : `torch.Tensor`, required.
A 3 dimensional tensor of shape (batch_size, num_indices, embedding_dim),
The tensor to perform scatter from.
# Returns
replaced_target : `torch.Tensor`
A tensor with shape (batch_size, sequence_length, embedding_dim) where 'indices'
are replaced with the corrosponding vector from `replace`
"""
target = target.clone()
mask = mask.bool()
prev_shape = target.size()
# Shape: (batch_size * num_indices)
flattened_indices = flatten_and_batch_shift_indices(indices * mask, target.size(1))
# Shape: (batch_size * sequence_length, embedding_size)
flattened_target = target.view(-1, target.size(-1))
# Shape: (nonzero_indices, 1)
mask = mask.view(-1)
flattened_target[flattened_indices[mask]] = replace.view(-1, replace.size(-1))[mask]
# Shape: (batch_size, sequence_length, embedding_dim)
replaced_target = flattened_target.reshape(prev_shape)
return replaced_target


def batched_span_select(target: torch.Tensor, spans: torch.LongTensor) -> torch.Tensor:
"""
The given `spans` of size `(batch_size, num_spans, 2)` indexes into the sequence
Expand Down
33 changes: 33 additions & 0 deletions tests/nn/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,39 @@ def test_batched_index_select(self):
with pytest.raises(ConfigurationError):
util.batched_index_select(targets, indices)

def test_masked_index_fill(self):
targets = torch.zeros([3, 5])
indices = torch.tensor([[4, 2, 3, -1], [0, 1, -1, -1], [1, 3, -1, -1]])
mask = indices >= 0
filled = util.masked_index_fill(targets, indices, mask)

numpy.testing.assert_array_equal(
filled, [[0, 0, 1, 1, 1], [1, 1, 0, 0, 0], [0, 1, 0, 1, 0]]
)

def test_masked_index_replace(self):
targets = torch.zeros([3, 5, 2])
indices = torch.tensor([[4, 2, 3, -1], [0, 1, -1, -1], [3, 1, -1, -1]])
replace_with = (
torch.arange(indices.numel())
.float()
.reshape(indices.shape)
.unsqueeze(-1)
.expand(indices.shape + (2,))
)

mask = indices >= 0
replaced = util.masked_index_replace(targets, indices, mask, replace_with)

numpy.testing.assert_array_equal(
replaced,
[
[[0, 0], [0, 0], [1, 1], [2, 2], [0, 0]],
[[4, 4], [5, 5], [0, 0], [0, 0], [0, 0]],
[[0, 0], [9, 9], [0, 0], [8, 8], [0, 0]],
],
)

def test_batched_span_select(self):
# Each element is a vector of its index.
targets = torch.ones([3, 12, 2]).cumsum(1) - 1
Expand Down

0 comments on commit dbc3c3f

Please sign in to comment.