Skip to content

Commit

Permalink
Refactor image reader to new schema
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Jan 8, 2019
1 parent a3dd0ef commit a3ab017
Show file tree
Hide file tree
Showing 42 changed files with 940 additions and 1,000 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
"source": [
"# Load the images\n",
"# use flowers_and_labels.parquet on larger cluster in order to get better results\n",
"imagesWithLabels = spark.read.parquet(\"wasbs://publicwasb@mmlspark.blob.core.windows.net/flowers_and_labels_small.parquet\") \\\n",
" .withColumn(\"labels\", col(\"labels\").cast(\"Double\"))\n",
"imagesWithLabels = spark.read.parquet(\"wasbs://publicwasb@mmlspark.blob.core.windows.net/flowers_and_labels2.parquet\") \\\n",
" .withColumnRenamed(\"bytes\",\"image\").sample(.1)\n",
"\n",
"imagesWithLabels.printSchema()"
]
Expand Down Expand Up @@ -206,7 +206,7 @@
"evaluate(deepResults,\"CNTKModel + LR\")\n",
"plt.subplot(1,2,2)\n",
"evaluate(basicResults,\"LR\")\n",
"# Note that on the larger dataset the accuracy will bump up from 44% to >90%\n",
"# Note that on the larger dataset the accuracy will bump up from 44% to >90%\n",
"display(plt.show())"
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"from mmlspark import toNDArray\n",
"\n",
"imageDir = \"wasbs://publicwasb@mmlspark.blob.core.windows.net/sampleImages\"\n",
"images = spark.readImages(imageDir, recursive = True, sampleRatio = 0.1).cache()\n",
"images = spark.read.image().load(imageDir).cache()\n",
"images.printSchema()\n",
"print(images.count())"
]
Expand All @@ -52,7 +52,7 @@
"metadata": {},
"outputs": [],
"source": [
"imageStream = spark.streamImages(imageDir + \"/*\", sampleRatio = 0.1)\n",
"imageStream = spark.readStream.image().load(imageDir)\n",
"query = imageStream.select(\"image.height\").writeStream.format(\"memory\").queryName(\"heights\").start()\n",
"print(\"Streaming query activity: {}\".format(query.isActive))"
]
Expand Down Expand Up @@ -116,8 +116,8 @@
"im = data[2][0] # the image is in the first column of a given row\n",
"\n",
"print(\"image type: {}, number of fields: {}\".format(type(im), len(im)))\n",
"print(\"image path: {}\".format(im.path))\n",
"print(\"height: {}, width: {}, OpenCV type: {}\".format(im.height, im.width, im.type))\n",
"print(\"image path: {}\".format(im.origin))\n",
"print(\"height: {}, width: {}, OpenCV type: {}\".format(im.height, im.width, im.mode))\n",
"\n",
"arr = toNDArray(im) # convert to numpy array\n",
"Image.fromarray(arr, \"RGB\") # display the image inside notebook\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
" .server() \\\n",
" .replyTo(\"my_api\") \\\n",
" .queryName(\"my_query\") \\\n",
" .option(\"checkpointLocation\", \"checkpoints-{}\".format(uuid.uuid1())) \\\n",
" .option(\"checkpointLocation\", \"file:///tmp/checkpoints-{}\".format(uuid.uuid1())) \\\n",
" .start()\n"
]
},
Expand Down
4 changes: 2 additions & 2 deletions src/cntk-model/src/test/scala/CNTKTestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
package com.microsoft.ml.spark

import com.microsoft.ml.spark.FileUtilities.File
import com.microsoft.ml.spark.Readers.implicits._
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql._
import com.microsoft.ml.spark.IOImplicits._

trait CNTKTestUtils extends TestBase {

Expand Down Expand Up @@ -38,7 +38,7 @@ trait CNTKTestUtils extends TestBase {
}

def testImages(spark: SparkSession): DataFrame = {
val images = spark.readImages(imagePath, true)
val images = spark.read.image.load(imagePath)

val unroll = new UnrollImage().setInputCol("image").setOutputCol(inputCol)

Expand Down
6 changes: 3 additions & 3 deletions src/cntk-train/src/test/scala/ValidateCntkTrain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import java.io.File
import java.net.URI

import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexerModel}
import com.microsoft.ml.spark.Readers.implicits._
import com.microsoft.ml.spark.IOImplicits._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.scalatest.{BeforeAndAfterEach, Suite}
Expand Down Expand Up @@ -255,12 +255,12 @@ TrainNetwork = {
val indexedLabel = "idxlabels"
val labelCol = "labels"

val images = session.readImages(imagePath, true)
val images = session.read.image.load(imagePath)

// Label annotation: CIFAR is constructed here as
// 01234-01.png, meaning (len - 5, len - 3) is label
val pathLen = images.first.getStruct(0).getString(0).length
val labeledData = images.withColumn(tmpLabel, images("image.path").substr(pathLen - 5, 2).cast("float"))
val labeledData = images.withColumn(tmpLabel, images("image.origin").substr(pathLen - 5, 2).cast("float"))

// Unroll images into Spark representation
val unroller = new UnrollImage().setOutputCol(inputCol).setInputCol("image")
Expand Down
15 changes: 6 additions & 9 deletions src/core/schema/src/main/scala/BinaryFileSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package com.microsoft.ml.spark.schema

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{BinaryType, StringType, StructField, StructType}
import org.apache.spark.sql.types._

object BinaryFileSchema {

Expand All @@ -20,13 +20,10 @@ object BinaryFileSchema {
def getPath(row: Row): String = row.getString(0)
def getBytes(row: Row): Array[Byte] = row.getAs[Array[Byte]](1)

/** Check if the dataframe column contains binary file data (i.e. has BinaryFileSchema)
*
* @param df
* @param column
* @return
*/
def isBinaryFile(df: DataFrame, column: String): Boolean =
df.schema(column).dataType == columnSchema
/** Check if the dataframe column contains binary file data (i.e. has BinaryFileSchema) */
def isBinaryFile(dt: DataType): Boolean =
dt == columnSchema

def isBinaryFile(sf: StructField): Boolean =
isBinaryFile(sf.dataType)
}
141 changes: 0 additions & 141 deletions src/core/schema/src/main/scala/ImageSchema.scala

This file was deleted.

33 changes: 33 additions & 0 deletions src/core/schema/src/main/scala/ImageSchemaUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark

import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.sql.types._

object ImageSchemaUtils {

val columnSchemaNullable = {
StructType(
StructField("origin", StringType, true) ::
StructField("height", IntegerType, true) ::
StructField("width", IntegerType, true) ::
StructField("nChannels", IntegerType, true) ::
// OpenCV-compatible type: CV_8UC3 in most cases
StructField("mode", IntegerType, true) ::
// Bytes in OpenCV-compatible order: row-wise BGR in most cases
StructField("data", BinaryType, true) :: Nil)
}

val imageSchemaNullable = StructType(StructField("image", columnSchemaNullable, true) :: Nil)

def isImage(dataType: DataType): Boolean = {
dataType == ImageSchema.columnSchema ||
dataType == columnSchemaNullable
}

def isImage(dataType: StructField): Boolean = {
isImage(dataType.dataType)
}
}
9 changes: 5 additions & 4 deletions src/featurize/src/main/scala/AssembleFeatures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import java.sql.{Date, Timestamp}
import java.time.temporal.ChronoField

import com.microsoft.ml.spark.schema.DatasetExtensions._
import com.microsoft.ml.spark.schema.{CategoricalColumnInfo, DatasetExtensions, ImageSchema}
import com.microsoft.ml.spark.schema.{CategoricalColumnInfo, DatasetExtensions}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.feature._
import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{SparseVector, Vectors}
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -213,7 +214,7 @@ class AssembleFeatures(override val uid: String) extends Estimator[AssembleFeatu
columnNamesToFeaturize.colNamesToTypes += unusedColumnName -> dataType
columnNamesToFeaturize.conversionColumnNamesMap += col -> unusedColumnName
}
case _ if ImageSchema.isImage(datasetAsDf, col) =>
case _ if ImageSchemaUtils.isImage(datasetAsDf.schema(col)) =>
if (!getAllowImages) {
throw new UnsupportedOperationException("Featurization of images columns disabled")
}
Expand Down Expand Up @@ -390,9 +391,9 @@ class AssembleFeaturesModel(val uid: String,
})
Seq(dataset(col),
extractTimeFeatures(dataset(col)).as(tmpRenamedCols, dataset.schema(col).metadata))
case imageType if imageType == ImageSchema.columnSchema =>
case t if ImageSchemaUtils.isImage(t) =>
val extractImageFeatures = udf((row: Row) => {
val image = ImageSchema.getBytes(row).map(_.toDouble)
val image = ImageSchema.getData(row).map(_.toDouble)
val height = ImageSchema.getHeight(row).toDouble
val width = ImageSchema.getWidth(row).toDouble
Vectors.dense((height :: (width :: image.toList)).toArray)
Expand Down
6 changes: 3 additions & 3 deletions src/featurize/src/test/scala/VerifyFeaturize.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import java.sql.{Date, Timestamp}
import java.util.GregorianCalendar

import com.microsoft.ml.spark.FileUtilities.File
import com.microsoft.ml.spark.schema.ImageSchema
import org.apache.spark.ml.{Estimator, PipelineModel}
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vectors}
Expand All @@ -18,6 +17,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.commons.io.FileUtils
import org.apache.spark.ml.image.ImageSchema

class VerifyAssembleFeatures extends TestBase with EstimatorFuzzing[AssembleFeatures] {
def testObjects(): Seq[TestObject[AssembleFeatures]] = List(new TestObject(
Expand Down Expand Up @@ -172,8 +172,8 @@ class VerifyFeaturize extends TestBase with EstimatorFuzzing[Featurize] {
// Expected is image size with width and height
val expectedSize = imageSize + 2
val rowRDD: RDD[Row] = sc.parallelize(Seq[Row](
Row(Row(path1, height, width, imgType, Array.fill[Byte](imageSize)(1))),
Row(Row(path2, height, width, imgType, Array.fill[Byte](imageSize)(1)))
Row(Row(path1, height, width, 3, imgType, Array.fill[Byte](imageSize)(1))),
Row(Row(path2, height, width, 3, imgType, Array.fill[Byte](imageSize)(1)))
))
val dataset = session.createDataFrame(rowRDD, imageDFSchema)
val result: DataFrame = featurize(dataset, includeFeaturesColumns = false)
Expand Down
4 changes: 2 additions & 2 deletions src/image-featurizer/src/main/scala/ImageFeaturizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.ml.spark

import com.microsoft.CNTK.CNTKExtensions._
import com.microsoft.CNTK.{SerializableFunction => CNTKFunction}
import com.microsoft.ml.spark.schema.{DatasetExtensions, ImageSchema}
import com.microsoft.ml.spark.schema.DatasetExtensions
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -140,7 +140,7 @@ class ImageFeaturizer(val uid: String) extends Transformer with HasInputCol with

val inputSchema = dataset.schema(getInputCol).dataType

val unrolledDF = if (inputSchema == ImageSchema.columnSchema) {
val unrolledDF = if (ImageSchemaUtils.isImage(inputSchema)) {
val prepare = new ResizeImageTransformer()
.setInputCol(getInputCol)
.setWidth(requiredSize(0).toInt)
Expand Down
Loading

0 comments on commit a3ab017

Please sign in to comment.