Permalink
Browse files

Refactor image reader to new schema

  • Loading branch information...
mhamilton723 committed Jan 3, 2019
1 parent a3dd0ef commit a3ab0177134cc7b1becd4a6c4b0a3f5d03412ad0
Showing with 940 additions and 1,000 deletions.
  1. +3 −3 notebooks/samples/DeepLearning - Flower Image Classification.ipynb
  2. +4 −4 notebooks/samples/OpenCV - Pipeline Image Transformations.ipynb
  3. +1 −1 notebooks/samples/SparkServing - Deploying a Classifier.ipynb
  4. +2 −2 src/cntk-model/src/test/scala/CNTKTestUtils.scala
  5. +3 −3 src/cntk-train/src/test/scala/ValidateCntkTrain.scala
  6. +6 −9 src/core/schema/src/main/scala/BinaryFileSchema.scala
  7. +0 −141 src/core/schema/src/main/scala/ImageSchema.scala
  8. +33 −0 src/core/schema/src/main/scala/ImageSchemaUtils.scala
  9. +5 −4 src/featurize/src/main/scala/AssembleFeatures.scala
  10. +3 −3 src/featurize/src/test/scala/VerifyFeaturize.scala
  11. +2 −2 src/image-featurizer/src/main/scala/ImageFeaturizer.scala
  12. +2 −2 src/image-featurizer/src/main/scala/ImageLIME.scala
  13. +1 −1 src/image-featurizer/src/main/scala/ImageSetAugmenter.scala
  14. +8 −8 src/image-featurizer/src/main/scala/Superpixel.scala
  15. +1 −2 src/image-featurizer/src/main/scala/SuperpixelTransformer.scala
  16. +22 −15 src/image-featurizer/src/test/scala/ImageFeaturizerSuite.scala
  17. +6 −7 src/image-featurizer/src/test/scala/ImageLIMESuite.scala
  18. +2 −2 src/image-featurizer/src/test/scala/ImageSetAugmenterSuite.scala
  19. +2 −5 src/image-featurizer/src/test/scala/SuperpixelSuite.scala
  20. +7 −6 src/image-transformer/src/main/python/ImageTransformer.py
  21. +36 −27 src/image-transformer/src/main/scala/ImageTransformer.scala
  22. +33 −0 src/image-transformer/src/main/scala/OpenCVUtils.scala
  23. +7 −7 src/image-transformer/src/main/scala/ResizeImageTransformer.scala
  24. +12 −13 src/image-transformer/src/main/scala/UnrollImage.scala
  25. +60 −52 src/image-transformer/src/test/scala/ImageTransformerSuite.scala
  26. +3 −2 src/image-transformer/src/test/scala/ResizeImageTransformerSuite.scala
  27. +47 −10 src/io/binary/src/main/scala/BinaryFileFormat.scala
  28. +2 −2 src/io/binary/src/main/scala/BinaryFileReader.scala
  29. +46 −18 src/io/binary/src/test/scala/BinaryFileReaderSuite.scala
  30. +2 −2 src/io/http/src/main/scala/ServingImplicits.scala
  31. +0 −122 src/io/image/src/main/python/ImageReader.py
  32. +70 −0 src/io/image/src/main/python/ImageUtils.py
  33. +0 −30 src/io/image/src/main/python/ImageWriter.py
  34. +0 −221 src/io/image/src/main/scala/Image.scala
  35. +0 −146 src/io/image/src/main/scala/ImageFileFormat.scala
  36. +177 −0 src/io/image/src/main/scala/ImageUtils.scala
  37. +13 −0 src/io/image/src/main/scala/NamespaceInjections.scala
  38. +152 −0 src/io/image/src/main/scala/PatchedImageFileFormat.scala
  39. +44 −82 src/io/image/src/test/scala/ImageReaderSuite.scala
  40. +35 −0 src/io/src/main/python/IOImplicits.py
  41. +88 −0 src/io/src/main/scala/IOImplicits.scala
  42. +0 −46 src/io/src/main/scala/Readers.scala
@@ -47,8 +47,8 @@
"source": [
"# Load the images\n",
"# use flowers_and_labels.parquet on larger cluster in order to get better results\n",
"imagesWithLabels = spark.read.parquet(\"wasbs://publicwasb@mmlspark.blob.core.windows.net/flowers_and_labels_small.parquet\") \\\n",
" .withColumn(\"labels\", col(\"labels\").cast(\"Double\"))\n",
"imagesWithLabels = spark.read.parquet(\"wasbs://publicwasb@mmlspark.blob.core.windows.net/flowers_and_labels2.parquet\") \\\n",
" .withColumnRenamed(\"bytes\",\"image\").sample(.1)\n",
"\n",
"imagesWithLabels.printSchema()"
]
@@ -206,7 +206,7 @@
"evaluate(deepResults,\"CNTKModel + LR\")\n",
"plt.subplot(1,2,2)\n",
"evaluate(basicResults,\"LR\")\n",
"# Note that on the larger dataset the accuracy will bump up from 44% to >90%\n",
"# Note that on the larger dataset the accuracy will bump up from 44% to >90%\n",
"display(plt.show())"
]
}
@@ -31,7 +31,7 @@
"from mmlspark import toNDArray\n",
"\n",
"imageDir = \"wasbs://publicwasb@mmlspark.blob.core.windows.net/sampleImages\"\n",
"images = spark.readImages(imageDir, recursive = True, sampleRatio = 0.1).cache()\n",
"images = spark.read.image().load(imageDir).cache()\n",
"images.printSchema()\n",
"print(images.count())"
]
@@ -52,7 +52,7 @@
"metadata": {},
"outputs": [],
"source": [
"imageStream = spark.streamImages(imageDir + \"/*\", sampleRatio = 0.1)\n",
"imageStream = spark.readStream.image().load(imageDir)\n",
"query = imageStream.select(\"image.height\").writeStream.format(\"memory\").queryName(\"heights\").start()\n",
"print(\"Streaming query activity: {}\".format(query.isActive))"
]
@@ -116,8 +116,8 @@
"im = data[2][0] # the image is in the first column of a given row\n",
"\n",
"print(\"image type: {}, number of fields: {}\".format(type(im), len(im)))\n",
"print(\"image path: {}\".format(im.path))\n",
"print(\"height: {}, width: {}, OpenCV type: {}\".format(im.height, im.width, im.type))\n",
"print(\"image path: {}\".format(im.origin))\n",
"print(\"height: {}, width: {}, OpenCV type: {}\".format(im.height, im.width, im.mode))\n",
"\n",
"arr = toNDArray(im) # convert to numpy array\n",
"Image.fromarray(arr, \"RGB\") # display the image inside notebook\n",
@@ -130,7 +130,7 @@
" .server() \\\n",
" .replyTo(\"my_api\") \\\n",
" .queryName(\"my_query\") \\\n",
" .option(\"checkpointLocation\", \"checkpoints-{}\".format(uuid.uuid1())) \\\n",
" .option(\"checkpointLocation\", \"file:///tmp/checkpoints-{}\".format(uuid.uuid1())) \\\n",
" .start()\n"
]
},
@@ -4,9 +4,9 @@
package com.microsoft.ml.spark

import com.microsoft.ml.spark.FileUtilities.File
import com.microsoft.ml.spark.Readers.implicits._
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql._
import com.microsoft.ml.spark.IOImplicits._

trait CNTKTestUtils extends TestBase {

@@ -38,7 +38,7 @@ trait CNTKTestUtils extends TestBase {
}

def testImages(spark: SparkSession): DataFrame = {
val images = spark.readImages(imagePath, true)
val images = spark.read.image.load(imagePath)

val unroll = new UnrollImage().setInputCol("image").setOutputCol(inputCol)

@@ -7,7 +7,7 @@ import java.io.File
import java.net.URI

import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexerModel}
import com.microsoft.ml.spark.Readers.implicits._
import com.microsoft.ml.spark.IOImplicits._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.scalatest.{BeforeAndAfterEach, Suite}
@@ -255,12 +255,12 @@ TrainNetwork = {
val indexedLabel = "idxlabels"
val labelCol = "labels"

val images = session.readImages(imagePath, true)
val images = session.read.image.load(imagePath)

// Label annotation: CIFAR is constructed here as
// 01234-01.png, meaning (len - 5, len - 3) is label
val pathLen = images.first.getStruct(0).getString(0).length
val labeledData = images.withColumn(tmpLabel, images("image.path").substr(pathLen - 5, 2).cast("float"))
val labeledData = images.withColumn(tmpLabel, images("image.origin").substr(pathLen - 5, 2).cast("float"))

// Unroll images into Spark representation
val unroller = new UnrollImage().setOutputCol(inputCol).setInputCol("image")
@@ -4,7 +4,7 @@
package com.microsoft.ml.spark.schema

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{BinaryType, StringType, StructField, StructType}
import org.apache.spark.sql.types._

object BinaryFileSchema {

@@ -20,13 +20,10 @@ object BinaryFileSchema {
def getPath(row: Row): String = row.getString(0)
def getBytes(row: Row): Array[Byte] = row.getAs[Array[Byte]](1)

/** Check if the dataframe column contains binary file data (i.e. has BinaryFileSchema)
*
* @param df
* @param column
* @return
*/
def isBinaryFile(df: DataFrame, column: String): Boolean =
df.schema(column).dataType == columnSchema
/** Check if the dataframe column contains binary file data (i.e. has BinaryFileSchema) */
def isBinaryFile(dt: DataType): Boolean =
dt == columnSchema

def isBinaryFile(sf: StructField): Boolean =
isBinaryFile(sf.dataType)
}

This file was deleted.

Oops, something went wrong.
@@ -0,0 +1,33 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark

import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.sql.types._

object ImageSchemaUtils {

val columnSchemaNullable = {
StructType(
StructField("origin", StringType, true) ::
StructField("height", IntegerType, true) ::
StructField("width", IntegerType, true) ::
StructField("nChannels", IntegerType, true) ::
// OpenCV-compatible type: CV_8UC3 in most cases
StructField("mode", IntegerType, true) ::
// Bytes in OpenCV-compatible order: row-wise BGR in most cases
StructField("data", BinaryType, true) :: Nil)
}

val imageSchemaNullable = StructType(StructField("image", columnSchemaNullable, true) :: Nil)

def isImage(dataType: DataType): Boolean = {
dataType == ImageSchema.columnSchema ||
dataType == columnSchemaNullable
}

def isImage(dataType: StructField): Boolean = {
isImage(dataType.dataType)
}
}
@@ -8,9 +8,10 @@ import java.sql.{Date, Timestamp}
import java.time.temporal.ChronoField

import com.microsoft.ml.spark.schema.DatasetExtensions._
import com.microsoft.ml.spark.schema.{CategoricalColumnInfo, DatasetExtensions, ImageSchema}
import com.microsoft.ml.spark.schema.{CategoricalColumnInfo, DatasetExtensions}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.feature._
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{SparseVector, Vectors}
import org.apache.spark.ml.param._
@@ -213,7 +214,7 @@ class AssembleFeatures(override val uid: String) extends Estimator[AssembleFeatu
columnNamesToFeaturize.colNamesToTypes += unusedColumnName -> dataType
columnNamesToFeaturize.conversionColumnNamesMap += col -> unusedColumnName
}
case _ if ImageSchema.isImage(datasetAsDf, col) =>
case _ if ImageSchemaUtils.isImage(datasetAsDf.schema(col)) =>
if (!getAllowImages) {
throw new UnsupportedOperationException("Featurization of images columns disabled")
}
@@ -390,9 +391,9 @@ class AssembleFeaturesModel(val uid: String,
})
Seq(dataset(col),
extractTimeFeatures(dataset(col)).as(tmpRenamedCols, dataset.schema(col).metadata))
case imageType if imageType == ImageSchema.columnSchema =>
case t if ImageSchemaUtils.isImage(t) =>
val extractImageFeatures = udf((row: Row) => {
val image = ImageSchema.getBytes(row).map(_.toDouble)
val image = ImageSchema.getData(row).map(_.toDouble)
val height = ImageSchema.getHeight(row).toDouble
val width = ImageSchema.getWidth(row).toDouble
Vectors.dense((height :: (width :: image.toList)).toArray)
@@ -9,7 +9,6 @@ import java.sql.{Date, Timestamp}
import java.util.GregorianCalendar

import com.microsoft.ml.spark.FileUtilities.File
import com.microsoft.ml.spark.schema.ImageSchema
import org.apache.spark.ml.{Estimator, PipelineModel}
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vectors}
@@ -18,6 +17,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.commons.io.FileUtils
import org.apache.spark.ml.image.ImageSchema

class VerifyAssembleFeatures extends TestBase with EstimatorFuzzing[AssembleFeatures] {
def testObjects(): Seq[TestObject[AssembleFeatures]] = List(new TestObject(
@@ -172,8 +172,8 @@ class VerifyFeaturize extends TestBase with EstimatorFuzzing[Featurize] {
// Expected is image size with width and height
val expectedSize = imageSize + 2
val rowRDD: RDD[Row] = sc.parallelize(Seq[Row](
Row(Row(path1, height, width, imgType, Array.fill[Byte](imageSize)(1))),
Row(Row(path2, height, width, imgType, Array.fill[Byte](imageSize)(1)))
Row(Row(path1, height, width, 3, imgType, Array.fill[Byte](imageSize)(1))),
Row(Row(path2, height, width, 3, imgType, Array.fill[Byte](imageSize)(1)))
))
val dataset = session.createDataFrame(rowRDD, imageDFSchema)
val result: DataFrame = featurize(dataset, includeFeaturesColumns = false)
@@ -5,7 +5,7 @@ package com.microsoft.ml.spark

import com.microsoft.CNTK.CNTKExtensions._
import com.microsoft.CNTK.{SerializableFunction => CNTKFunction}
import com.microsoft.ml.spark.schema.{DatasetExtensions, ImageSchema}
import com.microsoft.ml.spark.schema.DatasetExtensions
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.param._
@@ -140,7 +140,7 @@ class ImageFeaturizer(val uid: String) extends Transformer with HasInputCol with

val inputSchema = dataset.schema(getInputCol).dataType

val unrolledDF = if (inputSchema == ImageSchema.columnSchema) {
val unrolledDF = if (ImageSchemaUtils.isImage(inputSchema)) {
val prepare = new ResizeImageTransformer()
.setInputCol(getInputCol)
.setWidth(requiredSize(0).toInt)
Oops, something went wrong.

0 comments on commit a3ab017

Please sign in to comment.