Skip to content

Commit

Permalink
Add DeterministicList
Browse files Browse the repository at this point in the history
  • Loading branch information
aleju committed Nov 1, 2019
1 parent ed64a39 commit f795777
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
3 changes: 3 additions & 0 deletions changelogs/master/added/20191101_deterministic_list.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Added DeterministicList

* Added `imgaug.parameters.DeterministicList`.
46 changes: 46 additions & 0 deletions imgaug/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,52 @@ def __str__(self):
return "Deterministic(%s)" % (str(self.value),)


# TODO tests
class DeterministicList(StochasticParameter):
"""Parameter that repeats elements from a list in the given order.
E.g. of samples of shape ``(A, B, C)`` are requested, this parameter will
return the first ``A*B*C`` elements, reshaped to ``(A, B, C)`` from the
provided list. If the list contains less than ``A*B*C`` elements, it
will (by default) be tiled until it is long enough (i.e. the sampling
will start again at the first element, if necessary multiple times).
Parameters
----------
values : iterable
An iterable of values to sample from in the order within the iterable.
"""

def __init__(self, values, cycle=True):
super(DeterministicList, self).__init__()

assert ia.is_iterable(values), (
"Expected to get an iterable as input, got type %s." % (
type(values).__name__,))
values = np.array(values).flatten()
assert len(values) > 0, ("Expected to get at least one value, got "
"zero.")
self.values = values

def _draw_samples(self, size, random_state):
nb_requested = int(np.prod(size))
if nb_requested > self.values.size:
# we don't use itertools.cycle() here, as that would require
# running through a loop potentially many times (as `size` can
# be very large), which would be slow
multiplier = int(np.ceil(nb_requested / self.values.size))
values = np.tile(self.values, (multiplier,))
return values[:nb_requested].reshape(size)

def __repr__(self):
return self.__str__()

def __str__(self):
return "DeterministicList(%s, cycle=%s)" % (str(self.value.tolist()),
self.cycle)


class Choice(StochasticParameter):
"""Parameter that samples value from a list of allowed values.
Expand Down

0 comments on commit f795777

Please sign in to comment.