Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Samplers Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PetrochukM committed Feb 26, 2018
1 parent fd46926 commit 616fd74
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 13 deletions.
4 changes: 3 additions & 1 deletion lib/samplers/bucket_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(self,
self.last_batch_first = last_batch_first
self.shuffle = shuffle
super().__init__(
NoisySortedSampler(data_source, sort_key, sort_key_noise), batch_size, drop_last)
NoisySortedSampler(
data_source=data_source, sort_key=sort_key, sort_key_noise=sort_key_noise),
batch_size, drop_last)

def __iter__(self):
batches = list(super().__iter__())
Expand Down
26 changes: 17 additions & 9 deletions lib/samplers/noisy_sorted_sampler.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
import random

from lib.samplers.sorted_sampler import SortedSampler
from torch.utils.data.sampler import Sampler


class NoisySortedSampler(SortedSampler):
"""Samples elements sequentially, always in the same order.
class NoisySortedSampler(Sampler):
"""Samples elements sequentially with noise.
Reference and inspiration:
https://github.com/allenai/allennlp/blob/e125a490b71b21e914af01e70e9b00b165d64dcd/allennlp/data/iterators/bucket_iterator.py
Arguments:
data_source (Dataset): dataset to sample from
sort_key (callable -> int): callable that returns from one row of the data_source a int
sort_key (callable): specifies a function of one argument that is used to extract a
comparison key from each list element
"""

def __init__(self, data_source, sort_key, sort_key_noise=0.1):
def __init__(self, data_source, sort_key, sort_key_noise=0.25):
super().__init__(data_source)
self.data_source = data_source
self.sort_key = sort_key
zip = []
zip_ = []
for i, row in enumerate(self.data_source):
value = self.sort_key(row)
noise_value = value * sort_key_noise
noise = random.uniform(-noise_value, noise_value)
value = noise + value
zip.append(tuple([i, value]))
zip = sorted(zip, key=lambda r: r[1])
self.sorted_indexes = [item[0] for item in zip]
zip_.append(tuple([i, value]))
zip_ = sorted(zip_, key=lambda r: r[1])
self.sorted_indexes = [item[0] for item in zip_]

def __iter__(self):
return iter(self.sorted_indexes)

def __len__(self):
return len(self.data_source)
7 changes: 4 additions & 3 deletions lib/samplers/sorted_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ class SortedSampler(Sampler):
Arguments:
data_source (Dataset): dataset to sample from
sort_key (callable): callable that returns from one row of the data_source a sortable
value
sort_key (callable): specifies a function of one argument that is used to extract a
comparison key from each list element
"""

def __init__(self, data_source, sort_key, sort_noise=0.1):
def __init__(self, data_source, sort_key):
super().__init__(data_source)
self.data_source = data_source
self.sort_key = sort_key
zip = [(i, self.sort_key(row)) for i, row in enumerate(self.data_source)]
Expand Down
45 changes: 45 additions & 0 deletions tests/samplers/test_bucket_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from lib.samplers import BucketBatchSampler


def test_bucket_batch_sampler():
data_source = [[1], [2], [3], [4], [5], [6]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(BucketBatchSampler(data_source, sort_key, batch_size))
assert len(batches) == 3


def test_bucket_batch_sampler_uneven():
data_source = [[1], [2], [3], [4], [5]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(BucketBatchSampler(data_source, sort_key, batch_size))
assert len(batches) == 3
batches = list(BucketBatchSampler(data_source, sort_key, batch_size, drop_last=True))
assert len(batches) == 2


def test_bucket_batch_sampler_last_batch_first():
data_source = [[1], [2], [3], [4], [5, 6, 7, 8, 9, 10]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(BucketBatchSampler(data_source, sort_key, batch_size, last_batch_first=True))
# Largest batch (4) is in first batch
assert 4 in batches[0]


def test_bucket_batch_sampler_sorted():
data_source = [[1], [2], [3], [4], [5]]
sort_key = lambda r: r[0]
batch_size = 1
batches = list(
BucketBatchSampler(
data_source,
sort_key,
batch_size,
shuffle=False,
last_batch_first=False,
sort_key_noise=0.0))
# Largest batch (4) is in first batch
for i, batch in enumerate(batches):
assert batch[0] == i
29 changes: 29 additions & 0 deletions tests/samplers/test_noisy_sorted_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from lib.samplers import NoisySortedSampler


def test_noisy_sorted_sampler():
data_source = [[1], [2], [3], [4], [5], [6]]
sort_key = lambda r: r[0]
batch_size = 2
indexes = list(NoisySortedSampler(data_source, sort_key))
assert len(indexes) == len(data_source)


def test_noisy_sorted_sampler_sorted():
data_source = [[1], [2], [3], [4], [5], [6]]
sort_key = lambda r: r[0]
batch_size = 2
indexes = list(NoisySortedSampler(data_source, sort_key, sort_key_noise=0.0))
assert len(indexes) == len(data_source)
for i, j in enumerate(indexes):
assert i == j


def test_noisy_sorted_sampler_sort_key_noise():
data_source = [[2], [6], [10]]
sort_key = lambda r: r[0]
batch_size = 2
# `sort_key_noise` does not affect values 2, 6, 10
indexes = list(NoisySortedSampler(data_source, sort_key, sort_key_noise=0.25))
for i, j in enumerate(indexes):
assert i == j
20 changes: 20 additions & 0 deletions tests/samplers/test_random_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from lib.samplers import RandomBatchSampler

from lib.samplers import SortedSampler


def test_random_batch_sampler():
data_source = [[1], [2], [3], [4], [5], [6]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(RandomBatchSampler(SortedSampler(data_source, sort_key), batch_size, False))
assert len(batches) == 3


def test_random_batch_sampler_drop_last():
data_source = [[1], [2], [3], [4], [5]]
sort_key = lambda r: len(r)
batch_size = 2
batches = list(
RandomBatchSampler(SortedSampler(data_source, sort_key), batch_size, drop_last=True))
assert len(batches) == 2
11 changes: 11 additions & 0 deletions tests/samplers/test_sorted_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lib.samplers import SortedSampler


def test_sorted_sampler():
data_source = [[1], [2], [3], [4], [5], [6]]
sort_key = lambda r: r[0]
batch_size = 2
indexes = list(SortedSampler(data_source, sort_key))
assert len(indexes) == len(data_source)
for i, j in enumerate(indexes):
assert i == j

0 comments on commit 616fd74

Please sign in to comment.