-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-21866][ML][PySpark] Adding spark image reader #19439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4bd04f0
c897fbc
1b06505
82524c3
b9ff352
44a5ba1
09950cc
591d932
c0fb9ff
9db3465
e5dd345
bd1d495
6983805
35358cf
d16caf1
bae1449
1d37d5c
113894e
637c1ab
a76496b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| not an image |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| The images in the folder "kittens" are under the creative commons CC0 license, or no rights reserved: | ||
| https://creativecommons.org/share-your-work/public-domain/cc0/ | ||
| The images are taken from: | ||
| https://ccsearch.creativecommons.org/image/detail/WZnbJSJ2-dzIDiuUUdto3Q== | ||
| https://ccsearch.creativecommons.org/image/detail/_TlKu_rm_QrWlR0zthQTXA== | ||
| https://ccsearch.creativecommons.org/image/detail/OPNnHJb6q37rSZ5o_L5JHQ== | ||
| https://ccsearch.creativecommons.org/image/detail/B2CVP_j5KjwZm7UAVJ3Hvw== | ||
|
|
||
| The chr30.4.184.jpg and grayscale.jpg images are also under the CC0 license, taken from: | ||
| https://ccsearch.creativecommons.org/image/detail/8eO_qqotBfEm2UYxirLntw== | ||
|
|
||
| The image under "multi-channel" directory is under the CC BY-SA 4.0 license cropped from: | ||
| https://en.wikipedia.org/wiki/Alpha_compositing#/media/File:Hue_alpha_falloff.png |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.image | ||
|
|
||
| import scala.language.existentials | ||
| import scala.util.Random | ||
|
|
||
| import org.apache.commons.io.FilenameUtils | ||
| import org.apache.hadoop.conf.{Configuration, Configured} | ||
| import org.apache.hadoop.fs.{Path, PathFilter} | ||
| import org.apache.hadoop.mapreduce.lib.input.FileInputFormat | ||
|
|
||
| import org.apache.spark.sql.SparkSession | ||
|
|
||
| private object RecursiveFlag { | ||
| /** | ||
| * Sets the spark recursive flag and then restores it. | ||
| * | ||
| * @param value Value to set | ||
| * @param spark Existing spark session | ||
| * @param f The function to evaluate after setting the flag | ||
| * @return Returns the evaluation result T of the function | ||
| */ | ||
| def withRecursiveFlag[T](value: Boolean, spark: SparkSession)(f: => T): T = { | ||
| val flagName = FileInputFormat.INPUT_DIR_RECURSIVE | ||
| val hadoopConf = spark.sparkContext.hadoopConfiguration | ||
| val old = Option(hadoopConf.get(flagName)) | ||
| hadoopConf.set(flagName, value.toString) | ||
| try f finally { | ||
| old match { | ||
| case Some(v) => hadoopConf.set(flagName, v) | ||
| case None => hadoopConf.unset(flagName) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Filter that allows loading a fraction of HDFS files. | ||
| */ | ||
| private class SamplePathFilter extends Configured with PathFilter { | ||
| val random = new Random() | ||
|
|
||
| // Ratio of files to be read from disk | ||
| var sampleRatio: Double = 1 | ||
|
|
||
| override def setConf(conf: Configuration): Unit = { | ||
| if (conf != null) { | ||
| sampleRatio = conf.getDouble(SamplePathFilter.ratioParam, 1) | ||
| val seed = conf.getLong(SamplePathFilter.seedParam, 0) | ||
| random.setSeed(seed) | ||
| } | ||
| } | ||
|
|
||
| override def accept(path: Path): Boolean = { | ||
| // Note: checking fileSystem.isDirectory is very slow here, so we use basic rules instead | ||
| !SamplePathFilter.isFile(path) || random.nextDouble() < sampleRatio | ||
|
||
| } | ||
| } | ||
|
|
||
| private object SamplePathFilter { | ||
| val ratioParam = "sampleRatio" | ||
| val seedParam = "seed" | ||
|
|
||
| def isFile(path: Path): Boolean = FilenameUtils.getExtension(path.toString) != "" | ||
|
|
||
| /** | ||
| * Sets the HDFS PathFilter flag and then restores it. | ||
| * Only applies the filter if sampleRatio is less than 1. | ||
| * | ||
| * @param sampleRatio Fraction of the files that the filter picks | ||
| * @param spark Existing Spark session | ||
| * @param seed Random number seed | ||
| * @param f The function to evaluate after setting the flag | ||
| * @return Returns the evaluation result T of the function | ||
| */ | ||
| def withPathFilter[T]( | ||
| sampleRatio: Double, | ||
| spark: SparkSession, | ||
| seed: Long)(f: => T): T = { | ||
| val sampleImages = sampleRatio < 1 | ||
| if (sampleImages) { | ||
| val flagName = FileInputFormat.PATHFILTER_CLASS | ||
| val hadoopConf = spark.sparkContext.hadoopConfiguration | ||
| val old = Option(hadoopConf.getClass(flagName, null)) | ||
| hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio) | ||
| hadoopConf.setLong(SamplePathFilter.seedParam, seed) | ||
| hadoopConf.setClass(flagName, classOf[SamplePathFilter], classOf[PathFilter]) | ||
| try f finally { | ||
| hadoopConf.unset(SamplePathFilter.ratioParam) | ||
| hadoopConf.unset(SamplePathFilter.seedParam) | ||
| old match { | ||
| case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter]) | ||
| case None => hadoopConf.unset(flagName) | ||
| } | ||
| } | ||
| } else { | ||
| f | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,257 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.ml.image | ||
|
|
||
| import java.awt.Color | ||
| import java.awt.color.ColorSpace | ||
| import java.io.ByteArrayInputStream | ||
| import javax.imageio.ImageIO | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.input.PortableDataStream | ||
| import org.apache.spark.sql.{DataFrame, Row, SparkSession} | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Defines the image schema and methods to read and manipulate images. | ||
| */ | ||
| @Experimental | ||
| @Since("2.3.0") | ||
| object ImageSchema { | ||
|
|
||
| val undefinedImageType = "Undefined" | ||
|
|
||
| /** | ||
| * (Scala-specific) OpenCV type mapping supported | ||
| */ | ||
| val ocvTypes: Map[String, Int] = Map( | ||
| undefinedImageType -> -1, | ||
| "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC3" -> 16, "CV_8UC4" -> 24 | ||
| ) | ||
|
|
||
| /** | ||
| * (Java-specific) OpenCV type mapping supported | ||
| */ | ||
| val javaOcvTypes: java.util.Map[String, Int] = ocvTypes.asJava | ||
|
|
||
| /** | ||
| * Schema for the image column: Row(String, Int, Int, Int, Int, Array[Byte]) | ||
| */ | ||
| val columnSchema = StructType( | ||
|
||
| StructField("origin", StringType, true) :: | ||
| StructField("height", IntegerType, false) :: | ||
| StructField("width", IntegerType, false) :: | ||
| StructField("nChannels", IntegerType, false) :: | ||
| // OpenCV-compatible type: CV_8UC3 in most cases | ||
| StructField("mode", IntegerType, false) :: | ||
| // Bytes in OpenCV-compatible order: row-wise BGR in most cases | ||
| StructField("data", BinaryType, false) :: Nil) | ||
|
|
||
| val imageFields: Array[String] = columnSchema.fieldNames | ||
|
|
||
| /** | ||
| * DataFrame with a single column of images named "image" (nullable) | ||
| */ | ||
| val imageSchema = StructType(StructField("image", columnSchema, true) :: Nil) | ||
|
|
||
| /** | ||
| * Gets the origin of the image | ||
| * | ||
| * @return The origin of the image | ||
| */ | ||
| def getOrigin(row: Row): String = row.getString(0) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw the review comment and discussion in #19439 (comment). In particular, these
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am saying this partly because I am not seeing this in Python API. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can only echo the discussion you point out -- these are convenience function that allow the user not to care about indexing into schema (which is a common source of mistakes, in my experience). We might consider adding them to Python API too.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @HyukjinKwon @dakirsa agree, it would be nice to add them to python API as well eventually |
||
|
|
||
| /** | ||
| * Gets the height of the image | ||
| * | ||
| * @return The height of the image | ||
| */ | ||
| def getHeight(row: Row): Int = row.getInt(1) | ||
|
|
||
| /** | ||
| * Gets the width of the image | ||
| * | ||
| * @return The width of the image | ||
| */ | ||
| def getWidth(row: Row): Int = row.getInt(2) | ||
|
|
||
| /** | ||
| * Gets the number of channels in the image | ||
| * | ||
| * @return The number of channels in the image | ||
| */ | ||
| def getNChannels(row: Row): Int = row.getInt(3) | ||
|
|
||
| /** | ||
| * Gets the OpenCV representation as an int | ||
| * | ||
| * @return The OpenCV representation as an int | ||
| */ | ||
| def getMode(row: Row): Int = row.getInt(4) | ||
|
|
||
| /** | ||
| * Gets the image data | ||
| * | ||
| * @return The image data | ||
| */ | ||
| def getData(row: Row): Array[Byte] = row.getAs[Array[Byte]](5) | ||
|
|
||
| /** | ||
| * Default values for the invalid image | ||
| * | ||
| * @param origin Origin of the invalid image | ||
| * @return Row with the default values | ||
| */ | ||
| private[spark] def invalidImageRow(origin: String): Row = | ||
| Row(Row(origin, -1, -1, -1, ocvTypes(undefinedImageType), Array.ofDim[Byte](0))) | ||
|
|
||
| /** | ||
| * Convert the compressed image (jpeg, png, etc.) into OpenCV | ||
| * representation and store it in DataFrame Row | ||
| * | ||
| * @param origin Arbitrary string that identifies the image | ||
| * @param bytes Image bytes (for example, jpeg) | ||
| * @return DataFrame Row or None (if the decompression fails) | ||
| */ | ||
| private[spark] def decode(origin: String, bytes: Array[Byte]): Option[Row] = { | ||
|
|
||
| val img = ImageIO.read(new ByteArrayInputStream(bytes)) | ||
|
|
||
| if (img == null) { | ||
| None | ||
| } else { | ||
| val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY | ||
| val hasAlpha = img.getColorModel.hasAlpha | ||
|
|
||
| val height = img.getHeight | ||
| val width = img.getWidth | ||
| val (nChannels, mode) = if (isGray) { | ||
| (1, ocvTypes("CV_8UC1")) | ||
| } else if (hasAlpha) { | ||
| (4, ocvTypes("CV_8UC4")) | ||
| } else { | ||
| (3, ocvTypes("CV_8UC3")) | ||
| } | ||
|
|
||
| val imageSize = height * width * nChannels | ||
| assert(imageSize < 1e9, "image is too large") | ||
| val decoded = Array.ofDim[Byte](imageSize) | ||
|
|
||
| // Grayscale images in Java require special handling to get the correct intensity | ||
| if (isGray) { | ||
| var offset = 0 | ||
| val raster = img.getRaster | ||
| for (h <- 0 until height) { | ||
| for (w <- 0 until width) { | ||
| decoded(offset) = raster.getSample(w, h, 0).toByte | ||
| offset += 1 | ||
| } | ||
| } | ||
| } else { | ||
| var offset = 0 | ||
| for (h <- 0 until height) { | ||
| for (w <- 0 until width) { | ||
| val color = new Color(img.getRGB(w, h)) | ||
|
|
||
| decoded(offset) = color.getBlue.toByte | ||
| decoded(offset + 1) = color.getGreen.toByte | ||
| decoded(offset + 2) = color.getRed.toByte | ||
|
||
| if (nChannels == 4) { | ||
| decoded(offset + 3) = color.getAlpha.toByte | ||
| } | ||
| offset += nChannels | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // the internal "Row" is needed, because the image is a single DataFrame column | ||
| Some(Row(Row(origin, height, width, nChannels, mode, decoded))) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Read the directory of images from the local or remote source | ||
| * | ||
| * @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, | ||
| * there may be a race condition where one job overwrites the hadoop configs of another. | ||
| * @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but | ||
| * potentially non-deterministic. | ||
| * | ||
| * @param path Path to the image directory | ||
| * @return DataFrame with a single column "image" of images; | ||
| * see ImageSchema for the details | ||
| */ | ||
| def readImages(path: String): DataFrame = readImages(path, null, false, -1, false, 1.0, 0) | ||
|
|
||
| /** | ||
| * Read the directory of images from the local or remote source | ||
| * | ||
| * @note If multiple jobs are run in parallel with different sampleRatio or recursive flag, | ||
| * there may be a race condition where one job overwrites the hadoop configs of another. | ||
| * @note If sample ratio is less than 1, sampling uses a PathFilter that is efficient but | ||
| * potentially non-deterministic. | ||
| * | ||
| * @param path Path to the image directory | ||
| * @param sparkSession Spark Session, if omitted gets or creates the session | ||
| * @param recursive Recursive path search flag | ||
| * @param numPartitions Number of the DataFrame partitions, | ||
| * if omitted uses defaultParallelism instead | ||
| * @param dropImageFailures Drop the files that are not valid images from the result | ||
| * @param sampleRatio Fraction of the files loaded | ||
| * @return DataFrame with a single column "image" of images; | ||
| * see ImageSchema for the details | ||
| */ | ||
| def readImages( | ||
| path: String, | ||
| sparkSession: SparkSession, | ||
| recursive: Boolean, | ||
| numPartitions: Int, | ||
| dropImageFailures: Boolean, | ||
| sampleRatio: Double, | ||
| seed: Long): DataFrame = { | ||
| require(sampleRatio <= 1.0 && sampleRatio >= 0, "sampleRatio should be between 0 and 1") | ||
|
|
||
| val session = if (sparkSession != null) sparkSession else SparkSession.builder().getOrCreate | ||
| val partitions = | ||
| if (numPartitions > 0) { | ||
| numPartitions | ||
| } else { | ||
| session.sparkContext.defaultParallelism | ||
| } | ||
|
|
||
| RecursiveFlag.withRecursiveFlag(recursive, session) { | ||
|
||
| SamplePathFilter.withPathFilter(sampleRatio, session, seed) { | ||
| val binResult = session.sparkContext.binaryFiles(path, partitions) | ||
| val streams = if (numPartitions == -1) binResult else binResult.repartition(partitions) | ||
| val convert = (origin: String, bytes: PortableDataStream) => | ||
| decode(origin, bytes.toArray()) | ||
| val images = if (dropImageFailures) { | ||
| streams.flatMap { case (origin, bytes) => convert(origin, bytes) } | ||
| } else { | ||
| streams.map { case (origin, bytes) => | ||
| convert(origin, bytes).getOrElse(invalidImageRow(origin)) | ||
| } | ||
| } | ||
| session.createDataFrame(images, imageSchema) | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tell me if this SamplePathFilter has already been discussed; I may have missed it in the many comments above. I'm worried about it being deterministic, but I'm also not that familiar with the Hadoop APIs being used here.
We've run into a lot of issues in both RDD and DataFrame sampling methods with non-deterministic results, so I want to be careful here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I'm not sure about whether it will be deterministic even if we set a seed, but I can try to do that for now. As @thunterdb suggested, we could use some sort of a hash on the filename - but I'm not sure on how I would make that implementation work with a specified ratio - could you give me more info on the design:
"I would prefer that we do not use a seed and that the result is deterministic, based for example on some hash of the file name, to make it more robust to future code changes. That being said, there is no fundamental issues with the current implementation and other developers may have differing opinions, so the current implementation is fine as far as I am concerned."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it would be pretty simple to work it out this way (pseudocode). Note that a hash is a random variable between -2^31 and 2^31-1, so:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I take this comment back, since it would break on some pathological cases such as all the names being the same. When users want some samples, they most probably want a result that is a fraction of the original, whatever it may contain.
@jkbradley do you prefer a something that may not be deterministic (using random numbers) or deterministic but not respecting the sampling ratio in pathological cases? The only way to do both that I can think of is deduplicating, which requires a shuffle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer determinism since that's a pretty important standard in Spark. I could imagine either (a) using a file hash with a global random number or (b) using random numbers if we are certain about how PathFilters work.
For (a):
For (b):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't believe readImages would ever return duplicate file names - even when reading from zip files the full path is included from the root. I'm not sure what other pathological cases are possible - the hash function needs to be very good though - it needs to generate uniformly distributed random values (eg, if all values are above math.pow(2, 31) * 0.1234, then this won't work very well, or even if hashed values tend to be distributed more in some ranges than others) - I need to look into how the hash code is generated more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the default hashCode is computed as:
s[0]*31^(n-1) + s[1]*31^(n-2) + ... + s[n-1]
where s[i] is the ith character of the string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, if we think that duplicate file names are not an issue, then I would prefer having using the deterministic hashing-based scheme. I am also happy to make this into another PR since this is a pretty small matter.
@imatiach-msft a good hashing function is certainly a concern, and we will need to make a trade off between performance and correctness. If we really want to make sure that it works as expected, the best is probably to use a cryptographic hash, like
SHA-256(which has strong guarantees of the distribution of output values):https://stackoverflow.com/questions/5531455/how-to-hash-some-string-with-sha256-in-java
We have
murmur3in the Spark source code, but it is not cryptographic and does not come with as strong guarantees. For the sake of performance, we may want to use it eventually.Once you have the digest, which is a byte array, then it can be converted to a long first (by taking 8 bytes from the whole digest):
https://stackoverflow.com/questions/4485128/how-do-i-convert-long-to-byte-and-back-in-java
and then you can convert this long to a double and compare it to the requested fraction: