Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Memory fixes. Resolves #10867, and resolves #14080 (#14372)
Browse files Browse the repository at this point in the history
* Fixes for memory leak when reshaping executor

* Fixed Adam Optimizer memory leak

* Cleanup for PR

* Added unit test for new ResourceScope method

* Removing import that was added by overzealous ide

* Add back in an import

* Added flags for executor to know whether or not it owns NDArrays for disposal

* Moving to ResourceScope.using implementation

* Changes to make ResourceScope.using work with existing scope

* Updating ResourceScope to work with existing scopes via usingIfScopeExists method

* Fix clojure unit tests

* Fixes to be compatibile with how clojure is using ResourceScope

* Removing some unnecessary changes

* Adding scope assertion in unit test
  • Loading branch information
andrewfayres authored and nswamy committed Mar 28, 2019
1 parent 645c778 commit 102b46f
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 89 deletions.
47 changes: 39 additions & 8 deletions scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
Expand Up @@ -45,29 +45,47 @@ object Executor {
* @see Symbol.bind : to create executor
*/
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] val symbol: Symbol) extends NativeResource {
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
private[mxnet] val symbol: Symbol,
private[mxnet] var argArrays: Array[NDArray] = null,
private[mxnet] var gradArrays: Array[NDArray] = null,
private[mxnet] var auxArrays: Array[NDArray] = null,
private var _ctx: Context = null,
private var _gradsReq: Iterable[_] = null,
private var _group2ctx: Map[String, Context] = null
) extends NativeResource {

val outputs: Array[NDArray] = getOutputs
protected var _argDict: Map[String, NDArray] = null
protected var _gradDict: Map[String, NDArray] = null
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null
private[mxnet] var _ctx: Context = null
private[mxnet] var _gradsReq: Iterable[_] = null
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])

private[mxnet] var ownsArgArrays = false
private[mxnet] var ownsGradArrays = false
private[mxnet] var ownsAuxArrays = false

override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
override val bytesAllocated: Long = 0
override val ref: NativeResourceRef = super.register()

override def dispose(): Unit = {
if (!super.isDisposed) {
super.dispose()
outputs.foreach(o => o.dispose())
// Symbol.bind clones symbol when creating the executor so we need to dispose of the clone
symbol.dispose()
if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => a.dispose())}
if (ownsGradArrays && gradArrays != null) {gradArrays.foreach(
// Symbol will sometimes fill this with nulls so we've got to check the elements too
a => if (a != null) {a.dispose()})
}
if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => a.dispose())}
if (_argDict != null) {_argDict.foreach(a => a._2.dispose())}
if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())}
if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())}
}
}

Expand All @@ -86,6 +104,9 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
*/
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
var setArgOwner = false
var setAuxOwner = false
var setGradOwner = false
val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
// TODO: more precise error message should be provided by backend
require(argShapes != null, "Shape inference failed." +
Expand All @@ -107,8 +128,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype))
setArgOwner = true
if (dArr != null) {
newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype))
setGradOwner = true
}
} else {
newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
Expand All @@ -135,6 +158,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context))
setAuxOwner = true
} else {
newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
}
Expand All @@ -145,7 +169,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If this is intended, set partialShaping = true to suppress this warning.")
}
}
if (this._gradsReq.isInstanceOf[Seq[_]]) {
val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
Expand All @@ -162,6 +186,13 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
this._group2ctx,
this)
}

// This method has created new NDArrays that will need to be managed by the new Executor
if (setArgOwner) reshapedExecutor.ownsArgArrays = true
if (setGradOwner) reshapedExecutor.ownsGradArrays = true
if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true

reshapedExecutor
}

/**
Expand Down
20 changes: 11 additions & 9 deletions scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
Expand Up @@ -28,15 +28,17 @@ object Optimizer {
def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
new MXKVStoreUpdater with MXKVStoreCachedStates {
override def update(index: Int, grad: NDArray, weight: NDArray): Unit = {
val state =
if (states.contains(index)) {
states.get(index).get
} else {
val newState = optimizer.createState(index, weight)
states.put(index, newState)
newState
}
optimizer.update(index, weight, grad, state)
ResourceScope.usingIfScopeExists(this.scope) {
val state =
if (states.contains(index)) {
states.get(index).get
} else {
val newState = optimizer.createState(index, weight)
states.put(index, newState)
newState
}
optimizer.update(index, weight, grad, state)
}
}

override def dispose(): Unit = {
Expand Down
Expand Up @@ -48,8 +48,10 @@ class ResourceScope extends AutoCloseable {
*/
override def close(): Unit = {
ResourceScope.removeFromThreadLocal(this)
resourceQ.foreach(resource => if (resource != null) resource.dispose(false) )
resourceQ.clear()
if (!ResourceScope.threadLocalScopes.get().contains(this)) {
resourceQ.foreach(resource => if (resource != null) resource.dispose(false))
resourceQ.clear()
}
}

/**
Expand Down Expand Up @@ -145,7 +147,7 @@ object ResourceScope {
null.asInstanceOf[A] // we'll throw in finally
} finally {
var toThrow: Throwable = retThrowable
if (retThrowable eq null) curScope.close()
if (retThrowable eq null) curScope.close
else {
try {
curScope.close
Expand All @@ -160,6 +162,17 @@ object ResourceScope {
}
}

private[mxnet] def usingIfScopeExists[A](scope: Option[ResourceScope])(body: => A): A = {
if (scope == None) {
body
} else {
ResourceScope.addToThreadLocal(scope.get)
ResourceScope.using(scope.get){
body
}
}
}

// thread local Scopes
private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] {
override def initialValue(): ArrayBuffer[ResourceScope] =
Expand All @@ -179,7 +192,7 @@ object ResourceScope {
* @param r ResourceScope to remove
*/
private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = {
threadLocalScopes.get() -= r
threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r))
}

/**
Expand Down
21 changes: 13 additions & 8 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Expand Up @@ -803,18 +803,23 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
auxArgsHandle,
sharedHandle,
execHandle))
val executor = new Executor(execHandle.value, this.clone())
executor.argArrays = argsNDArray
executor.gradArrays = argsGradNDArray
executor.auxArrays = auxStatesNDArray
executor._ctx = new Context(ctx.deviceType, ctx.deviceId)
executor._gradsReq = gradsReq
executor._group2ctx =

val executorGroup2ctx =
if (group2ctx == null) null
else group2ctx.map { case (key, value) =>
key -> new Context(value.deviceType, value.deviceId)
}
executor

// If this is in a scope then we want to create the clone in the same scope
var newSymbol: Symbol = null
ResourceScope.usingIfScopeExists(this.scope) {
newSymbol = this.clone()
}

new Executor(execHandle.value, newSymbol, argsNDArray, argsGradNDArray,
auxStatesNDArray, new Context(ctx.deviceType, ctx.deviceId),
gradsReq, executorGroup2ctx)

}

/**
Expand Down
Expand Up @@ -299,7 +299,6 @@ class DataParallelExecutorGroup private[module](

private var batchSize: Int = -1
private var slices: Array[(Int, Int)] = null
private var _defaultExecs: Array[Executor] = null
private var execs: Array[Executor] = null
private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null
private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None
Expand Down Expand Up @@ -373,7 +372,12 @@ class DataParallelExecutorGroup private[module](
val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts))
val inputShapes
= dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
execs(i) = _defaultExecs(i).reshape(allowUpSizing = true, kwargs = inputShapes)

ResourceScope.usingIfScopeExists(execs(i).scope) {
val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
execs(i).dispose()
execs(i) = tmpExec
}
}
} else {
execs = (0 until contexts.length).map(i =>
Expand Down Expand Up @@ -434,9 +438,6 @@ class DataParallelExecutorGroup private[module](
*/
def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = {
if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
if (this._defaultExecs == null) {
this._defaultExecs = this.execs.map(x => x)
}
this.bindExec(dataShapes, labelShapes, None, reshape = true)
}
}
Expand Down
Expand Up @@ -19,7 +19,7 @@ package org.apache.mxnet.optimizer

import org.apache.mxnet.NDArrayConversions._
import org.apache.mxnet.util.SerializerUtils
import org.apache.mxnet.{LRScheduler, NDArray, Optimizer}
import org.apache.mxnet.{LRScheduler, NDArray, Optimizer, ResourceScope}

/**
* Adam optimizer as described in [King2014]
Expand Down Expand Up @@ -57,63 +57,54 @@ class Adam(val learningRate: Float = 0.002f, beta1: Float = 0.9f, beta2: Float =
* The auxiliary state used in optimization.
*/
override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = {
var lr =
(if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
})
lr = getLr(index, lr)

val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]

// increment time only when the first parameters is called
timeFirstIndex match {
case Some(idx) =>
if (idx == index) time += 1
case None =>
timeFirstIndex = Option(index)
time = 0 // all parameters share the same time
}

val t1: Int = time + 1
val learningRate = (lr *
math.sqrt(1.0 - math.pow(beta2, t1)) /
(1.0 - math.pow(beta1, t1))).toFloat
val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat

var resdGrad = grad * rescaleGrad
if (clipGradient != 0f) {
val oldResdGrad = resdGrad
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
oldResdGrad.dispose()
}

val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
.disposeDepsExcept(mean, resdGrad)
val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
.disposeDepsExcept(variance, resdGrad)
ResourceScope.using() {
var lr =
(if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
})
lr = getLr(index, lr)

val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))
.disposeDepsExcept(meanT, varianceT)
val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]

val wd = this.getWd(index, this.wd)
if (wd > 0.0f) {
val stepDelta = lr * wd * weight
step += stepDelta
stepDelta.dispose()
// increment time only when the first parameters is called
timeFirstIndex match {
case Some(idx) =>
if (idx == index) time += 1
case None =>
timeFirstIndex = Option(index)
time = 0 // all parameters share the same time
}

val t1: Int = time + 1
val learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1)) /
(1.0 - math.pow(beta1, t1))).toFloat
val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat

var resdGrad = grad * rescaleGrad
if (clipGradient != 0f) {
val oldResdGrad = resdGrad
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
}

val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))

val wd = this.getWd(index, this.wd)
if (wd > 0.0f) {
val stepDelta = lr * wd * weight
step += stepDelta
}

weight -= step
mean.set(meanT)
variance.set(varianceT)
(mean, variance)
}

weight -= step
mean.set(meanT)
variance.set(varianceT)

meanT.dispose()
varianceT.dispose()
step.dispose()
resdGrad.dispose()
}

// Create additional optimizer state: mean, variance
Expand Down

0 comments on commit 102b46f

Please sign in to comment.