Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
protected var _argDict: Map[String, NDArray] = null
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null
private[mxnet] var _ctx: Context = null
private[mxnet] var _gradsReq: Iterable[_] = null
private[mxnet] var _group2ctx: Map[String, Context] = null

private var disposed = false

Expand All @@ -139,6 +142,98 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
}
}

/**
* Return a new executor with the same symbol and shared memory,
* but different input/output shapes.
* For runtime reshaping, variable length sequences, etc.
* The returned executor shares state with the current one,
* and cannot be used in parallel with it.
* @param partialShaping Whether to allow changing the shape of unspecified arguments.
* @param allowUpSizing Whether to allow allocating new ndarrays that's larger than the original.
* @param kwargs Map of string to Shape.
* - new shape for arguments.
* @return
* executor A new executor that shares memory with this.
*/
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
require(argShapes != null, "Insufficient argument shapes provided.")

var newArgDict = Map[String, NDArray]()
var newGradDict = Map[String, NDArray]()

this.symbol.listArguments().zipWithIndex.foreach { case (name, i) =>
val newShape = argShapes(i)
val arr = this.argArrays(i)
val dArr = if (this.gradArrays == null) null else this.gradArrays(i)
if (partialShaping || kwargs.contains(name) || newShape.equals(arr.shape)) {
if (newShape.product > arr.shape.product) {
require(allowUpSizing, s"New shape of arg:$name larger than original. " +
"First making a big executor and then down sizing it " +
"is more efficient than the reverse." +
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context))
if (dArr != null) {
newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context))
}
} else {
newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
if (dArr != null) {
newGradDict = newGradDict + (name -> dArr.reshape(newShape.toArray))
}
}
} else {
import java.lang.AssertionError
throw new AssertionError(s"Shape of unspecified array arg:$name changed." +
"This can cause the new executor to not share parameters " +
"with the old one. Please check for error in network." +
"If this is intended, set partialShaping = true to suppress this warning.")
}
}

var newAuxDict = Map[String, NDArray]()
val zip3 = (this.symbol.listAuxiliaryStates, auxShapes, this.auxArrays).zipped
zip3.foreach { case (name, newShape, arr) =>
if (partialShaping || newShape.equals(arr.shape)) {
if (newShape.product > arr.shape.product) {
require(allowUpSizing, s"New shape of aux:$name larger than original. " +
"First making a big executor and then down sizing it " +
"is more efficient than the reverse." +
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context))
} else {
newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
}
} else {
import java.lang.AssertionError
throw new AssertionError(s"Shape of unspecified array aux:$name changed." +
"This can cause the new executor to not share parameters " +
"with the old one. Please check for error in network." +
"If this is intended, set partialShaping = true to suppress this warning.")
}
}
if (this._gradsReq.isInstanceOf[Seq[_]]) {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
this._gradsReq.asInstanceOf[Seq[String]],
newAuxDict,
this._group2ctx,
this)
} else {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
this._gradsReq.asInstanceOf[Map[String, String]],
newAuxDict,
this._group2ctx,
this)
}
}

/**
* list all the output ndarray
* @return A list of ndarray binded to the heads of executor.
Expand Down
14 changes: 14 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,20 @@ class LibInfo {
reqsArray: Array[Int],
auxArgsHandle: Array[NDArrayHandle],
out: ExecutorHandleRef): Int
@native def mxExecutorBindEX(handle: SymbolHandle,
deviceTypeId: Int,
deviceID: Int,
numCtx: Int,
ctxMapKeys: Array[String],
ctxMapDevTypes: Array[Int],
ctxMapDevIDs: Array[Int],
numArgs: Int,
argsHandle: Array[NDArrayHandle],
argsGradHandle: Array[NDArrayHandle],
reqsArray: Array[Int],
auxArgsHandle: Array[NDArrayHandle],
sharedExec: ExecutorHandle,
out: ExecutorHandleRef): Int
// scalastyle:on parameterNum
@native def mxSymbolSaveToFile(handle: SymbolHandle, fname: String): Int
@native def mxSymbolCreateFromFile(fname: String, handle: SymbolHandleRef): Int
Expand Down
Loading