From 32064ceaf9a0ddfba28d3c56e2ee88eaf1ba4aca Mon Sep 17 00:00:00 2001 From: tomasatdatabricks Date: Fri, 29 Dec 2017 14:56:28 -0800 Subject: [PATCH 1/5] Added functionality for handling non-uint8-based images for ImageSchema Added test for conversion between array and image struct for all ocv types. --- .../apache/spark/ml/image/ImageSchema.scala | 59 ++++++++++--- .../spark/ml/image/ImageSchemaSuite.scala | 5 +- python/pyspark/ml/image.py | 87 ++++++++++++++----- python/pyspark/ml/tests.py | 26 +++++- 4 files changed, 137 insertions(+), 40 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index f7850b238465b..0cd2617275726 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -37,20 +37,51 @@ import org.apache.spark.sql.types._ @Since("2.3.0") object ImageSchema { - val undefinedImageType = "Undefined" - /** - * (Scala-specific) OpenCV type mapping supported + * OpenCv type representation + * @param mode ordinal for the type + * @param dataType open cv data type + * @param nChannels number of color channels */ - val ocvTypes: Map[String, Int] = Map( - undefinedImageType -> -1, - "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC3" -> 16, "CV_8UC4" -> 24 - ) + case class OpenCvType(mode: Int, dataType: String, nChannels: Int) { + def name: String = "CV_" + dataType + "C" + nChannels + override def toString: String = "OpenCvType(mode = " + mode + ", name = " + name + ")" + } + + object OpenCvType { + def get(name: String): OpenCvType = { + ocvTypes.find(x => x.name == name).getOrElse( + throw new IllegalArgumentException("Unknown open cv type " + name)) + } + def get(mode: Int): OpenCvType = { + ocvTypes.find(x => x.mode == mode).getOrElse( + throw new IllegalArgumentException("Unknown open cv mode " + mode)) + } + val undefinedType = OpenCvType(-1, "N/A", -1) + } /** - * (Java-specific) OpenCV type mapping supported + * A Mapping of Type to Numbers in OpenCV + * + * C1 C2 C3 C4 + * CV_8U 0 8 16 24 + * CV_8S 1 9 17 25 + * CV_16U 2 10 18 26 + * CV_16S 3 11 19 27 + * CV_32S 4 12 20 28 + * CV_32F 5 13 21 29 + * CV_64F 6 14 22 30 */ - val javaOcvTypes: java.util.Map[String, Int] = ocvTypes.asJava + val ocvTypes = { + val types = + for (nc <- Array(1, 2, 3, 4); + dt <- Array("8U", "8S", "16U", "16S", "32S", "32F", "64F")) + yield (dt, nc) + val ordinals = for (i <- 0 to 3; j <- 0 to 6) yield ( i * 8 + j) + OpenCvType.undefinedType +: (ordinals zip types).map(x => OpenCvType(x._1, x._2._1, x._2._2)) + } + + val javaOcvTypes = ocvTypes.asJava /** * Schema for the image column: Row(String, Int, Int, Int, Int, Array[Byte]) @@ -121,7 +152,7 @@ object ImageSchema { * @return Row with the default values */ private[spark] def invalidImageRow(origin: String): Row = - Row(Row(origin, -1, -1, -1, ocvTypes(undefinedImageType), Array.ofDim[Byte](0))) + Row(Row(origin, -1, -1, -1, OpenCvType.undefinedType.mode, Array.ofDim[Byte](0))) /** * Convert the compressed image (jpeg, png, etc.) into OpenCV @@ -143,12 +174,12 @@ object ImageSchema { val height = img.getHeight val width = img.getWidth - val (nChannels, mode) = if (isGray) { - (1, ocvTypes("CV_8UC1")) + val (nChannels, mode: Int) = if (isGray) { + (1, OpenCvType.get("CV_8UC1").mode) } else if (hasAlpha) { - (4, ocvTypes("CV_8UC4")) + (4, OpenCvType.get("CV_8UC4").mode) } else { - (3, ocvTypes("CV_8UC3")) + (3, OpenCvType.get("CV_8UC3").mode) } val imageSize = height * width * nChannels diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index dba61cd1eb1cc..3d83b6c220523 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -36,7 +36,7 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { val height = 1 val nChannels = 3 val data = Array[Byte](0, 0, 0) - val mode = ocvTypes("CV_8UC3") + val mode: Int = OpenCvType.get("CV_8UC3").mode // Internal Row corresponds to image StructType val rows = Seq(Row(Row(origin, height, width, nChannels, mode, data)), @@ -83,7 +83,8 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { val bytes20 = getData(row).slice(0, 20) val (expectedMode, expectedBytes) = firstBytes20(filename) - assert(ocvTypes(expectedMode) === mode, "mode of the image is not read correctly") + assert(OpenCvType.get(expectedMode).mode === mode, + "mode of the image is not read correctly") assert(Arrays.equals(expectedBytes, bytes20), "incorrect numeric value for flattened image") } } diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index c9b840276f675..16697d01c8ea6 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -25,6 +25,8 @@ """ import numpy as np +from collections import namedtuple + from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string from pyspark.sql import DataFrame, SparkSession @@ -40,9 +42,24 @@ class _ImageSchema(object): def __init__(self): self._imageSchema = None self._ocvTypes = None + self._ocvTypesByName = None + self._ocvTypesByMode = None self._imageFields = None self._undefinedImageType = None + _OcvType = namedtuple("OcvType", ["name", "mode", "nChannels", "dataType", "nptype"]) + + _ocvToNumpyMap = { + "N/A": "N/A", + "8U": np.dtype("uint8"), + "8S": np.dtype("int8"), + "16U": np.dtype('uint16'), + "16S": np.dtype('int16'), + "32S": np.dtype('int32'), + "32F": np.dtype('float32'), + "64F": np.dtype('float64')} + _numpyToOcvMap = {x[1]: x[0] for x in _ocvToNumpyMap.items()} + @property def imageSchema(self): """ @@ -55,7 +72,7 @@ def imageSchema(self): """ if self._imageSchema is None: - ctx = SparkContext._active_spark_context + ctx = SparkContext.getOrCreate() jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema() self._imageSchema = _parse_datatype_json_string(jschema.json()) return self._imageSchema @@ -71,9 +88,30 @@ def ocvTypes(self): """ if self._ocvTypes is None: - ctx = SparkContext._active_spark_context - self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) - return self._ocvTypes + ctx = SparkContext.getOrCreate() + ocvTypeList = ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes() + self._ocvTypes = [self._OcvType(name=x.name(), + mode=x.mode(), + nChannels=x.nChannels(), + dataType=x.dataType(), + nptype=self._ocvToNumpyMap[x.dataType()]) + for x in ocvTypeList] + return self._ocvTypes[:] + + def ocvTypeByName(self, name): + if self._ocvTypesByName is None: + self._ocvTypesByName = {x.name: x for x in self.ocvTypes} + if name not in self._ocvTypesByName: + raise ValueError( + "Can not find matching OpenCvFormat for type = '%s'; supported formats are = %s" % + (name, str( + self._ocvTypesByName.keys()))) + return self._ocvTypesByName[name] + + def ocvTypeByMode(self, mode): + if self._ocvTypesByMode is None: + self._ocvTypesByMode = {x.mode: x for x in self.ocvTypes} + return self._ocvTypesByMode[mode] @property def imageFields(self): @@ -86,7 +124,7 @@ def imageFields(self): """ if self._imageFields is None: - ctx = SparkContext._active_spark_context + ctx = SparkContext.getOrCreate() self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields()) return self._imageFields @@ -99,7 +137,7 @@ def undefinedImageType(self): """ if self._undefinedImageType is None: - ctx = SparkContext._active_spark_context + ctx = SparkContext.getOrCreate() self._undefinedImageType = \ ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType() return self._undefinedImageType @@ -128,11 +166,17 @@ def toNDArray(self, image): height = image.height width = image.width nChannels = image.nChannels + ocvType = self.ocvTypeByMode(image.mode) + if nChannels != ocvType.nChannels: + raise ValueError( + "Image has %d channels but OcvType '%s' expects %d channels." % + (nChannels, ocvType.name, ocvType.nChannels)) + itemSz = ocvType.nptype.itemsize return np.ndarray( shape=(height, width, nChannels), - dtype=np.uint8, + dtype=ocvType.nptype, buffer=image.data, - strides=(width * nChannels, nChannels, 1)) + strides=(width * nChannels * itemSz, nChannels * itemSz, itemSz)) def toImage(self, array, origin=""): """ @@ -150,29 +194,27 @@ def toImage(self, array, origin=""): "array argument should be numpy.ndarray; however, it got [%s]." % type(array)) if array.ndim != 3: - raise ValueError("Invalid array shape") + raise ValueError("Invalid array shape %s" % str(array.shape)) height, width, nChannels = array.shape - ocvTypes = ImageSchema.ocvTypes - if nChannels == 1: - mode = ocvTypes["CV_8UC1"] - elif nChannels == 3: - mode = ocvTypes["CV_8UC3"] - elif nChannels == 4: - mode = ocvTypes["CV_8UC4"] - else: - raise ValueError("Invalid number of channels") + dtype = array.dtype + if dtype not in self._numpyToOcvMap: + raise ValueError( + "Unsupported array data type '%s', currently only supported formats are %s" % + (str(array.dtype), str(self._numpyToOcvMap.keys()))) + ocvName = "CV_%sC%d" % (self._numpyToOcvMap[dtype], nChannels) + ocvType = self.ocvTypeByName(ocvName) # Running `bytearray(numpy.array([1]))` fails in specific Python versions # with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3. # Here, it avoids it by converting it to bytes. - data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes()) + data = bytearray(array.tobytes()) # Creating new Row with _create_row(), because Row(name = value, ... ) # orders fields by name, which conflicts with expected schema order # when the new DataFrame is created by UDF return _create_row(self.imageFields, - [origin, height, width, nChannels, mode, data]) + [origin, height, width, nChannels, ocvType.mode, data]) def readImages(self, path, recursive=False, numPartitions=-1, dropImageFailures=False, sampleRatio=1.0, seed=0): @@ -201,8 +243,9 @@ def readImages(self, path, recursive=False, numPartitions=-1, .. versionadded:: 2.3.0 """ - spark = SparkSession.builder.getOrCreate() - image_schema = spark._jvm.org.apache.spark.ml.image.ImageSchema + ctx = SparkContext.getOrCreate() + spark = SparkSession(ctx) + image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema jsession = spark._jsparkSession jresult = image_schema.readImages(path, jsession, recursive, numPartitions, dropImageFailures, float(sampleRatio), seed) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900d..628735418e15f 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -45,6 +45,7 @@ from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros import inspect import py4j +import random from pyspark import keyword_only, SparkContext from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer @@ -1843,6 +1844,28 @@ def tearDown(self): class ImageReaderTest(SparkSessionTestCase): + def test_ocv_types(self): + ocvList = ImageSchema.ocvTypes + self.assertEqual("Undefined", ocvList[0].name) + self.assertEqual(-1, ocvList[0].mode) + self.assertEqual("N/A", ocvList[0].dataType) + for x in ocvList: + self.assertEqual(x, ImageSchema.ocvTypeByName(x.name)) + self.assertEqual(x, ImageSchema.ocvTypeByMode(x.mode)) + + def test_conversions(self): + ary_src = [[[1e7*random.random() for z in range(4)] for y in range(10)] for x in range(10)] + for ocvType in ImageSchema.ocvTypes: + if ocvType.name == 'Undefined': + continue + x = [[ary_src[i][j][0:ocvType.nChannels] + for j in range(len(ary_src[0]))] for i in range(len(ary_src))] + npary0 = np.array(x).astype(ocvType.nptype) + img = ImageSchema.toImage(npary0) + self.assertEqual(ocvType, ImageSchema.ocvTypeByMode(img.mode)) + npary1 = ImageSchema.toNDArray(img) + np.testing.assert_array_equal(npary0, npary1) + def test_read_images(self): data_path = 'data/mllib/images/kittens' df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True) @@ -1852,8 +1875,7 @@ def test_read_images(self): self.assertEqual(len(array), first_row[1]) self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) self.assertEqual(df.schema, ImageSchema.imageSchema) - expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} - self.assertEqual(ImageSchema.ocvTypes, expected) + expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] self.assertEqual(ImageSchema.imageFields, expected) self.assertEqual(ImageSchema.undefinedImageType, "Undefined") From 53c4d769b73c2011c12d13c75543c3cfb5f9de08 Mon Sep 17 00:00:00 2001 From: tomasatdatabricks Date: Tue, 9 Jan 2018 11:11:58 -0800 Subject: [PATCH 2/5] Addressed reviewers comments. Fixed name method on OpenCvType to return correct name for Undefined type. Removed OpenCvType object and renamed the methods to match python side. + few cosmetic changes. --- .../apache/spark/ml/image/ImageSchema.scala | 37 ++++++++++--------- .../spark/ml/image/ImageSchemaSuite.scala | 4 +- python/pyspark/ml/image.py | 17 ++++----- python/pyspark/ml/tests.py | 9 ++--- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index 0cd2617275726..de8ccc90101cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -44,22 +44,22 @@ object ImageSchema { * @param nChannels number of color channels */ case class OpenCvType(mode: Int, dataType: String, nChannels: Int) { - def name: String = "CV_" + dataType + "C" + nChannels - override def toString: String = "OpenCvType(mode = " + mode + ", name = " + name + ")" + def name: String = if (mode == -1) { "Undefined" } else { s"CV_$dataType" + s"C$nChannels" } + override def toString: String = s"OpenCvType(mode = $mode, name = $name)" } - object OpenCvType { - def get(name: String): OpenCvType = { - ocvTypes.find(x => x.name == name).getOrElse( - throw new IllegalArgumentException("Unknown open cv type " + name)) - } - def get(mode: Int): OpenCvType = { - ocvTypes.find(x => x.mode == mode).getOrElse( - throw new IllegalArgumentException("Unknown open cv mode " + mode)) - } - val undefinedType = OpenCvType(-1, "N/A", -1) + def ocvTypeByName(name: String): OpenCvType = { + ocvTypes.find(x => x.name == name).getOrElse( + throw new IllegalArgumentException("Unknown open cv type " + name)) + } + + def ocvTypeByMode(mode: Int): OpenCvType = { + ocvTypes.find(x => x.mode == mode).getOrElse( + throw new IllegalArgumentException("Unknown open cv mode " + mode)) } + val undefinedImageType = OpenCvType(-1, "N/A", -1) + /** * A Mapping of Type to Numbers in OpenCV * @@ -78,9 +78,12 @@ object ImageSchema { dt <- Array("8U", "8S", "16U", "16S", "32S", "32F", "64F")) yield (dt, nc) val ordinals = for (i <- 0 to 3; j <- 0 to 6) yield ( i * 8 + j) - OpenCvType.undefinedType +: (ordinals zip types).map(x => OpenCvType(x._1, x._2._1, x._2._2)) + undefinedImageType +: (ordinals zip types).map(x => OpenCvType(x._1, x._2._1, x._2._2)) } + /** + * (Java Specific) list of OpenCv types + */ val javaOcvTypes = ocvTypes.asJava /** @@ -152,7 +155,7 @@ object ImageSchema { * @return Row with the default values */ private[spark] def invalidImageRow(origin: String): Row = - Row(Row(origin, -1, -1, -1, OpenCvType.undefinedType.mode, Array.ofDim[Byte](0))) + Row(Row(origin, -1, -1, -1, undefinedImageType.mode, Array.ofDim[Byte](0))) /** * Convert the compressed image (jpeg, png, etc.) into OpenCV @@ -175,11 +178,11 @@ object ImageSchema { val height = img.getHeight val width = img.getWidth val (nChannels, mode: Int) = if (isGray) { - (1, OpenCvType.get("CV_8UC1").mode) + (1, ocvTypeByName("CV_8UC1").mode) } else if (hasAlpha) { - (4, OpenCvType.get("CV_8UC4").mode) + (4, ocvTypeByName("CV_8UC4").mode) } else { - (3, OpenCvType.get("CV_8UC3").mode) + (3, ocvTypeByName("CV_8UC3").mode) } val imageSize = height * width * nChannels diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index 3d83b6c220523..c45faa8b89cc7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -36,7 +36,7 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { val height = 1 val nChannels = 3 val data = Array[Byte](0, 0, 0) - val mode: Int = OpenCvType.get("CV_8UC3").mode + val mode: Int = ImageSchema.ocvTypeByName("CV_8UC3").mode // Internal Row corresponds to image StructType val rows = Seq(Row(Row(origin, height, width, nChannels, mode, data)), @@ -83,7 +83,7 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { val bytes20 = getData(row).slice(0, 20) val (expectedMode, expectedBytes) = firstBytes20(filename) - assert(OpenCvType.get(expectedMode).mode === mode, + assert(ImageSchema.ocvTypeByName(expectedMode).mode === mode, "mode of the image is not read correctly") assert(Arrays.equals(expectedBytes, bytes20), "incorrect numeric value for flattened image") } diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 16697d01c8ea6..d8d0e7fc646f5 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -104,13 +104,16 @@ def ocvTypeByName(self, name): if name not in self._ocvTypesByName: raise ValueError( "Can not find matching OpenCvFormat for type = '%s'; supported formats are = %s" % - (name, str( - self._ocvTypesByName.keys()))) + (name, str(self._ocvTypesByName.keys()))) return self._ocvTypesByName[name] def ocvTypeByMode(self, mode): if self._ocvTypesByMode is None: self._ocvTypesByMode = {x.mode: x for x in self.ocvTypes} + if mode not in self._ocvTypesByMode: + raise ValueError( + "Invalid mode '%d'; supported modes are = %s" % + (name, str(self._ocvTypesByMode.keys()))) return self._ocvTypesByMode[mode] @property @@ -135,11 +138,8 @@ def undefinedImageType(self): .. versionadded:: 2.3.0 """ - if self._undefinedImageType is None: - ctx = SparkContext.getOrCreate() - self._undefinedImageType = \ - ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType() + self._undefinedImageType = self.ocvTypeByName("Undefined") return self._undefinedImageType def toNDArray(self, image): @@ -243,9 +243,8 @@ def readImages(self, path, recursive=False, numPartitions=-1, .. versionadded:: 2.3.0 """ - ctx = SparkContext.getOrCreate() - spark = SparkSession(ctx) - image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema + spark = SparkSession.builder.getOrCreate() + image_schema = spark._jvm.org.apache.spark.ml.image.ImageSchema jsession = spark._jsparkSession jresult = image_schema.readImages(path, jsession, recursive, numPartitions, dropImageFailures, float(sampleRatio), seed) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 628735418e15f..07c898e823a56 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1854,13 +1854,12 @@ def test_ocv_types(self): self.assertEqual(x, ImageSchema.ocvTypeByMode(x.mode)) def test_conversions(self): - ary_src = [[[1e7*random.random() for z in range(4)] for y in range(10)] for x in range(10)] + s = np.random.RandomState(seed=987) + ary_src = s.rand(4, 10, 10) for ocvType in ImageSchema.ocvTypes: if ocvType.name == 'Undefined': continue - x = [[ary_src[i][j][0:ocvType.nChannels] - for j in range(len(ary_src[0]))] for i in range(len(ary_src))] - npary0 = np.array(x).astype(ocvType.nptype) + npary0 = ary_src[..., 0:ocvType.nChannels].astype(ocvType.nptype) img = ImageSchema.toImage(npary0) self.assertEqual(ocvType, ImageSchema.ocvTypeByMode(img.mode)) npary1 = ImageSchema.toNDArray(img) @@ -1878,7 +1877,7 @@ def test_read_images(self): expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] self.assertEqual(ImageSchema.imageFields, expected) - self.assertEqual(ImageSchema.undefinedImageType, "Undefined") + self.assertEqual(ImageSchema.undefinedImageType.name, "Undefined") with QuietTest(self.sc): self.assertRaisesRegexp( From 490454ae3659f9f00d02fde97cf450dffd828930 Mon Sep 17 00:00:00 2001 From: tomasatdatabricks Date: Fri, 12 Jan 2018 15:10:52 -0800 Subject: [PATCH 3/5] Minor test fix - added type check to numpy array comparison --- python/pyspark/ml/tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 07c898e823a56..ee4121f00175a 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1863,6 +1863,7 @@ def test_conversions(self): img = ImageSchema.toImage(npary0) self.assertEqual(ocvType, ImageSchema.ocvTypeByMode(img.mode)) npary1 = ImageSchema.toNDArray(img) + self.assertEqual(ocvType.nptype, npary1.dtype) np.testing.assert_array_equal(npary0, npary1) def test_read_images(self): From 31fef5e1f8cd240d8d39b9d56a7e6509e18017dd Mon Sep 17 00:00:00 2001 From: tomasatdatabricks Date: Mon, 15 Jan 2018 18:24:45 -0800 Subject: [PATCH 4/5] Adressed review comments: Added explicit types for ocvTypes and ocvTypesJava, fixed/added python comments. --- .../apache/spark/ml/image/ImageSchema.scala | 19 +++++++++++++++--- python/pyspark/ml/image.py | 20 +++++++++++++++++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index de8ccc90101cd..f7fcd7a215f93 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -39,6 +39,7 @@ object ImageSchema { /** * OpenCv type representation + * * @param mode ordinal for the type * @param dataType open cv data type * @param nChannels number of color channels @@ -48,11 +49,23 @@ object ImageSchema { override def toString: String = s"OpenCvType(mode = $mode, name = $name)" } + /** + * Return the supported OpenCvType with matching name or raise error if there is no matching type. + * + * @param name: name of existing OpenCvType + * @return OpenCvType that matches the given name + */ def ocvTypeByName(name: String): OpenCvType = { ocvTypes.find(x => x.name == name).getOrElse( throw new IllegalArgumentException("Unknown open cv type " + name)) } + /** + * Return the supported OpenCvType with matching mode or raise error if there is no matching type. + * + * @param mode: mode of existing OpenCvType + * @return OpenCvType that matches the given mode + */ def ocvTypeByMode(mode: Int): OpenCvType = { ocvTypes.find(x => x.mode == mode).getOrElse( throw new IllegalArgumentException("Unknown open cv mode " + mode)) @@ -72,7 +85,7 @@ object ImageSchema { * CV_32F 5 13 21 29 * CV_64F 6 14 22 30 */ - val ocvTypes = { + val ocvTypes: IndexedSeq[OpenCvType] = { val types = for (nc <- Array(1, 2, 3, 4); dt <- Array("8U", "8S", "16U", "16S", "32S", "32F", "64F")) @@ -82,9 +95,9 @@ object ImageSchema { } /** - * (Java Specific) list of OpenCv types + * (Java-specific) list of OpenCv types */ - val javaOcvTypes = ocvTypes.asJava + val javaOcvTypes: java.util.List[OpenCvType] = ocvTypes.asJava /** * Schema for the image column: Row(String, Int, Int, Int, Int, Array[Byte]) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index d8d0e7fc646f5..3ab4edf892eb9 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -80,9 +80,9 @@ def imageSchema(self): @property def ocvTypes(self): """ - Returns the OpenCV type mapping supported. + Return the supported OpenCV types. - :return: a dictionary containing the OpenCV type mapping supported. + :return: a list containing the supported OpenCV types. .. versionadded:: 2.3.0 """ @@ -99,6 +99,14 @@ def ocvTypes(self): return self._ocvTypes[:] def ocvTypeByName(self, name): + """ + Return the OpenCvType with matching name or raise error if there is no matching type. + + :param: str name: OpenCv type name; must be equal to name of one of the supported types. + :return: OpenCvType with matching name. + + """ + if self._ocvTypesByName is None: self._ocvTypesByName = {x.name: x for x in self.ocvTypes} if name not in self._ocvTypesByName: @@ -108,6 +116,14 @@ def ocvTypeByName(self, name): return self._ocvTypesByName[name] def ocvTypeByMode(self, mode): + """ + Return the OpenCvType with matching mode or raise error if there is no matching type. + + :param: int mode: OpenCv type mode; must be equal to mode of one of the supported types. + :return: OpenCvType with matching mode. + + """ + if self._ocvTypesByMode is None: self._ocvTypesByMode = {x.mode: x for x in self.ocvTypes} if mode not in self._ocvTypesByMode: From 5a632f5f60afd2e8c225703532d17ed9e56e47f7 Mon Sep 17 00:00:00 2001 From: tomasatdatabricks Date: Tue, 16 Jan 2018 14:31:41 -0800 Subject: [PATCH 5/5] Adressed reviw comments. Mostly update in comments, variable names. --- .../apache/spark/ml/image/ImageSchema.scala | 29 +++++++++++-------- .../spark/ml/image/ImageSchemaSuite.scala | 4 +-- python/pyspark/ml/image.py | 2 +- python/pyspark/ml/tests.py | 4 +-- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index f7fcd7a215f93..ac4ac100130cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -38,10 +38,13 @@ import org.apache.spark.sql.types._ object ImageSchema { /** - * OpenCv type representation + * OpenCv type representation. + * + * @see + * OpenCv/basic_structures * * @param mode ordinal for the type - * @param dataType open cv data type + * @param dataType OpenCv data type * @param nChannels number of color channels */ case class OpenCvType(mode: Int, dataType: String, nChannels: Int) { @@ -57,7 +60,7 @@ object ImageSchema { */ def ocvTypeByName(name: String): OpenCvType = { ocvTypes.find(x => x.name == name).getOrElse( - throw new IllegalArgumentException("Unknown open cv type " + name)) + throw new IllegalArgumentException("Unknown OpenCv type " + name)) } /** @@ -68,7 +71,7 @@ object ImageSchema { */ def ocvTypeByMode(mode: Int): OpenCvType = { ocvTypes.find(x => x.mode == mode).getOrElse( - throw new IllegalArgumentException("Unknown open cv mode " + mode)) + throw new IllegalArgumentException("Unknown OpenCv mode " + mode)) } val undefinedImageType = OpenCvType(-1, "N/A", -1) @@ -76,14 +79,16 @@ object ImageSchema { /** * A Mapping of Type to Numbers in OpenCV * - * C1 C2 C3 C4 - * CV_8U 0 8 16 24 - * CV_8S 1 9 17 25 - * CV_16U 2 10 18 26 - * CV_16S 3 11 19 27 - * CV_32S 4 12 20 28 - * CV_32F 5 13 21 29 - * CV_64F 6 14 22 30 + * name | num channels + * | C1 C2 C3 C4 + * -------+-------------- + * CV_8U | 0 8 16 24 + * CV_8S | 1 9 17 25 + * CV_16U | 2 10 18 26 + * CV_16S | 3 11 19 27 + * CV_32S | 4 12 20 28 + * CV_32F | 5 13 21 29 + * CV_64F | 6 14 22 30 */ val ocvTypes: IndexedSeq[OpenCvType] = { val types = diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index c45faa8b89cc7..de585e29656a5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -82,8 +82,8 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { val mode = getMode(row) val bytes20 = getData(row).slice(0, 20) - val (expectedMode, expectedBytes) = firstBytes20(filename) - assert(ImageSchema.ocvTypeByName(expectedMode).mode === mode, + val (expectedName, expectedBytes) = firstBytes20(filename) + assert(ImageSchema.ocvTypeByName(expectedName).mode === mode, "mode of the image is not read correctly") assert(Arrays.equals(expectedBytes, bytes20), "incorrect numeric value for flattened image") } diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 3ab4edf892eb9..2cabb70fd994b 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -185,7 +185,7 @@ def toNDArray(self, image): ocvType = self.ocvTypeByMode(image.mode) if nChannels != ocvType.nChannels: raise ValueError( - "Image has %d channels but OcvType '%s' expects %d channels." % + "Image has %d channels but its OcvType '%s' expects %d channels." % (nChannels, ocvType.name, ocvType.nChannels)) itemSz = ocvType.nptype.itemsize return np.ndarray( diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ee4121f00175a..e2232ff212384 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1855,11 +1855,11 @@ def test_ocv_types(self): def test_conversions(self): s = np.random.RandomState(seed=987) - ary_src = s.rand(4, 10, 10) + array_src = s.rand(10, 10, 4) for ocvType in ImageSchema.ocvTypes: if ocvType.name == 'Undefined': continue - npary0 = ary_src[..., 0:ocvType.nChannels].astype(ocvType.nptype) + npary0 = array_src[..., 0:ocvType.nChannels].astype(ocvType.nptype) img = ImageSchema.toImage(npary0) self.assertEqual(ocvType, ImageSchema.ocvTypeByMode(img.mode)) npary1 = ImageSchema.toNDArray(img)