Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #30 from PetrochukM/padding
Browse files Browse the repository at this point in the history
Update ``pad_tensor`` and ``pad_batch``
  • Loading branch information
PetrochukM committed May 6, 2018
2 parents 588c340 + 55be8dc commit cf41090
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
19 changes: 16 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import urllib.request
import pickle

import pytest
import torch

from tqdm import tqdm

from torchnlp.datasets import Dataset
from torchnlp.text_encoders import PADDING_INDEX
from torchnlp.utils import flatten_parameters
from torchnlp.utils import get_filename_from_url
from torchnlp.utils import pad_batch
from torchnlp.utils import pad_tensor
from torchnlp.utils import reporthook
from torchnlp.utils import resplit_datasets
from torchnlp.utils import shuffle
from torchnlp.utils import torch_equals_ignore_index
from torchnlp.utils import get_filename_from_url


def test_get_filename_from_url():
Expand All @@ -24,13 +26,24 @@ def test_get_filename_from_url():


def test_pad_tensor():
PADDING_INDEX = 0
padded = pad_tensor(torch.LongTensor([1, 2, 3]), 5, PADDING_INDEX)
assert padded.tolist() == [1, 2, 3, PADDING_INDEX, PADDING_INDEX]


def test_pad_tensor_multiple_dim():
padded = pad_tensor(torch.LongTensor(1, 2, 3), 5, PADDING_INDEX)
assert padded.size() == (5, 2, 3)
assert padded[1].sum().item() == pytest.approx(0)


def test_pad_tensor_multiple_dim_float_tensor():
padded = pad_tensor(torch.FloatTensor(778, 80), 804, PADDING_INDEX)
assert padded.size() == (804, 80)
assert padded[-1].sum().item() == pytest.approx(0)
assert padded.type() == 'torch.FloatTensor'


def test_pad_batch():
PADDING_INDEX = 0
batch = [torch.LongTensor([1, 2, 3]), torch.LongTensor([1, 2]), torch.LongTensor([1])]
padded, lengths = pad_batch(batch, PADDING_INDEX)
padded = [r.tolist() for r in padded]
Expand Down
38 changes: 18 additions & 20 deletions torchnlp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,43 +51,41 @@ def datasets_iterator(*datasets):

def pad_tensor(tensor, length, padding_index=PADDING_INDEX):
""" Pad a ``tensor`` to ``length`` with ``padding_index``.
Args:
tensor (1D :class:`torch.LongTensor`): Tensor to pad.
tensor (torch.Tensor [n, *]): Tensor to pad.
length (int): Pad the ``tensor`` up to ``length``.
padding_index (int, optional): Index to pad tensor with.
Returns
torch.LongTensor: Padded Tensor.
(torch.Tensor [length, *]) Padded Tensor.
"""
assert len(tensor.size()) == 1
assert length >= len(tensor)
n_padding = length - len(tensor)
padding = torch.LongTensor(n_padding * [padding_index])
return torch.cat((tensor, padding), 0)


def flatten_parameters(model):
""" ``flatten_parameters`` of a RNN model loaded from disk. """
model.apply(lambda m: m.flatten_parameters() if hasattr(m, 'flatten_parameters') else None)
n_padding = length - tensor.shape[0]
assert n_padding >= 0
if n_padding == 0:
return tensor
padding = tensor.new(n_padding, *tensor.shape[1:]).fill_(padding_index)
return torch.cat((tensor, padding), dim=0)


def pad_batch(batch, padding_index=PADDING_INDEX):
""" Pad a :class:`list` of ``tensors`` (``batch``) with ``padding_index``.
Args:
batch (:class:`list` of 1D :class:`torch.LongTensor`): Batch of tensors to pad.
batch (:class:`list` of :class:`torch.Tensor`): Batch of tensors to pad.
padding_index (int, optional): Index to pad tensors with.
Returns
list of torch.LongTensor, list of int: Padded tensors and original lengths of tensors.
torch.Tensor, list of int: Padded tensors and original lengths of tensors.
"""
lengths = [len(row) for row in batch]
lengths = [tensor.shape[0] for tensor in batch]
max_len = max(lengths)
padded = [pad_tensor(row, max_len, padding_index) for row in batch]
padded = [pad_tensor(tensor, max_len, padding_index) for tensor in batch]
padded = torch.stack(padded, dim=0).contiguous()
return padded, lengths


def flatten_parameters(model):
""" ``flatten_parameters`` of a RNN model loaded from disk. """
model.apply(lambda m: m.flatten_parameters() if hasattr(m, 'flatten_parameters') else None)


def shuffle(list_, random_seed=123):
""" Shuffle list deterministically based on ``random_seed``.
Expand Down

0 comments on commit cf41090

Please sign in to comment.