This repository has been archived by the owner on Jul 4, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Samplers Update Docs, Tests and Bucket Sampler
- Loading branch information
1 parent
dca4f84
commit 2a1a685
Showing
10 changed files
with
314 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from torchnlp.samplers import NoisySortedBatchSampler | ||
|
||
|
||
def test_noisy_sorted_batch_sampler(): | ||
data_source = [[1], [2], [3], [4], [5], [6]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list(NoisySortedBatchSampler(data_source, batch_size, sort_key)) | ||
assert len(batches) == 3 | ||
|
||
|
||
def test_noisy_sorted_batch_sampler_uneven(): | ||
data_source = [[1], [2], [3], [4], [5]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list(NoisySortedBatchSampler(data_source, batch_size, sort_key)) | ||
assert len(batches) == 3 | ||
batches = list(NoisySortedBatchSampler(data_source, batch_size, sort_key, drop_last=True)) | ||
assert len(batches) == 2 | ||
|
||
|
||
def test_noisy_sorted_batch_sampler_last_batch_first(): | ||
data_source = [[1], [2], [3], [4], [5, 6, 7, 8, 9, 10]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list( | ||
NoisySortedBatchSampler(data_source, batch_size, sort_key, last_batch_first=True)) | ||
# Largest batch (4) is in first batch | ||
assert 4 in batches[0] | ||
|
||
|
||
def test_noisy_sorted_batch_sampler_sorted(): | ||
data_source = [[1], [2], [3], [4], [5]] | ||
sort_key = lambda r: r[0] | ||
batch_size = 1 | ||
batches = list( | ||
NoisySortedBatchSampler( | ||
data_source, | ||
batch_size, | ||
sort_key, | ||
shuffle=False, | ||
last_batch_first=False, | ||
sort_key_noise=0.0)) | ||
# Largest batch (4) is in first batch | ||
for i, batch in enumerate(batches): | ||
assert batch[0] == i |
10 changes: 5 additions & 5 deletions
10
tests/samplers/test_random_batch_sampler.py → tests/samplers/test_shuffle_batch_sampler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,20 @@ | ||
from torchnlp.samplers import RandomBatchSampler | ||
from torchnlp.samplers import ShuffleBatchSampler | ||
|
||
from torchnlp.samplers import SortedSampler | ||
|
||
|
||
def test_random_batch_sampler(): | ||
def test_shuffle_batch_sampler(): | ||
data_source = [[1], [2], [3], [4], [5], [6]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list(RandomBatchSampler(SortedSampler(data_source, sort_key), batch_size, False)) | ||
batches = list(ShuffleBatchSampler(SortedSampler(data_source, sort_key), batch_size, False)) | ||
assert len(batches) == 3 | ||
|
||
|
||
def test_random_batch_sampler_drop_last(): | ||
def test_shuffle_batch_sampler_drop_last(): | ||
data_source = [[1], [2], [3], [4], [5]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list( | ||
RandomBatchSampler(SortedSampler(data_source, sort_key), batch_size, drop_last=True)) | ||
ShuffleBatchSampler(SortedSampler(data_source, sort_key), batch_size, drop_last=True)) | ||
assert len(batches) == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
from torchnlp.samplers.bucket_batch_sampler import BucketBatchSampler | ||
from torchnlp.samplers.sorted_sampler import SortedSampler | ||
from torchnlp.samplers.noisy_sorted_sampler import NoisySortedSampler | ||
from torchnlp.samplers.random_batch_sampler import RandomBatchSampler | ||
from torchnlp.samplers.shuffle_batch_sampler import ShuffleBatchSampler | ||
from torchnlp.samplers.noisy_sorted_batch_sampler import NoisySortedBatchSampler | ||
|
||
__all__ = ['NoisySortedSampler', 'RandomBatchSampler', 'SortedSampler', 'BucketBatchSampler'] | ||
__all__ = [ | ||
'NoisySortedSampler', 'ShuffleBatchSampler', 'SortedSampler', 'BucketBatchSampler', | ||
'NoisySortedBatchSampler' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,108 @@ | ||
import random | ||
import heapq | ||
import pickle | ||
|
||
from torch.utils.data.sampler import BatchSampler | ||
from torch.utils.data.sampler import RandomSampler | ||
|
||
from torchnlp.samplers.noisy_sorted_sampler import NoisySortedSampler | ||
from torchnlp.samplers.sorted_sampler import SortedSampler | ||
from torchnlp.samplers.shuffle_batch_sampler import ShuffleBatchSampler | ||
|
||
|
||
class BucketBatchSampler(BatchSampler): | ||
""" | ||
Reference: | ||
https://github.com/allenai/allennlp/blob/e125a490b71b21e914af01e70e9b00b165d64dcd/allennlp/data/iterators/bucket_iterator.py | ||
https://github.com/pytorch/text/tree/master/torchtext/data/iterators/#BucketIterator | ||
class BucketBatchSampler(object): | ||
"""Samples a noisy sorted mini-batch of indices from a data source. | ||
In order to introduce, noise into a sorted mini-batch, we use a bucketing technique from | ||
`torchtext`. First, partition data in buckets of size 100 * `batch_size`. The examples inside | ||
the buckets are sorted using `sort_key` and batched. Finally, those batches are shuffled. | ||
Background: | ||
BucketBatchSampler is similar to a BucketIterator found in popular libraries like `AllenNLP` | ||
and `torchtext`. A BucketIterator pools together examples with a similar size length to | ||
reduce the padding required for each batch. BucketIterator also includes the ability to add | ||
noise to the pooling. | ||
AllenNLP Implementation: | ||
https://github.com/allenai/allennlp/blob/e125a490b71b21e914af01e70e9b00b165d64dcd/allennlp/data/iterators/bucket_iterator.py | ||
torchtext Implementation: | ||
https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 | ||
Args: | ||
data (iterable): Data to sample from. | ||
batch_size (int): Size of mini-batch. | ||
sort_key (callable): specifies a function of one argument that is used to extract a | ||
comparison key from each list element | ||
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if its size | ||
would be less than ``batch_size`` | ||
biggest_batch_first (bool, optional): If ``True``, the sampler will use cPickle to | ||
approximate the memory footprint of each batch and attempt to return the 5 biggest | ||
batches first. | ||
This is largely for testing, to see how large of a batch you can safely use with your | ||
GPU. This will let you try out the biggest batch that you have in the data `first`, so | ||
that if you're going to run out of memory, you know it early, instead of waiting | ||
through the whole epoch to find out at the end that you're going to crash. | ||
`BucketIterator` pools together examples with a similar size length to reduce the padding | ||
required for each batch. `BucketIterator` typically also includes the ability to add noise to | ||
the pooling. | ||
Credits: | ||
https://github.com/allenai/allennlp/blob/3d100d31cc8d87efcf95c0b8d162bfce55c64926/allennlp/data/iterators/bucket_iterator.py#L43 | ||
bucket_size_multiplier (int): Batch size multiplier to determine the bucket size. | ||
Example: | ||
>>> list(BucketBatchSampler(range(10), batch_size=3, drop_last=False)) | ||
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] | ||
>>> list(BucketBatchSampler(range(10), batch_size=3, drop_last=True)) | ||
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] | ||
The functionality has been replicated as a `Sampler` to be used with a | ||
`torch.data.utils.DataLoader`. | ||
""" | ||
|
||
def __init__(self, | ||
data_source, | ||
sort_key, | ||
batch_size, | ||
sort_key_noise=0.25, | ||
last_batch_first=True, | ||
shuffle=True, | ||
drop_last=False): | ||
self.last_batch_first = last_batch_first | ||
self.shuffle = shuffle | ||
super().__init__( | ||
NoisySortedSampler( | ||
data_source=data_source, sort_key=sort_key, sort_key_noise=sort_key_noise), | ||
batch_size, drop_last) | ||
def __init__( | ||
self, | ||
data, | ||
batch_size, | ||
sort_key, | ||
drop_last=False, | ||
biggest_batches_first=True, | ||
bucket_size_multiplier=100, | ||
): | ||
self.biggest_batches_first = biggest_batches_first | ||
self.sort_key = sort_key | ||
self.bucket_size_multiplier = bucket_size_multiplier | ||
self.batch_size = batch_size | ||
self.drop_last = drop_last | ||
self.data = data | ||
|
||
self.bucket_size_multiplier = bucket_size_multiplier | ||
self.bucket_sampler = BatchSampler( | ||
RandomSampler(data), batch_size * bucket_size_multiplier, False) | ||
|
||
def __iter__(self): | ||
batches = list(super().__iter__()) | ||
if self.last_batch_first: | ||
last_batch = batches.pop() | ||
if self.shuffle: | ||
random.shuffle(batches) | ||
if self.last_batch_first: | ||
batches.insert(0, last_batch) | ||
return iter(batches) | ||
|
||
def get_batches(): | ||
""" Get bucketed batches """ | ||
for bucket in self.bucket_sampler: | ||
for batch in ShuffleBatchSampler( | ||
SortedSampler(bucket, lambda i: self.sort_key(self.data[i])), | ||
self.batch_size, | ||
drop_last=self.drop_last, | ||
shuffle=True): | ||
batch = [bucket[i] for i in batch] | ||
|
||
# Should only be triggered once | ||
if len(batch) < self.batch_size and self.drop_last: | ||
continue | ||
|
||
yield batch | ||
|
||
if not self.biggest_batches_first: | ||
return get_batches() | ||
else: | ||
batches = list(get_batches()) | ||
indices = heapq.nlargest( | ||
5, | ||
range(len(batches)), | ||
key=lambda i: len(pickle.dumps([self.data[j] for j in batches[i]]))) | ||
front = [batches[i] for i in indices] | ||
for i in sorted(indices, reverse=True): | ||
batches.pop(i) | ||
batches[0:0] = front | ||
return iter(batches) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import random | ||
|
||
from torch.utils.data.sampler import BatchSampler | ||
|
||
from torchnlp.samplers.noisy_sorted_sampler import NoisySortedSampler | ||
|
||
|
||
class NoisySortedBatchSampler(BatchSampler): | ||
""" Samples a noisy sorted mini-batch of indices from a data source. | ||
In order to introduce, noise into a sorted mini-batch, we sort all elements using a number to | ||
which noise is introduced from a uniform distribution. | ||
Background: | ||
NoisySortedBatchSampler is similar to a BucketIterator found in popular libraries like | ||
`AllenNLP` and `torchtext`. A BucketIterator pools together examples with a similar size | ||
length to reduce the padding required for each batch. BucketIterator also includes the | ||
ability to add noise to the pooling. | ||
AllenNLP Implementation: | ||
https://github.com/allenai/allennlp/blob/e125a490b71b21e914af01e70e9b00b165d64dcd/allennlp/data/iterators/bucket_iterator.py | ||
torchtext Implementation: | ||
https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 | ||
Args: | ||
data (iterable): Iterable data. | ||
batch_size (int): Size of mini-batch. | ||
sort_key (callable): Specifies a function of one argument that is used to extract a | ||
numerical comparison key from each list element. | ||
sort_key_noise (float): Maximum noise added to the numerical `sort_key`. | ||
last_batch_first (bool, optional): If ``True``, the sampler will append the last batch | ||
first. Only helpful if the `sort_key` approximates GPU memory. | ||
This is largely for testing, to see how large of a batch you can safely use with your | ||
GPU. This will let you try out the biggest batch that you have in the data `first`, so | ||
that if you're going to run out of memory, you know it early, instead of waiting | ||
through the whole epoch to find out at the end that you're going to crash. | ||
Credits: | ||
https://github.com/allenai/allennlp/blob/3d100d31cc8d87efcf95c0b8d162bfce55c64926/allennlp/data/iterators/bucket_iterator.py#L43 | ||
shuffle (bool, optional): If ``True``, the batches are shuffled. | ||
drop_last (bool, optional): If ``True``, the sampler will drop the last batch if its size | ||
would be less than ``batch_size``. | ||
Example: | ||
>>> list(NoisySortedBatchSampler(range(10), batch_size=3, drop_last=False)) | ||
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] | ||
>>> list(NoisySortedBatchSampler(range(10), batch_size=3, drop_last=True)) | ||
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] | ||
""" | ||
|
||
def __init__(self, | ||
data, | ||
batch_size, | ||
sort_key, | ||
sort_key_noise=0.25, | ||
last_batch_first=True, | ||
shuffle=True, | ||
drop_last=False): | ||
self.last_batch_first = last_batch_first | ||
self.shuffle = shuffle | ||
super().__init__( | ||
NoisySortedSampler(data=data, sort_key=sort_key, sort_key_noise=sort_key_noise), | ||
batch_size, drop_last) | ||
|
||
def __iter__(self): | ||
batches = list(super().__iter__()) | ||
if self.last_batch_first: | ||
last_batch = batches.pop() | ||
if self.shuffle: | ||
random.shuffle(batches) | ||
if self.last_batch_first: | ||
batches.insert(0, last_batch) | ||
return iter(batches) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.