Skip to content

Commit

Permalink
Updating MultiTask example to use new infer api and adding test for CI (
Browse files Browse the repository at this point in the history
apache#11605)

* Updating MultiTask example to use new infer api and adding test for CI

* Moved common multitask code into methods. Changed multitask test to scala suite convention
  • Loading branch information
andrewfayres authored and nswamy committed Jul 22, 2018
1 parent 6f374db commit a6833dd
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 115 deletions.
Expand Up @@ -17,9 +17,16 @@

package org.apache.mxnetexamples.multitask

import java.io.File
import java.net.URL

import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._

import org.apache.commons.io.FileUtils

import org.apache.mxnet.Symbol
import org.apache.mxnet.DataIter
import org.apache.mxnet.DataBatch
Expand All @@ -29,8 +36,10 @@ import org.apache.mxnet.EvalMetric
import org.apache.mxnet.Context
import org.apache.mxnet.Xavier
import org.apache.mxnet.optimizer.RMSProp
import org.apache.mxnet.Executor

import scala.collection.immutable.ListMap
import scala.sys.process.Process

/**
* Example of multi-task
Expand All @@ -41,13 +50,13 @@ object ExampleMultiTask {

def buildNetwork(): Symbol = {
val data = Symbol.Variable("data")
val fc1 = Symbol.FullyConnected("fc1")()(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation("relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected("fc2")()(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation("relu2")()(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected("fc3")()(Map("data" -> act2, "num_hidden" -> 10))
val sm1 = Symbol.SoftmaxOutput("softmax1")()(Map("data" -> fc3))
val sm2 = Symbol.SoftmaxOutput("softmax2")()(Map("data" -> fc3))
val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128)
val act1 = Symbol.api.Activation(data = Some(fc1), act_type = "relu")
val fc2 = Symbol.api.FullyConnected(data = Some(act1), num_hidden = 64)
val act2 = Symbol.api.Activation(data = Some(fc2), act_type = "relu")
val fc3 = Symbol.api.FullyConnected(data = Some(act2), num_hidden = 10)
val sm1 = Symbol.api.SoftmaxOutput(data = Some(fc3))
val sm2 = Symbol.api.SoftmaxOutput(data = Some(fc3))

val softmax = Symbol.Group(sm1, sm2)

Expand Down Expand Up @@ -133,7 +142,7 @@ object ExampleMultiTask {

for (i <- labels.indices) {
val (pred, label) = (preds(i), labels(i))
val predLabel = NDArray.argmax_channel(pred)
val predLabel = NDArray.api.argmax_channel(data = pred)
require(label.shape == predLabel.shape,
s"label ${label.shape} and prediction ${predLabel.shape}" +
s"should have the same length.")
Expand Down Expand Up @@ -191,131 +200,154 @@ object ExampleMultiTask {
}
}

def main(args: Array[String]): Unit = {
val lesk = new ExampleMultiTask
val parser: CmdLineParser = new CmdLineParser(lesk)
try {
parser.parseArgument(args.toList.asJava)
assert(lesk.dataPath != null)
def getTrainingData: String = {
val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci"
val tempDirPath = System.getProperty("java.io.tmpdir")
val modelDirPath = tempDirPath + File.separator + "multitask/"
val tmpFile = new File(tempDirPath + "/multitask/mnist.zip")
if (!tmpFile.exists()) {
FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"),
tmpFile)
}

val batchSize = 100
val numEpoch = 100
val ctx = if (lesk.gpu != -1) Context.gpu(lesk.gpu) else Context.cpu()
val lr = 0.001f
val network = buildNetwork()
val (trainIter, valIter) =
Data.mnistIterator(lesk.dataPath, batchSize = batchSize, inputShape = Shape(784))
val trainMultiIter = new MultiMnistIterator(trainIter)
val valMultiIter = new MultiMnistIterator(valIter)

val datasAndLabels = trainMultiIter.provideData ++ trainMultiIter.provideLabel
val (argShapes, outputShapes, auxShapes) = network.inferShape(datasAndLabels)

val initializer = new Xavier(factorType = "in", magnitude = 2.34f)

val argNames = network.listArguments()
val argDict = argNames.zip(argShapes.map(NDArray.empty(_, ctx))).toMap
val auxNames = network.listAuxiliaryStates()
val auxDict = auxNames.zip(auxShapes.map(NDArray.empty(_, ctx))).toMap

val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
!datasAndLabels.contains(name)
}.map(x => x._1 -> NDArray.empty(x._2, ctx) ).toMap

argDict.foreach { case (name, ndArray) =>
if (!datasAndLabels.contains(name)) {
initializer.initWeight(name, ndArray)
}
}
// TODO: Need to confirm with Windows
Process("unzip " + tempDirPath + "/multitask/mnist.zip -d "
+ tempDirPath + "/multitask/") !

modelDirPath
}

val data = argDict("data")
val label1 = argDict("softmax1_label")
val label2 = argDict("softmax2_label")
def train(batchSize: Int, numEpoch: Int, ctx: Context, modelDirPath: String):
(Executor, MultiAccuracy) = {
val lr = 0.001f
val network = ExampleMultiTask.buildNetwork()
val (trainIter, valIter) =
Data.mnistIterator(modelDirPath, batchSize = batchSize, inputShape = Shape(784))
val trainMultiIt = new MultiMnistIterator(trainIter)
val valMultiIter = new MultiMnistIterator(valIter)

val maxGradNorm = 0.5f
val executor = network.bind(ctx, argDict, gradDict)
val datasAndLabels = trainMultiIt.provideData ++ trainMultiIt.provideLabel

val opt = new RMSProp(learningRate = lr, wd = 0.00001f)
val (argShapes, outputShapes, auxShapes) = network.inferShape(trainMultiIt.provideData("data"))
val initializer = new Xavier(factorType = "in", magnitude = 2.34f)

val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
(idx, name, grad, opt.createState(idx, argDict(name)))
val argNames = network.listArguments
val argDict = argNames.zip(argShapes.map(NDArray.empty(_, ctx))).toMap

val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
!datasAndLabels.contains(name)
}.map(x => x._1 -> NDArray.empty(x._2, ctx)).toMap

argDict.foreach { case (name, ndArray) =>
if (!datasAndLabels.contains(name)) {
initializer.initWeight(name, ndArray)
}
}

val evalMetric = new MultiAccuracy(num = 2, name = "multi_accuracy")
val batchEndCallback = new Speedometer(batchSize, 50)

for (epoch <- 0 until numEpoch) {
// Training phase
val tic = System.currentTimeMillis
evalMetric.reset()
var nBatch = 0
var epochDone = false
// Iterate over training data.
trainMultiIter.reset()

while (!epochDone) {
var doReset = true
while (doReset && trainMultiIter.hasNext) {
val dataBatch = trainMultiIter.next()

data.set(dataBatch.data(0))
label1.set(dataBatch.label(0))
label2.set(dataBatch.label(1))

executor.forward(isTrain = true)
executor.backward()

val norm = Math.sqrt(paramsGrads.map { case (idx, name, grad, optimState) =>
val l2Norm = NDArray.norm(grad / batchSize).toScalar
l2Norm * l2Norm
}.sum).toFloat

paramsGrads.foreach { case (idx, name, grad, optimState) =>
if (norm > maxGradNorm) {
grad.set(grad.toArray.map(_ * (maxGradNorm / norm)))
opt.update(idx, argDict(name), grad, optimState)
} else opt.update(idx, argDict(name), grad, optimState)
}
val data = argDict("data")
val label1 = argDict("softmaxoutput0_label")
val label2 = argDict("softmaxoutput1_label")
val maxGradNorm = 0.5f
val executor = network.bind(ctx, argDict, gradDict)

// evaluate at end, so out_cpu_array can lazy copy
evalMetric.update(dataBatch.label, executor.outputs)
val opt = new RMSProp(learningRate = lr, wd = 0.00001f)

nBatch += 1
batchEndCallback.invoke(epoch, nBatch, evalMetric)
}
if (doReset) {
trainMultiIter.reset()
val paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
(idx, name, grad, opt.createState(idx, argDict(name)))
}

val evalMetric = new ExampleMultiTask.MultiAccuracy(num = 2, name = "multi_accuracy")
val batchEndCallback = new ExampleMultiTask.Speedometer(batchSize, 50)

for (epoch <- 0 until numEpoch) {
// Training phase
val tic = System.currentTimeMillis
evalMetric.reset()
var nBatch = 0
var epochDone = false
// Iterate over training data.
trainMultiIt.reset()

while (!epochDone) {
var doReset = true
while (doReset && trainMultiIt.hasNext) {
val dataBatch = trainMultiIt.next()

data.set(dataBatch.data(0))
label1.set(dataBatch.label(0))
label2.set(dataBatch.label(1))

executor.forward(isTrain = true)
executor.backward()

val norm = Math.sqrt(paramsGrads.map { case (idx, name, grad, optimState) =>
val l2Norm = NDArray.api.norm(data = (grad / batchSize)).toScalar
l2Norm * l2Norm
}.sum).toFloat

paramsGrads.foreach { case (idx, name, grad, optimState) =>
if (norm > maxGradNorm) {
grad.set(grad.toArray.map(_ * (maxGradNorm / norm)))
opt.update(idx, argDict(name), grad, optimState)
} else opt.update(idx, argDict(name), grad, optimState)
}
// this epoch is done
epochDone = true

// evaluate at end, so out_cpu_array can lazy copy
evalMetric.update(dataBatch.label, executor.outputs)

nBatch += 1
batchEndCallback.invoke(epoch, nBatch, evalMetric)
}
var nameVals = evalMetric.get
nameVals.foreach { case (name, value) =>
logger.info(s"Epoch[$epoch] Train-$name=$value")
if (doReset) {
trainMultiIt.reset()
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
// this epoch is done
epochDone = true
}
var nameVals = evalMetric.get
nameVals.foreach { case (name, value) =>
logger.info(s"Epoch[$epoch] Train-$name=$value")
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalMetric.reset()
valMultiIter.reset()
while (valMultiIter.hasNext) {
val evalBatch = valMultiIter.next()
evalMetric.reset()
valMultiIter.reset()
while (valMultiIter.hasNext) {
val evalBatch = valMultiIter.next()

data.set(evalBatch.data(0))
label1.set(evalBatch.label(0))
label2.set(evalBatch.label(1))
data.set(evalBatch.data(0))
label1.set(evalBatch.label(0))
label2.set(evalBatch.label(1))

executor.forward(isTrain = true)
executor.forward(isTrain = true)

evalMetric.update(evalBatch.label, executor.outputs)
evalBatch.dispose()
}
evalMetric.update(evalBatch.label, executor.outputs)
evalBatch.dispose()
}

nameVals = evalMetric.get
nameVals.foreach { case (name, value) =>
logger.info(s"Epoch[$epoch] Validation-$name=$value")
}
nameVals = evalMetric.get
nameVals.foreach { case (name, value) =>
logger.info(s"Epoch[$epoch] Validation-$name=$value")
}
}

(executor, evalMetric)
}

def main(args: Array[String]): Unit = {
val lesk = new ExampleMultiTask
val parser: CmdLineParser = new CmdLineParser(lesk)
try {
parser.parseArgument(args.toList.asJava)

val batchSize = 100
val numEpoch = 5
val ctx = if (lesk.gpu != -1) Context.gpu(lesk.gpu) else Context.cpu()

val modelPath = if (lesk.dataPath == null) lesk.dataPath else getTrainingData

val (executor, evalMetric) = train(batchSize, numEpoch, ctx, modelPath)
executor.dispose()

} catch {
Expand Down
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.multitask

import org.apache.commons.io.FileUtils
import org.apache.mxnet.Context
import org.scalatest.FunSuite
import org.slf4j.LoggerFactory
import org.apache.mxnet.Symbol
import org.apache.mxnet.DataIter
import org.apache.mxnet.DataBatch
import org.apache.mxnet.NDArray
import org.apache.mxnet.Shape
import org.apache.mxnet.EvalMetric
import org.apache.mxnet.Context
import org.apache.mxnet.Xavier
import org.apache.mxnet.optimizer.RMSProp
import java.io.File
import java.net.URL

import scala.sys.process.Process
import scala.collection.immutable.ListMap
import scala.collection.immutable.IndexedSeq
import scala.collection.mutable.{ArrayBuffer, ListBuffer}


/**
* Integration test for imageClassifier example.
* This will run as a part of "make scalatest"
*/
class MultiTaskSuite extends FunSuite {

test("Multitask Test") {
val logger = LoggerFactory.getLogger(classOf[MultiTaskSuite])
logger.info("Multitask Test...")

val batchSize = 100
val numEpoch = 10
val ctx = Context.cpu()

val modelPath = ExampleMultiTask.getTrainingData
val (executor, evalMetric) = ExampleMultiTask.train(batchSize, numEpoch, ctx, modelPath)
evalMetric.get.foreach { case (name, value) =>
assert(value >= 0.95f)
}
executor.dispose()
}

}

0 comments on commit a6833dd

Please sign in to comment.