Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Update stack_and_pad_tensors spec
Browse files Browse the repository at this point in the history
  • Loading branch information
PetrochukM committed Apr 7, 2019
1 parent fca395e commit 0793c2c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
15 changes: 13 additions & 2 deletions tests/encoders/text/test_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,21 @@ def test_pad_tensor_multiple_dim_float_tensor():
assert padded.type() == 'torch.FloatTensor'


def test_pad_batch():
def test_stack_and_pad_tensors():
batch = [torch.LongTensor([1, 2, 3]), torch.LongTensor([1, 2]), torch.LongTensor([1])]
padded, lengths = stack_and_pad_tensors(batch, DEFAULT_PADDING_INDEX)
padded = [r.tolist() for r in padded]
assert padded == [[1, 2, 3], [1, 2, DEFAULT_PADDING_INDEX],
[1, DEFAULT_PADDING_INDEX, DEFAULT_PADDING_INDEX]]
assert lengths == [3, 2, 1]
assert lengths.tolist() == [3, 2, 1]


def test_stack_and_pad_tensors__dim():
batch_size = 3
batch = [torch.LongTensor([1, 2, 3, 4]), torch.LongTensor([1, 2, 3]), torch.LongTensor([1, 2])]
padded, lengths = stack_and_pad_tensors(batch, DEFAULT_PADDING_INDEX, dim=1)
assert padded.shape == (4, batch_size)
assert lengths.shape == (1, batch_size)
assert lengths.tolist() == [[4, 3, 2]]
assert padded.tolist() == [[1, 1, 1], [2, 2, 2], [3, 3, DEFAULT_PADDING_INDEX],
[4, DEFAULT_PADDING_INDEX, DEFAULT_PADDING_INDEX]]
5 changes: 4 additions & 1 deletion torchnlp/encoders/text/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def stack_and_pad_tensors(batch, padding_index=DEFAULT_PADDING_INDEX, dim=0):
dim (int, optional): Dimension on to which to concatenate the batch of tensors.
Returns
torch.Tensor, list of int: Padded tensors and original lengths of tensors.
torch.Tensor, torch.Tensor: Padded tensors and original lengths of tensors.
"""
lengths = [tensor.shape[0] for tensor in batch]
max_len = max(lengths)
padded = [pad_tensor(tensor, max_len, padding_index) for tensor in batch]
lengths = torch.tensor(lengths)
padded = torch.stack(padded, dim=dim).contiguous()
for _ in range(dim):
lengths = lengths.unsqueeze(0)
return padded, lengths


Expand Down

0 comments on commit 0793c2c

Please sign in to comment.