Skip to content

Commit

Permalink
support iterators in view.map
Browse files Browse the repository at this point in the history
Some objects are not sliceable (e.g. xrange). For these, fall back
on itertools.islice.

tests included

reported on IRC by @juliantaylor
  • Loading branch information
minrk committed Dec 2, 2011
1 parent 4b8920a commit a3ee8de
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
10 changes: 8 additions & 2 deletions IPython/parallel/client/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from __future__ import division

import types
from itertools import islice

from IPython.utils.data import flatten as utils_flatten

Expand Down Expand Up @@ -77,9 +78,14 @@ def getPartition(self, seq, p, q):
else:
lo.append(n*basesize + remainder)
hi.append(lo[-1] + basesize)


result = seq[lo[p]:hi[p]]
try:
result = seq[lo[p]:hi[p]]
except TypeError:
# some objects (iterators) can't be sliced,
# use islice:
result = list(islice(seq, lo[p], hi[p]))

return result

def joinPartitions(self, listOfPartitions):
Expand Down
11 changes: 11 additions & 0 deletions IPython/parallel/tests/test_lbview.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ def slow_f(x):
# Ensure that results came in order
self.assertEquals(astheycame, reference)
self.assertEquals(amr.result, reference)

def test_map_iterable(self):
"""test map on iterables (balanced)"""
view = self.view
# 101 is prime, so it won't be evenly distributed
arr = range(101)
# so that it will be an iterator, even in Python 3
it = iter(arr)
r = view.map_sync(lambda x:x, arr)
self.assertEquals(r, list(arr))


def test_abort(self):
view = self.view
Expand Down
13 changes: 11 additions & 2 deletions IPython/parallel/tests/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ def f(x):
r = view.map_sync(f, data)
self.assertEquals(r, map(f, data))

def test_map_iterable(self):
"""test map on iterables (direct)"""
view = self.client[:]
# 101 is prime, so it won't be evenly distributed
arr = range(101)
# ensure it will be an iterator, even in Python 3
it = iter(arr)
r = view.map_sync(lambda x:x, arr)
self.assertEquals(r, list(arr))

def test_scatterGatherNonblocking(self):
data = range(16)
view = self.client[:]
Expand Down Expand Up @@ -446,6 +456,5 @@ def check_unicode(a, check):
self.fail(e.evalue)
else:
raise e




0 comments on commit a3ee8de

Please sign in to comment.