Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-531] Custom Operator Example for Scala #11401

Merged
merged 8 commits into from
Jul 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ package org.apache.mxnet
* Main code will be generated during compile time through Macros
*/
object NDArrayAPI extends NDArrayAPIBase {
// TODO: Implement CustomOp for NDArray
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,19 @@
*/
package org.apache.mxnet

import scala.collection.mutable


@AddSymbolAPIs(false)
/**
* typesafe Symbol API: Symbol.api._
* Main code will be generated during compile time through Macros
*/
object SymbolAPI extends SymbolAPIBase {
def Custom (op_type : String, kwargs : mutable.Map[String, Any],
name : String = null, attr : Map[String, String] = null) : Symbol = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should your Custom Op also take optional Seq() for positional arguments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, kwargs should be able to maintain all customer needs to create a custom Op.

val map = kwargs
map.put("op_type", op_type)
Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@

package org.apache.mxnetexamples.customop

import org.apache.mxnet.Shape
import org.apache.mxnet.IO
import org.apache.mxnet.DataIter
import org.apache.mxnet.{DataIter, IO, Shape}

/**
* @author Depeng Liang
*/
object Data {
// return train and val iterators for mnist
def mnistIterator(dataPath: String, batchSize: Int, inputShape: Shape): (DataIter, DataIter) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,26 @@

package org.apache.mxnetexamples.customop

import org.apache.mxnet.Callback.Speedometer
import org.apache.mxnet.DType.DType
import org.apache.mxnet.{Accuracy, Context, CustomOp, CustomOpProp, NDArray, Operator, Shape, Symbol, Xavier}
import org.apache.mxnet.optimizer.RMSProp
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._
import org.apache.mxnet.Symbol
import org.apache.mxnet.DType.DType
import org.apache.mxnet.DataIter
import org.apache.mxnet.DataBatch
import org.apache.mxnet.NDArray
import org.apache.mxnet.Shape
import org.apache.mxnet.EvalMetric
import org.apache.mxnet.Context
import org.apache.mxnet.Xavier
import org.apache.mxnet.optimizer.RMSProp
import org.apache.mxnet.CustomOp
import org.apache.mxnet.CustomOpProp
import org.apache.mxnet.Operator
import org.apache.mxnet.optimizer.SGD
import org.apache.mxnet.Accuracy
import org.apache.mxnet.Callback.Speedometer
import scala.collection.mutable

/**
* Example of CustomOp
* @author Depeng Liang
*/
* Example of CustomOp
*/
object ExampleCustomOp {
private val logger = LoggerFactory.getLogger(classOf[ExampleCustomOp])

class Softmax(_param: Map[String, String]) extends CustomOp {

override def forward(sTrain: Boolean, req: Array[String],
inData: Array[NDArray], outData: Array[NDArray], aux: Array[NDArray]): Unit = {
override def forward(sTrain: Boolean, req: Array[String], inData: Array[NDArray],
outData: Array[NDArray], aux: Array[NDArray]): Unit = {
val xShape = inData(0).shape
val x = inData(0).toArray.grouped(xShape(1)).toArray
val yArr = x.map { it =>
Expand All @@ -63,8 +52,8 @@ object ExampleCustomOp {
}

override def backward(req: Array[String], outGrad: Array[NDArray],
inData: Array[NDArray], outData: Array[NDArray],
inGrad: Array[NDArray], aux: Array[NDArray]): Unit = {
inData: Array[NDArray], outData: Array[NDArray],
inGrad: Array[NDArray], aux: Array[NDArray]): Unit = {
val l = inData(1).toArray.map(_.toInt)
val oShape = outData(0).shape
val yArr = outData(0).toArray.grouped(oShape(1)).toArray
Expand All @@ -86,24 +75,121 @@ object ExampleCustomOp {
override def listOutputs(): Array[String] = Array("output")

override def inferShape(inShape: Array[Shape]):
(Array[Shape], Array[Shape], Array[Shape]) = {
(Array[Shape], Array[Shape], Array[Shape]) = {
val dataShape = inShape(0)
val labelShape = Shape(dataShape(0))
val outputShape = dataShape
(Array(dataShape, labelShape), Array(outputShape), null)
}

override def inferType(inType: Array[DType]):
(Array[DType], Array[DType], Array[DType]) = {
(Array[DType], Array[DType], Array[DType]) = {
(inType, inType.take(1), null)
}

override def createOperator(ctx: String, inShapes: Array[Array[Int]],
inDtypes: Array[Int]): CustomOp = new Softmax(this.kwargs)
inDtypes: Array[Int]): CustomOp = new Softmax(this.kwargs)
}

Operator.register("softmax", new SoftmaxProp)

def test(dataPath : String, ctx : Context) : Float = {
val data = Symbol.Variable("data")
val label = Symbol.Variable("label")
val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1")
val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu")
val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = "fc2")
val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2")
val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name = "fc3")
val kwargs = mutable.Map[String, Any]("label" -> label, "data" -> fc3)
val mlp = Symbol.api.Custom(op_type = "softmax", name = "softmax", kwargs = kwargs)

val (trainIter, testIter) =
Data.mnistIterator(dataPath, batchSize = 100, inputShape = Shape(784))

val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel
val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels)

val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
val argNames = mlp.listArguments()
val argDict = argNames.zip(argShapes.map(s => NDArray.empty(s, ctx))).toMap

val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
!datasAndLabels.contains(name)
}.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap

argDict.foreach { case (name, ndArray) =>
if (!datasAndLabels.contains(name)) {
initializer.initWeight(name, ndArray)
}
}

val executor = mlp.bind(ctx, argDict, gradDict)
val lr = 0.001f
val opt = new RMSProp(learningRate = lr, wd = 0.00001f)
val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
(idx, name, grad, opt.createState(idx, argDict(name)))
}

val evalMetric = new Accuracy
val batchEndCallback = new Speedometer(100, 100)
val numEpoch = 10
var validationAcc = 0.0f

for (epoch <- 0 until numEpoch) {
val tic = System.currentTimeMillis
evalMetric.reset()
var nBatch = 0
var epochDone = false

trainIter.reset()
while (!epochDone) {
var doReset = true
while (doReset && trainIter.hasNext) {
val dataBatch = trainIter.next()
argDict("data").set(dataBatch.data(0))
argDict("label").set(dataBatch.label(0))
executor.forward(isTrain = true)
executor.backward()
paramsGrads.foreach { case (idx, name, grad, optimState) =>
opt.update(idx, argDict(name), grad, optimState)
}
evalMetric.update(dataBatch.label, executor.outputs)
nBatch += 1
batchEndCallback.invoke(epoch, nBatch, evalMetric)
}
if (doReset) {
trainIter.reset()
}
epochDone = true
}
val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-accuracy=$v")
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalMetric.reset()
testIter.reset()
while (testIter.hasNext) {
val evalBatch = testIter.next()
argDict("data").set(evalBatch.data(0))
argDict("label").set(evalBatch.label(0))
executor.forward(isTrain = true)
evalMetric.update(evalBatch.label, executor.outputs)
evalBatch.dispose()
}
val (names, values) = evalMetric.get
names.zip(values).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Validation-accuracy=$v")
validationAcc = Math.max(validationAcc, v)
}
}
executor.dispose()
validationAcc
}

def main(args: Array[String]): Unit = {
val leop = new ExampleCustomOp
val parser: CmdLineParser = new CmdLineParser(leop)
Expand All @@ -115,98 +201,8 @@ object ExampleCustomOp {

val dataName = Array("data")
val labelName = Array("softmax_label")
test(leop.dataPath, ctx)

val data = Symbol.Variable("data")
val label = Symbol.Variable("label")
val fc1 = Symbol.FullyConnected("fc1")()(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation("relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected("fc2")()(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation("relu2")()(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected("fc3")()(Map("data" -> act2, "num_hidden" -> 10))
val mlp = Symbol.Custom("softmax")()(Map("data" -> fc3,
"label" -> label, "op_type" -> "softmax"))

val (trainIter, testIter) =
Data.mnistIterator(leop.dataPath, batchSize = 100, inputShape = Shape(784))

val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel
val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels)

val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
val argNames = mlp.listArguments()
val argDict = argNames.zip(argShapes.map(s => NDArray.empty(s, ctx))).toMap

val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
!datasAndLabels.contains(name)
}.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap

argDict.foreach { case (name, ndArray) =>
if (!datasAndLabels.contains(name)) {
initializer.initWeight(name, ndArray)
}
}

val executor = mlp.bind(ctx, argDict, gradDict)
val lr = 0.001f
val opt = new RMSProp(learningRate = lr, wd = 0.00001f)
val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
(idx, name, grad, opt.createState(idx, argDict(name)))
}

val evalMetric = new Accuracy
val batchEndCallback = new Speedometer(100, 100)
val numEpoch = 20

for (epoch <- 0 until numEpoch) {
val tic = System.currentTimeMillis
evalMetric.reset()
var nBatch = 0
var epochDone = false

trainIter.reset()
while (!epochDone) {
var doReset = true
while (doReset && trainIter.hasNext) {
val dataBatch = trainIter.next()
argDict("data").set(dataBatch.data(0))
argDict("label").set(dataBatch.label(0))
executor.forward(isTrain = true)
executor.backward()
paramsGrads.foreach { case (idx, name, grad, optimState) =>
opt.update(idx, argDict(name), grad, optimState)
}
evalMetric.update(dataBatch.label, executor.outputs)
nBatch += 1
batchEndCallback.invoke(epoch, nBatch, evalMetric)
}
if (doReset) {
trainIter.reset()
}
epochDone = true
}
val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-accuracy=$v")
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalMetric.reset()
testIter.reset()
while (testIter.hasNext) {
val evalBatch = testIter.next()
argDict("data").set(evalBatch.data(0))
argDict("label").set(evalBatch.label(0))
executor.forward(isTrain = true)
evalMetric.update(evalBatch.label, executor.outputs)
evalBatch.dispose()
}
val (names, values) = evalMetric.get
names.zip(values).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Validation-accuracy=$v")
}
}
executor.dispose()
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
Expand Down
Loading