Skip to content
Browse files

support iterators in view.map

Some objects are not sliceable (e.g. xrange). For these, fall back
on itertools.islice.

tests included

reported on IRC by @juliantaylor
  • Loading branch information...
1 parent 4b8920a commit a3ee8de989ec5949a78aa09c52ed7e2d2cb0c166 @minrk minrk committed Dec 2, 2011
Showing with 30 additions and 4 deletions.
  1. +8 −2 IPython/parallel/client/map.py
  2. +11 −0 IPython/parallel/tests/test_lbview.py
  3. +11 −2 IPython/parallel/tests/test_view.py
View
10 IPython/parallel/client/map.py
@@ -27,6 +27,7 @@
from __future__ import division
import types
+from itertools import islice
from IPython.utils.data import flatten as utils_flatten
@@ -77,9 +78,14 @@ def getPartition(self, seq, p, q):
else:
lo.append(n*basesize + remainder)
hi.append(lo[-1] + basesize)
-
- result = seq[lo[p]:hi[p]]
+ try:
+ result = seq[lo[p]:hi[p]]
+ except TypeError:
+ # some objects (iterators) can't be sliced,
+ # use islice:
+ result = list(islice(seq, lo[p], hi[p]))
+
return result
def joinPartitions(self, listOfPartitions):
View
11 IPython/parallel/tests/test_lbview.py
@@ -95,6 +95,17 @@ def slow_f(x):
# Ensure that results came in order
self.assertEquals(astheycame, reference)
self.assertEquals(amr.result, reference)
+
+ def test_map_iterable(self):
+ """test map on iterables (balanced)"""
+ view = self.view
+ # 101 is prime, so it won't be evenly distributed
+ arr = range(101)
+ # so that it will be an iterator, even in Python 3
+ it = iter(arr)
+ r = view.map_sync(lambda x:x, arr)
+ self.assertEquals(r, list(arr))
+
def test_abort(self):
view = self.view
View
13 IPython/parallel/tests/test_view.py
@@ -240,6 +240,16 @@ def f(x):
r = view.map_sync(f, data)
self.assertEquals(r, map(f, data))
+ def test_map_iterable(self):
+ """test map on iterables (direct)"""
+ view = self.client[:]
+ # 101 is prime, so it won't be evenly distributed
+ arr = range(101)
+ # ensure it will be an iterator, even in Python 3
+ it = iter(arr)
+ r = view.map_sync(lambda x:x, arr)
+ self.assertEquals(r, list(arr))
+
def test_scatterGatherNonblocking(self):
data = range(16)
view = self.client[:]
@@ -446,6 +456,5 @@ def check_unicode(a, check):
self.fail(e.evalue)
else:
raise e
-
-
+

0 comments on commit a3ee8de

Please sign in to comment.
Something went wrong with that request. Please try again.