Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #33 from PetrochukM/memory
Browse files Browse the repository at this point in the history
Sampler biggest_batch_first upgrade
  • Loading branch information
PetrochukM committed May 6, 2018
2 parents 83f3491 + ba97bf1 commit 9925127
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 22 deletions.
21 changes: 9 additions & 12 deletions tests/samplers/test_bucket_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch

from torchnlp.samplers import BucketBatchSampler


def test_bucket_batch_sampler():
def test_bucket_batch_sampler_length():
data_source = [[1], [2], [3], [4], [5], [6]]
sort_key = lambda r: len(r)
batch_size = 2
Expand All @@ -12,7 +14,7 @@ def test_bucket_batch_sampler():
assert len(sampler) == 3


def test_bucket_batch_sampler_uneven():
def test_bucket_batch_sampler_uneven_length():
data_source = [[1], [2], [3], [4], [5]]
sort_key = lambda r: len(r)
batch_size = 2
Expand All @@ -29,19 +31,14 @@ def test_bucket_batch_sampler_uneven():


def test_bucket_batch_sampler_last_batch_first():
data_source = [[1], [2], [3], [4], [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]
data_source = [torch.tensor([j for j in range(i)]) for i in range(100)]
sort_key = lambda r: len(r)
batch_size = 2
batch_size = 1
batches = list(
BucketBatchSampler(
data_source,
batch_size,
sort_key=sort_key,
drop_last=False,
biggest_batches_first=True,
bucket_size_multiplier=2))
data_source, batch_size, sort_key=sort_key, drop_last=False, bucket_size_multiplier=2))
# Largest batch (4) is in first batch
assert 4 in batches[0]
assert 99 == batches[0][0]


def test_bucket_batch_sampler_sorted():
Expand All @@ -54,7 +51,7 @@ def test_bucket_batch_sampler_sorted():
batch_size,
sort_key=sort_key,
drop_last=False,
biggest_batches_first=False,
biggest_batches_first=None,
bucket_size_multiplier=1))
# Largest batch (4) is in first batch
for i, batch in enumerate(batches):
Expand Down
41 changes: 41 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,47 @@
from torchnlp.utils import resplit_datasets
from torchnlp.utils import shuffle
from torchnlp.utils import torch_equals_ignore_index
from torchnlp.utils import get_tensors


class GetTensorsObjectMock(object):

class_attribute = torch.tensor([4, 5])

def __init__(self, recurse=True):
self.noise_int = 3
self.noise_str = 'abc'
self.instance_attribute = frozenset([torch.tensor([6, 7])])
if recurse:
self.object_ = GetTensorsObjectMock(recurse=False)

@property
def property_(self):
return torch.tensor([7, 8])


def test_get_tensors_list():
list_ = [torch.tensor([1, 2]), torch.tensor([2, 3])]
tensors = get_tensors(list_)
assert len(tensors) == 2


def test_get_tensors_dict():
list_ = [{'t': torch.tensor([1, 2])}, torch.tensor([2, 3])]
tensors = get_tensors(list_)
assert len(tensors) == 2


def test_get_tensors_tuple():
tuple_ = tuple([{'t': torch.tensor([1, 2])}, torch.tensor([2, 3])])
tensors = get_tensors(tuple_)
assert len(tensors) == 2


def test_get_tensors_object():
object_ = GetTensorsObjectMock()
tensors = get_tensors(object_)
assert len(tensors) == 6


def test_pad_tensor():
Expand Down
22 changes: 12 additions & 10 deletions torchnlp/samplers/bucket_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import heapq
import pickle
import math

from torch.utils.data.sampler import BatchSampler
from torch.utils.data.sampler import RandomSampler

from torchnlp.samplers.sorted_sampler import SortedSampler
from torchnlp.samplers.shuffle_batch_sampler import ShuffleBatchSampler
from torchnlp.utils import get_tensors


class BucketBatchSampler(object):
Expand Down Expand Up @@ -35,9 +35,9 @@ class BucketBatchSampler(object):
comparison key from each list element
drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be
less than ``batch_size``.
biggest_batch_first (bool, optional): If ``True``, the sampler will use cPickle to
approximate the memory footprint of each batch and attempt to return the 5 biggest
batches first.
biggest_batch_first (callable or None, optional): If a callable is provided, the sampler
approximates the memory footprint of tensors in each batch, returning the 5 biggest
batches first. Callable must return a number, given an example.
This is largely for testing, to see how large of a batch you can safely use with your
GPU. This will let you try out the biggest batch that you have in the data `first`, so
Expand All @@ -63,7 +63,7 @@ def __init__(
batch_size,
drop_last,
sort_key=lambda e: e,
biggest_batches_first=True,
biggest_batches_first=lambda o: sum([t.numel() for t in get_tensors(o)]),
bucket_size_multiplier=100,
shuffle=True,
):
Expand Down Expand Up @@ -97,17 +97,19 @@ def get_batches():

yield batch

if not self.biggest_batches_first:
if self.biggest_batches_first is None:
return get_batches()
else:
batches = list(get_batches())
indices = heapq.nlargest(
biggest_batches = heapq.nlargest(
5,
range(len(batches)),
key=lambda i: len(pickle.dumps([self.data[j] for j in batches[i]])))
front = [batches[i] for i in indices]
for i in sorted(indices, reverse=True):
key=lambda i: sum([self.biggest_batches_first(self.data[j]) for j in batches[i]]))
front = [batches[i] for i in biggest_batches]
# Remove ``biggest_batches`` from data
for i in sorted(biggest_batches, reverse=True):
batches.pop(i)
# Move them to the front
batches[0:0] = front
return iter(batches)

Expand Down
34 changes: 34 additions & 0 deletions torchnlp/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import inspect
import collections

import random
import torch
Expand All @@ -8,6 +10,38 @@
logger = logging.getLogger(__name__)


def get_tensors(object_):
""" Get all tensors associated with ``object_``
Args:
object_ (any): Any object to look for tensors.
Returns:
(list of torch.tensor): List of tensors that are associated with ``object_``.
"""
if torch.is_tensor(object_):
return [object_]
elif isinstance(object_, (str, float, int)):
return []

tensors = set()

if isinstance(object_, collections.Mapping):
for value in object_.values():
tensors.update(get_tensors(value))
elif isinstance(object_, collections.Iterable):
for value in object_:
tensors.update(get_tensors(value))
else:
members = [
value for key, value in inspect.getmembers(object_)
if not isinstance(value, (collections.Callable, type(None)))
]
tensors.update(get_tensors(members))

return tensors


def sampler_to_iterator(dataset, sampler):
""" Given a batch sampler or sampler returns examples instead of indices
Expand Down

0 comments on commit 9925127

Please sign in to comment.