Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Samplers Update Docs, Tests and Bucket Sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
PetrochukM committed Mar 11, 2018
1 parent dca4f84 commit 2a1a685
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 80 deletions.
24 changes: 14 additions & 10 deletions tests/samplers/test_bucket_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,45 @@ def test_bucket_batch_sampler():
data_source = [[1], [2], [3], [4], [5], [6]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(BucketBatchSampler(data_source, sort_key, batch_size))
batches = list(BucketBatchSampler(data_source, batch_size, sort_key, bucket_size_multiplier=2))
assert len(batches) == 3


def test_bucket_batch_sampler_uneven():
data_source = [[1], [2], [3], [4], [5]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(BucketBatchSampler(data_source, sort_key, batch_size))
batches = list(BucketBatchSampler(data_source, batch_size, sort_key, bucket_size_multiplier=2))
assert len(batches) == 3
batches = list(BucketBatchSampler(data_source, sort_key, batch_size, drop_last=True))
batches = list(
BucketBatchSampler(
data_source, batch_size, sort_key, drop_last=True, bucket_size_multiplier=2))
assert len(batches) == 2


def test_bucket_batch_sampler_last_batch_first():
data_source = [[1], [2], [3], [4], [5, 6, 7, 8, 9, 10]]
data_source = [[1], [2], [3], [4], [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(BucketBatchSampler(data_source, sort_key, batch_size, last_batch_first=True))
batches = list(
BucketBatchSampler(
data_source, batch_size, sort_key, biggest_batches_first=True,
bucket_size_multiplier=2))
# Largest batch (4) is in first batch
assert 4 in batches[0]


def test_bucket_batch_sampler_sorted():
data_source = [[1], [2], [3], [4], [5]]
sort_key = lambda r: r[0]
batch_size = 1
batch_size = len(data_source)
batches = list(
BucketBatchSampler(
data_source,
sort_key,
batch_size,
shuffle=False,
last_batch_first=False,
sort_key_noise=0.0))
sort_key,
biggest_batches_first=False,
bucket_size_multiplier=1))
# Largest batch (4) is in first batch
for i, batch in enumerate(batches):
assert batch[0] == i
46 changes: 46 additions & 0 deletions tests/samplers/test_noisy_sorted_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from torchnlp.samplers import NoisySortedBatchSampler


def test_noisy_sorted_batch_sampler():
data_source = [[1], [2], [3], [4], [5], [6]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(NoisySortedBatchSampler(data_source, batch_size, sort_key))
assert len(batches) == 3


def test_noisy_sorted_batch_sampler_uneven():
data_source = [[1], [2], [3], [4], [5]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(NoisySortedBatchSampler(data_source, batch_size, sort_key))
assert len(batches) == 3
batches = list(NoisySortedBatchSampler(data_source, batch_size, sort_key, drop_last=True))
assert len(batches) == 2


def test_noisy_sorted_batch_sampler_last_batch_first():
data_source = [[1], [2], [3], [4], [5, 6, 7, 8, 9, 10]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(
NoisySortedBatchSampler(data_source, batch_size, sort_key, last_batch_first=True))
# Largest batch (4) is in first batch
assert 4 in batches[0]


def test_noisy_sorted_batch_sampler_sorted():
data_source = [[1], [2], [3], [4], [5]]
sort_key = lambda r: r[0]
batch_size = 1
batches = list(
NoisySortedBatchSampler(
data_source,
batch_size,
sort_key,
shuffle=False,
last_batch_first=False,
sort_key_noise=0.0))
# Largest batch (4) is in first batch
for i, batch in enumerate(batches):
assert batch[0] == i
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from torchnlp.samplers import RandomBatchSampler
from torchnlp.samplers import ShuffleBatchSampler

from torchnlp.samplers import SortedSampler


def test_random_batch_sampler():
def test_shuffle_batch_sampler():
data_source = [[1], [2], [3], [4], [5], [6]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(RandomBatchSampler(SortedSampler(data_source, sort_key), batch_size, False))
batches = list(ShuffleBatchSampler(SortedSampler(data_source, sort_key), batch_size, False))
assert len(batches) == 3


def test_random_batch_sampler_drop_last():
def test_shuffle_batch_sampler_drop_last():
data_source = [[1], [2], [3], [4], [5]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(
RandomBatchSampler(SortedSampler(data_source, sort_key), batch_size, drop_last=True))
ShuffleBatchSampler(SortedSampler(data_source, sort_key), batch_size, drop_last=True))
assert len(batches) == 2
8 changes: 6 additions & 2 deletions torchnlp/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from torchnlp.samplers.bucket_batch_sampler import BucketBatchSampler
from torchnlp.samplers.sorted_sampler import SortedSampler
from torchnlp.samplers.noisy_sorted_sampler import NoisySortedSampler
from torchnlp.samplers.random_batch_sampler import RandomBatchSampler
from torchnlp.samplers.shuffle_batch_sampler import ShuffleBatchSampler
from torchnlp.samplers.noisy_sorted_batch_sampler import NoisySortedBatchSampler

__all__ = ['NoisySortedSampler', 'RandomBatchSampler', 'SortedSampler', 'BucketBatchSampler']
__all__ = [
'NoisySortedSampler', 'ShuffleBatchSampler', 'SortedSampler', 'BucketBatchSampler',
'NoisySortedBatchSampler'
]
131 changes: 97 additions & 34 deletions torchnlp/samplers/bucket_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,108 @@
import random
import heapq
import pickle

from torch.utils.data.sampler import BatchSampler
from torch.utils.data.sampler import RandomSampler

from torchnlp.samplers.noisy_sorted_sampler import NoisySortedSampler
from torchnlp.samplers.sorted_sampler import SortedSampler
from torchnlp.samplers.shuffle_batch_sampler import ShuffleBatchSampler


class BucketBatchSampler(BatchSampler):
"""
Reference:
https://github.com/allenai/allennlp/blob/e125a490b71b21e914af01e70e9b00b165d64dcd/allennlp/data/iterators/bucket_iterator.py
https://github.com/pytorch/text/tree/master/torchtext/data/iterators/#BucketIterator
class BucketBatchSampler(object):
"""Samples a noisy sorted mini-batch of indices from a data source.
In order to introduce, noise into a sorted mini-batch, we use a bucketing technique from
`torchtext`. First, partition data in buckets of size 100 * `batch_size`. The examples inside
the buckets are sorted using `sort_key` and batched. Finally, those batches are shuffled.
Background:
BucketBatchSampler is similar to a BucketIterator found in popular libraries like `AllenNLP`
and `torchtext`. A BucketIterator pools together examples with a similar size length to
reduce the padding required for each batch. BucketIterator also includes the ability to add
noise to the pooling.
AllenNLP Implementation:
https://github.com/allenai/allennlp/blob/e125a490b71b21e914af01e70e9b00b165d64dcd/allennlp/data/iterators/bucket_iterator.py
torchtext Implementation:
https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225
Args:
data (iterable): Data to sample from.
batch_size (int): Size of mini-batch.
sort_key (callable): specifies a function of one argument that is used to extract a
comparison key from each list element
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if its size
would be less than ``batch_size``
biggest_batch_first (bool, optional): If ``True``, the sampler will use cPickle to
approximate the memory footprint of each batch and attempt to return the 5 biggest
batches first.
This is largely for testing, to see how large of a batch you can safely use with your
GPU. This will let you try out the biggest batch that you have in the data `first`, so
that if you're going to run out of memory, you know it early, instead of waiting
through the whole epoch to find out at the end that you're going to crash.
`BucketIterator` pools together examples with a similar size length to reduce the padding
required for each batch. `BucketIterator` typically also includes the ability to add noise to
the pooling.
Credits:
https://github.com/allenai/allennlp/blob/3d100d31cc8d87efcf95c0b8d162bfce55c64926/allennlp/data/iterators/bucket_iterator.py#L43
bucket_size_multiplier (int): Batch size multiplier to determine the bucket size.
Example:
>>> list(BucketBatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BucketBatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
The functionality has been replicated as a `Sampler` to be used with a
`torch.data.utils.DataLoader`.
"""

def __init__(self,
data_source,
sort_key,
batch_size,
sort_key_noise=0.25,
last_batch_first=True,
shuffle=True,
drop_last=False):
self.last_batch_first = last_batch_first
self.shuffle = shuffle
super().__init__(
NoisySortedSampler(
data_source=data_source, sort_key=sort_key, sort_key_noise=sort_key_noise),
batch_size, drop_last)
def __init__(
self,
data,
batch_size,
sort_key,
drop_last=False,
biggest_batches_first=True,
bucket_size_multiplier=100,
):
self.biggest_batches_first = biggest_batches_first
self.sort_key = sort_key
self.bucket_size_multiplier = bucket_size_multiplier
self.batch_size = batch_size
self.drop_last = drop_last
self.data = data

self.bucket_size_multiplier = bucket_size_multiplier
self.bucket_sampler = BatchSampler(
RandomSampler(data), batch_size * bucket_size_multiplier, False)

def __iter__(self):
batches = list(super().__iter__())
if self.last_batch_first:
last_batch = batches.pop()
if self.shuffle:
random.shuffle(batches)
if self.last_batch_first:
batches.insert(0, last_batch)
return iter(batches)

def get_batches():
""" Get bucketed batches """
for bucket in self.bucket_sampler:
for batch in ShuffleBatchSampler(
SortedSampler(bucket, lambda i: self.sort_key(self.data[i])),
self.batch_size,
drop_last=self.drop_last,
shuffle=True):
batch = [bucket[i] for i in batch]

# Should only be triggered once
if len(batch) < self.batch_size and self.drop_last:
continue

yield batch

if not self.biggest_batches_first:
return get_batches()
else:
batches = list(get_batches())
indices = heapq.nlargest(
5,
range(len(batches)),
key=lambda i: len(pickle.dumps([self.data[j] for j in batches[i]])))
front = [batches[i] for i in indices]
for i in sorted(indices, reverse=True):
batches.pop(i)
batches[0:0] = front
return iter(batches)
76 changes: 76 additions & 0 deletions torchnlp/samplers/noisy_sorted_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import random

from torch.utils.data.sampler import BatchSampler

from torchnlp.samplers.noisy_sorted_sampler import NoisySortedSampler


class NoisySortedBatchSampler(BatchSampler):
""" Samples a noisy sorted mini-batch of indices from a data source.
In order to introduce, noise into a sorted mini-batch, we sort all elements using a number to
which noise is introduced from a uniform distribution.
Background:
NoisySortedBatchSampler is similar to a BucketIterator found in popular libraries like
`AllenNLP` and `torchtext`. A BucketIterator pools together examples with a similar size
length to reduce the padding required for each batch. BucketIterator also includes the
ability to add noise to the pooling.
AllenNLP Implementation:
https://github.com/allenai/allennlp/blob/e125a490b71b21e914af01e70e9b00b165d64dcd/allennlp/data/iterators/bucket_iterator.py
torchtext Implementation:
https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225
Args:
data (iterable): Iterable data.
batch_size (int): Size of mini-batch.
sort_key (callable): Specifies a function of one argument that is used to extract a
numerical comparison key from each list element.
sort_key_noise (float): Maximum noise added to the numerical `sort_key`.
last_batch_first (bool, optional): If ``True``, the sampler will append the last batch
first. Only helpful if the `sort_key` approximates GPU memory.
This is largely for testing, to see how large of a batch you can safely use with your
GPU. This will let you try out the biggest batch that you have in the data `first`, so
that if you're going to run out of memory, you know it early, instead of waiting
through the whole epoch to find out at the end that you're going to crash.
Credits:
https://github.com/allenai/allennlp/blob/3d100d31cc8d87efcf95c0b8d162bfce55c64926/allennlp/data/iterators/bucket_iterator.py#L43
shuffle (bool, optional): If ``True``, the batches are shuffled.
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if its size
would be less than ``batch_size``.
Example:
>>> list(NoisySortedBatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(NoisySortedBatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""

def __init__(self,
data,
batch_size,
sort_key,
sort_key_noise=0.25,
last_batch_first=True,
shuffle=True,
drop_last=False):
self.last_batch_first = last_batch_first
self.shuffle = shuffle
super().__init__(
NoisySortedSampler(data=data, sort_key=sort_key, sort_key_noise=sort_key_noise),
batch_size, drop_last)

def __iter__(self):
batches = list(super().__iter__())
if self.last_batch_first:
last_batch = batches.pop()
if self.shuffle:
random.shuffle(batches)
if self.last_batch_first:
batches.insert(0, last_batch)
return iter(batches)
23 changes: 14 additions & 9 deletions torchnlp/samplers/noisy_sorted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,23 @@ class NoisySortedSampler(Sampler):
Reference and inspiration:
https://github.com/allenai/allennlp/blob/e125a490b71b21e914af01e70e9b00b165d64dcd/allennlp/data/iterators/bucket_iterator.py
Arguments:
data_source (datasets.Dataset): dataset to sample from
sort_key (callable): specifies a function of one argument that is used to extract a
comparison key from each list element
Args:
data (iterable): Data to sample from.
sort_key (callable): Specifies a function of one argument that is used to extract a
numerical comparison key from each list element.
sort_key_noise (float): Maximum noise added to the numerical `sort_key`.
Example:
>>> list(NoisySortedSampler(range(10), sort_key=lambda i: i, sort_key_noise=0.25))
[0, 1, 2, 3, 5, 4, 6, 8, 7, 9]
"""

def __init__(self, data_source, sort_key, sort_key_noise=0.25):
super().__init__(data_source)
self.data_source = data_source
def __init__(self, data, sort_key, sort_key_noise=0.25):
super().__init__(data)
self.data = data
self.sort_key = sort_key
zip_ = []
for i, row in enumerate(self.data_source):
for i, row in enumerate(self.data):
value = self.sort_key(row)
noise_value = value * sort_key_noise
noise = random.uniform(-noise_value, noise_value)
Expand All @@ -33,4 +38,4 @@ def __iter__(self):
return iter(self.sorted_indexes)

def __len__(self):
return len(self.data_source)
return len(self.data)

0 comments on commit 2a1a685

Please sign in to comment.