This repository has been archived by the owner on Jul 4, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fd46926
commit 616fd74
Showing
7 changed files
with
129 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,36 @@ | ||
import random | ||
|
||
from lib.samplers.sorted_sampler import SortedSampler | ||
from torch.utils.data.sampler import Sampler | ||
|
||
|
||
class NoisySortedSampler(SortedSampler): | ||
"""Samples elements sequentially, always in the same order. | ||
class NoisySortedSampler(Sampler): | ||
"""Samples elements sequentially with noise. | ||
Reference and inspiration: | ||
https://github.com/allenai/allennlp/blob/e125a490b71b21e914af01e70e9b00b165d64dcd/allennlp/data/iterators/bucket_iterator.py | ||
Arguments: | ||
data_source (Dataset): dataset to sample from | ||
sort_key (callable -> int): callable that returns from one row of the data_source a int | ||
sort_key (callable): specifies a function of one argument that is used to extract a | ||
comparison key from each list element | ||
""" | ||
|
||
def __init__(self, data_source, sort_key, sort_key_noise=0.1): | ||
def __init__(self, data_source, sort_key, sort_key_noise=0.25): | ||
super().__init__(data_source) | ||
self.data_source = data_source | ||
self.sort_key = sort_key | ||
zip = [] | ||
zip_ = [] | ||
for i, row in enumerate(self.data_source): | ||
value = self.sort_key(row) | ||
noise_value = value * sort_key_noise | ||
noise = random.uniform(-noise_value, noise_value) | ||
value = noise + value | ||
zip.append(tuple([i, value])) | ||
zip = sorted(zip, key=lambda r: r[1]) | ||
self.sorted_indexes = [item[0] for item in zip] | ||
zip_.append(tuple([i, value])) | ||
zip_ = sorted(zip_, key=lambda r: r[1]) | ||
self.sorted_indexes = [item[0] for item in zip_] | ||
|
||
def __iter__(self): | ||
return iter(self.sorted_indexes) | ||
|
||
def __len__(self): | ||
return len(self.data_source) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from lib.samplers import BucketBatchSampler | ||
|
||
|
||
def test_bucket_batch_sampler(): | ||
data_source = [[1], [2], [3], [4], [5], [6]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list(BucketBatchSampler(data_source, sort_key, batch_size)) | ||
assert len(batches) == 3 | ||
|
||
|
||
def test_bucket_batch_sampler_uneven(): | ||
data_source = [[1], [2], [3], [4], [5]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list(BucketBatchSampler(data_source, sort_key, batch_size)) | ||
assert len(batches) == 3 | ||
batches = list(BucketBatchSampler(data_source, sort_key, batch_size, drop_last=True)) | ||
assert len(batches) == 2 | ||
|
||
|
||
def test_bucket_batch_sampler_last_batch_first(): | ||
data_source = [[1], [2], [3], [4], [5, 6, 7, 8, 9, 10]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list(BucketBatchSampler(data_source, sort_key, batch_size, last_batch_first=True)) | ||
# Largest batch (4) is in first batch | ||
assert 4 in batches[0] | ||
|
||
|
||
def test_bucket_batch_sampler_sorted(): | ||
data_source = [[1], [2], [3], [4], [5]] | ||
sort_key = lambda r: r[0] | ||
batch_size = 1 | ||
batches = list( | ||
BucketBatchSampler( | ||
data_source, | ||
sort_key, | ||
batch_size, | ||
shuffle=False, | ||
last_batch_first=False, | ||
sort_key_noise=0.0)) | ||
# Largest batch (4) is in first batch | ||
for i, batch in enumerate(batches): | ||
assert batch[0] == i |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from lib.samplers import NoisySortedSampler | ||
|
||
|
||
def test_noisy_sorted_sampler(): | ||
data_source = [[1], [2], [3], [4], [5], [6]] | ||
sort_key = lambda r: r[0] | ||
batch_size = 2 | ||
indexes = list(NoisySortedSampler(data_source, sort_key)) | ||
assert len(indexes) == len(data_source) | ||
|
||
|
||
def test_noisy_sorted_sampler_sorted(): | ||
data_source = [[1], [2], [3], [4], [5], [6]] | ||
sort_key = lambda r: r[0] | ||
batch_size = 2 | ||
indexes = list(NoisySortedSampler(data_source, sort_key, sort_key_noise=0.0)) | ||
assert len(indexes) == len(data_source) | ||
for i, j in enumerate(indexes): | ||
assert i == j | ||
|
||
|
||
def test_noisy_sorted_sampler_sort_key_noise(): | ||
data_source = [[2], [6], [10]] | ||
sort_key = lambda r: r[0] | ||
batch_size = 2 | ||
# `sort_key_noise` does not affect values 2, 6, 10 | ||
indexes = list(NoisySortedSampler(data_source, sort_key, sort_key_noise=0.25)) | ||
for i, j in enumerate(indexes): | ||
assert i == j |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from lib.samplers import RandomBatchSampler | ||
|
||
from lib.samplers import SortedSampler | ||
|
||
|
||
def test_random_batch_sampler(): | ||
data_source = [[1], [2], [3], [4], [5], [6]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list(RandomBatchSampler(SortedSampler(data_source, sort_key), batch_size, False)) | ||
assert len(batches) == 3 | ||
|
||
|
||
def test_random_batch_sampler_drop_last(): | ||
data_source = [[1], [2], [3], [4], [5]] | ||
sort_key = lambda r: len(r) | ||
batch_size = 2 | ||
batches = list( | ||
RandomBatchSampler(SortedSampler(data_source, sort_key), batch_size, drop_last=True)) | ||
assert len(batches) == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from lib.samplers import SortedSampler | ||
|
||
|
||
def test_sorted_sampler(): | ||
data_source = [[1], [2], [3], [4], [5], [6]] | ||
sort_key = lambda r: r[0] | ||
batch_size = 2 | ||
indexes = list(SortedSampler(data_source, sort_key)) | ||
assert len(indexes) == len(data_source) | ||
for i, j in enumerate(indexes): | ||
assert i == j |