Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21866][ML][PYTHON][FOLLOWUP] Few cleanups and fix image test failure in Python 3.6.0 / NumPy 1.13.3 #19835

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 24 additions & 3 deletions python/pyspark/ml/image.py
Expand Up @@ -108,12 +108,23 @@ def toNDArray(self, image):
"""
Converts an image to an array with metadata.

:param image: The image to be converted.
:param `Row` image: A row that contains the image to be converted. It should
have the attributes specified in `ImageSchema.imageSchema`.
:return: a `numpy.ndarray` that is an image.

.. versionadded:: 2.3.0
"""

if not isinstance(image, Row):
raise TypeError(
"image argument should be pyspark.sql.types.Row; however, "
"it got [%s]." % type(image))

if any(not hasattr(image, f) for f in self.imageFields):
raise ValueError(
"image argument should have attributes specified in "
"ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields))

height = image.height
width = image.width
nChannels = image.nChannels
Expand All @@ -127,15 +138,20 @@ def toImage(self, array, origin=""):
"""
Converts an array with metadata to a two-dimensional image.

:param array array: The array to convert to image.
:param `numpy.ndarray` array: The array to convert to image.
:param str origin: Path to the image, optional.
:return: a :class:`Row` that is a two dimensional image.

.. versionadded:: 2.3.0
"""

if not isinstance(array, np.ndarray):
raise TypeError(
"array argument should be numpy.ndarray; however, it got [%s]." % type(array))

if array.ndim != 3:
raise ValueError("Invalid array shape")

height, width, nChannels = array.shape
ocvTypes = ImageSchema.ocvTypes
if nChannels == 1:
Expand All @@ -146,7 +162,12 @@ def toImage(self, array, origin=""):
mode = ocvTypes["CV_8UC4"]
else:
raise ValueError("Invalid number of channels")
data = bytearray(array.astype(dtype=np.uint8).ravel())

# Running `bytearray(numpy.array([1]))` fails in specific Python versions
# with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
# Here, it avoids it by converting it to bytes.
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I can't find the exact issue and changes about this issue yet. There are too many similar / related issues in NumPy and Python release notes, and it sounds even harder to find the relevant issue as the exception is from NumPy but the cause seems a different Python version (3.6.0), if I haven't missed something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strange, but the comment explains the issue well and I think this is a good workaround


# Creating new Row with _create_row(), because Row(name = value, ... )
# orders fields by name, which conflicts with expected schema order
# when the new DataFrame is created by UDF
Expand Down
20 changes: 19 additions & 1 deletion python/pyspark/ml/tests.py
Expand Up @@ -71,7 +71,7 @@
from pyspark.sql.functions import rand
from pyspark.sql.types import DoubleType, IntegerType
from pyspark.storagelevel import *
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase

ser = PickleSerializer()

Expand Down Expand Up @@ -1836,6 +1836,24 @@ def test_read_images(self):
self.assertEqual(ImageSchema.imageFields, expected)
self.assertEqual(ImageSchema.undefinedImageType, "Undefined")

with QuietTest(self.sc):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice tests!

self.assertRaisesRegexp(
TypeError,
"image argument should be pyspark.sql.types.Row; however",
lambda: ImageSchema.toNDArray("a"))

with QuietTest(self.sc):
self.assertRaisesRegexp(
ValueError,
"image argument should have attributes specified in",
lambda: ImageSchema.toNDArray(Row(a=1)))

with QuietTest(self.sc):
self.assertRaisesRegexp(
TypeError,
"array argument should be numpy.ndarray; however, it got",
lambda: ImageSchema.toImage("a"))


class ALSTest(SparkSessionTestCase):

Expand Down