From 68989b68a63219a6997f9f89663938db9e37270d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A2=81=E5=BE=B7=E6=BE=8E?= Date: Sun, 28 Aug 2016 11:39:48 +0800 Subject: [PATCH] [Scala], add EvalMetric TopK, F1 and Optimizer NAG, SGLD, ccSGD (#3149) * scalapkg, add TopK and F1 EvalMetric * scalapkg, add optimizer, NAG, SGLD, ccSGD --- .../src/main/scala/ml/dmlc/mxnet/Base.scala | 5 + .../main/scala/ml/dmlc/mxnet/EvalMetric.scala | 86 +++++++++++++++++- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 14 +++ .../scala/ml/dmlc/mxnet/optimizer/NAG.scala | 91 +++++++++++++++++++ .../scala/ml/dmlc/mxnet/optimizer/SGD.scala | 5 + .../scala/ml/dmlc/mxnet/optimizer/SGLD.scala | 70 ++++++++++++++ .../scala/ml/dmlc/mxnet/optimizer/ccSGD.scala | 76 ++++++++++++++++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 84 +++++++++++++++++ 8 files changed, 430 insertions(+), 1 deletion(-) create mode 100644 scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/NAG.scala create mode 100644 scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGLD.scala create mode 100644 scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/ccSGD.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 6f406ce7cc59..0d5a6d5fb427 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -24,6 +24,8 @@ object Base { type ExecutorHandle = CPtrAddress type SymbolHandle = CPtrAddress type RecordIOHandle = CPtrAddress + type OptimizerCreator = CPtrAddress + type OptimizerHandle = CPtrAddress type MXUintRef = RefInt type MXFloatRef = RefFloat @@ -35,6 +37,9 @@ object Base { type ExecutorHandleRef = RefLong type SymbolHandleRef = RefLong type RecordIOHandleRef = RefLong + type OptimizerCreatorRef = RefLong + type OptimizerHandleRef = RefLong + try { try { diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala index 35aa2eef6ada..c6bb8fe7aadd 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala @@ -4,7 +4,7 @@ package ml.dmlc.mxnet * Base class of all evaluation metrics * @param name Metric name * - * @author Yuan Tang, Yizhi Liu + * @author Yuan Tang, Yizhi Liu, Depeng Liang */ abstract class EvalMetric(protected val name: String) { @@ -64,6 +64,90 @@ class Accuracy extends EvalMetric("accuracy") { } } +/** + * Calculate top k predictions accuracy + */ +class TopKAccuracy(topK: Int) extends EvalMetric("top_k_accuracy") { + require(topK > 1, "Please use Accuracy if topK is no more than 1") + + override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = { + require(labels.length == preds.length, + "labels and predictions should have the same length.") + + for ((pred, label) <- preds zip labels) { + val predShape = pred.shape + val dims = predShape.length + require(dims <= 2, "Predictions should be no more than 2 dims.") + val labelArray = label.toArray + val numSamples = predShape(0) + if (dims == 1) { + val predArray = pred.toArray.zipWithIndex.sortBy(_._1).reverse.map(_._2) + require(predArray.length == labelArray.length) + this.sumMetric += + labelArray.zip(predArray).map { case (l, p) => if (l == p) 1 else 0 }.sum + } else if (dims == 2) { + val numclasses = predShape(1) + val predArray = pred.toArray.grouped(numclasses).map { a => + a.zipWithIndex.sortBy(_._1).reverse.map(_._2) + }.toArray + require(predArray.length == labelArray.length) + val topK = Math.max(this.topK, numclasses) + for (j <- 0 until topK) { + this.sumMetric += + labelArray.zip(predArray.map(_(j))).map { case (l, p) => if (l == p) 1 else 0 }.sum + } + } + this.numInst += numSamples + } + } +} + +/** + * Calculate the F1 score of a binary classification problem. + */ +class F1 extends EvalMetric("f1") { + override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = { + require(labels.length == preds.length, + "labels and predictions should have the same length.") + + for ((pred, label) <- preds zip labels) { + val predLabel = NDArray.argmaxChannel(pred) + require(label.shape == predLabel.shape, + s"label ${label.shape} and prediction ${predLabel.shape}" + + s"should have the same length.") + val labelArray = label.toArray + var unique = Array[Float]() + labelArray.foreach(l => if (!unique.contains(l)) unique = unique :+ l) + require(unique.length <= 2, "F1 currently only supports binary classification.") + + var truePositives, falsePositives, falseNegatives = 0f + for ((labelElem, predElem) <- labelArray zip predLabel.toArray) { + if (predElem == 1 && labelElem == 1) truePositives += 1 + else if (predElem == 1 && labelElem == 0) falsePositives += 1 + else if (predElem == 0 && labelElem == 1) falseNegatives += 1 + } + + val precision = { + if (truePositives + falsePositives > 0) truePositives / (truePositives + falsePositives) + else 0f + } + + val recall = { + if (truePositives + falseNegatives > 0) truePositives / (truePositives + falseNegatives) + else 0f + } + + val f1Score = { + if (precision + recall > 0) (2 * precision * recall) / (precision + recall) + else 0f + } + + this.sumMetric += f1Score + this.numInst += 1 + } + } +} + // Regression metrics /** diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index a8aada22b430..92f37c99d309 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -240,4 +240,18 @@ class LibInfo { @native def mxRecordIOReaderReadRecord(handle: RecordIOHandle, buf: RefString): Int @native def mxRecordIOWriterTell(handle: RecordIOHandle, pos: RefInt): Int @native def mxRecordIOReaderSeek(handle: RecordIOHandle, pos: Int): Int + + @native def mxOptimizerFindCreator(key: String, out: OptimizerCreatorRef): Int + @native def mxOptimizerCreateOptimizer(creator: OptimizerCreator, + numParam: Int, + keys: Array[String], + vals: Array[String], + out: OptimizerHandleRef): Int + @native def mxOptimizerFree(handle: OptimizerHandle): Int + @native def mxOptimizerUpdate(handle: OptimizerHandle, + index: Int, + weight: NDArrayHandle, + grad: NDArrayHandle, + lr: Float, + wd: Float): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/NAG.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/NAG.scala new file mode 100644 index 000000000000..59ea76e8b8b0 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/NAG.scala @@ -0,0 +1,91 @@ +package ml.dmlc.mxnet.optimizer + +import ml.dmlc.mxnet.{Optimizer, LRScheduler, NDArray} +import ml.dmlc.mxnet.NDArrayConversions._ + +/** + * SGD with nesterov. + * It is implemented according to + * https://github.com/torch/optim/blob/master/sgd.lua + * + * @author Depeng Liang + * + * @param learningRate Float, Step size. + * @param momentum Float, momentum value. + * @param wd Float, L2 regularization coefficient add to all the weights + * @param clipGradient Float, clip gradient in range [-clip_gradient, clip_gradient] + * @param lrScheduler The learning rate scheduler + */ +class NAG(val learningRate: Float = 0.01f, val momentum: Float = 0.0f, + val wd: Float = 0.0001f, val clipGradient: Float = 0f, + val lrScheduler: LRScheduler = null) extends Optimizer { + + if (lrScheduler != null) { + lrScheduler.baseLR = learningRate + } + + /** + * Update the parameters. + * @param index An unique integer key used to index the parameters + * @param weight weight ndarray + * @param grad grad ndarray + * @param state NDArray or other objects returned by initState + * The auxiliary state used in optimization. + */ + override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { + // TODO(bing) implement wd_bias, wd_gamma, wd_beta (copy from python package) + val lr = + (if (lrScheduler != null) { + val scheduledLr = lrScheduler(numUpdate) + updateCount(index) + scheduledLr + } else { + this.learningRate + }) * lrScale.getOrElse(index, 1f) + + val wd = getWd(index, this.wd) + var resdGrad = grad * this.rescaleGrad + if (clipGradient != 0f) { + // to get rid of memory leak + val oldResdGrad = resdGrad + resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) + oldResdGrad.dispose() + } + + if (state != null) { + val mom = state.asInstanceOf[NDArray] + mom *= momentum + resdGrad += wd * weight + mom += resdGrad + resdGrad += momentum * mom + weight += -lr * resdGrad + } else { + require(momentum == 0f) + // adder = -lr * (resdGrad + this.wd * weight) + // we write in this way to get rid of memory leak + val adder = this.wd * weight + adder += resdGrad + adder *= (-lr) + weight += adder + adder.dispose() + } + + resdGrad.dispose() + } + + // Create additional optimizer state such as momentum. + override def createState(index: Int, weight: NDArray): AnyRef = { + if (momentum == 0.0f) { + null + } else { + NDArray.zeros(weight.shape, weight.context) + } + } + + // Dispose the state it created + override def disposeState(state: AnyRef): Unit = { + if (state != null) { + state.asInstanceOf[NDArray].dispose() + } + } +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala index 349a3ddb31a1..6e35358877e5 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala @@ -10,6 +10,11 @@ import ml.dmlc.mxnet.NDArrayConversions._ class SGD(private val learningRate: Float = 0.01f, private val momentum: Float = 0.0f, private val wd: Float = 0.0001f, private val clipGradient: Float = 0f, private val lrScheduler: LRScheduler = null) extends Optimizer { + + if (lrScheduler != null) { + lrScheduler.baseLR = learningRate + } + /** * Update the parameters. * @param index An unique integer key used to index the parameters diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGLD.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGLD.scala new file mode 100644 index 000000000000..a1bd5db55c3c --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGLD.scala @@ -0,0 +1,70 @@ +package ml.dmlc.mxnet.optimizer + +import ml.dmlc.mxnet.{Optimizer, LRScheduler, NDArray} +import ml.dmlc.mxnet.NDArrayConversions._ +import ml.dmlc.mxnet.Random + +/** + * Stochastic Langevin Dynamics Updater to sample from a distribution. + * + * @author Depeng Liang + * + * @param learningRate Float, Step size. + * @param rescaleGradient Float, rescaling factor of gradient. + * @param wd Float, L2 regularization coefficient add to all the weights + * @param clipGradient Float, clip gradient in range [-clip_gradient, clip_gradient] + * @param lrScheduler The learning rate scheduler + */ +class SGLD(val learningRate: Float = 0.01f, val rescaleGradient: Float = 1.0f, + val wd: Float = 0.0001f, val clipGradient: Float = 0f, + val lrScheduler: LRScheduler = null) extends Optimizer { + + if (lrScheduler != null) { + lrScheduler.baseLR = learningRate + } + + /** + * Update the parameters. + * @param index An unique integer key used to index the parameters + * @param weight weight ndarray + * @param grad grad ndarray + * @param state NDArray or other objects returned by initState + * The auxiliary state used in optimization. + */ + override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { + val lr = + (if (lrScheduler != null) { + val scheduledLr = lrScheduler(numUpdate) + updateCount(index) + scheduledLr + } else { + this.learningRate + }) * lrScale.getOrElse(index, 1f) + + val wd = getWd(index, this.wd) + var resdGrad = grad * this.rescaleGrad + if (clipGradient != 0f) { + // to get rid of memory leak + val oldResdGrad = resdGrad + resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) + oldResdGrad.dispose() + } + + val adder = this.wd * weight + adder += resdGrad + adder *= -(lr / 2) + val norm = Random.normal(0f, Math.sqrt(lr).toFloat, weight.shape, weight.context) + adder += norm + weight += adder + adder.dispose() + norm.dispose() + } + + // Create additional optimizer state such as momentum. + override def createState(index: Int, weight: NDArray): AnyRef = { + null + } + + // Dispose the state it created + override def disposeState(state: AnyRef): Unit = {} +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/ccSGD.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/ccSGD.scala new file mode 100644 index 000000000000..fbc82a2efd9b --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/ccSGD.scala @@ -0,0 +1,76 @@ +package ml.dmlc.mxnet.optimizer + +import ml.dmlc.mxnet.{Optimizer, LRScheduler, NDArray} +import ml.dmlc.mxnet.NDArrayConversions._ +import ml.dmlc.mxnet.Base._ + + +/** + * A very simple SGD optimizer with momentum and weight regularization. + * Implemented in C++. + * + * @author Depeng Liang + * + * @param learningRate Float, Step size. + * @param momentum Float, momentum value. + * @param rescaleGradient Float, rescaling factor of gradient. + * @param wd Float, L2 regularization coefficient add to all the weights + * @param clipGradient Float, clip gradient in range [-clip_gradient, clip_gradient] + * @param lrScheduler The learning rate scheduler + */ +class ccSGD(val learningRate: Float = 0.01f, val momentum: Float = 0.0f, + val wd: Float = 0.0001f, val rescaleGradient: Float = 1.0f, + val clipGradient: Float = -1f, val lrScheduler: LRScheduler = null + ) extends Optimizer { + + if (lrScheduler != null) { + lrScheduler.baseLR = learningRate + } + + private val optCreator = new OptimizerCreatorRef + private val optHandle = new OptimizerHandleRef + + checkCall(_LIB.mxOptimizerFindCreator("ccsgd", optCreator)) + private val paramKeys = Array("momentum", "rescale_grad", "clip_gradient") + private val paramvals = Array(s"$momentum", s"$rescaleGradient", s"$clipGradient") + checkCall(_LIB.mxOptimizerCreateOptimizer( + optCreator.value, paramKeys.length, paramKeys, paramvals, optHandle)) + + /** + * Update the parameters. + * @param index An unique integer key used to index the parameters + * @param weight weight ndarray + * @param grad grad ndarray + * @param state NDArray or other objects returned by initState + * The auxiliary state used in optimization. + */ + override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { + val lr = + (if (lrScheduler != null) { + val scheduledLr = lrScheduler(numUpdate) + updateCount(index) + scheduledLr + } else { + this.learningRate + }) * lrScale.getOrElse(index, 1f) + + val wd = getWd(index, this.wd) + checkCall(_LIB.mxOptimizerUpdate(optHandle.value, index, weight.handle, grad.handle, lr, wd)) + } + + // Create additional optimizer state such as momentum. + override def createState(index: Int, weight: NDArray): AnyRef = { + null + } + + // Dispose the state it created + override def disposeState(state: AnyRef): Unit = {} + + /** + * Free the optimizer handle. + * The object shall never be used after it is disposed. + */ + def dispose(): Unit = { + checkCall(_LIB.mxOptimizerFree(optHandle.value)) + } +} diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index c3c38da445d3..de29bbe880d9 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -1532,3 +1532,87 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxRecordIOReaderSeek int ret = MXRecordIOReaderSeek(recordIOHandle, pos); return ret; } + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxOptimizerFindCreator + (JNIEnv *env, jobject obj, jstring jkey, jobject out) { + OptimizerCreator creator; + const char *key = env->GetStringUTFChars(jkey, 0); + int ret = MXOptimizerFindCreator(key, &creator); + env->ReleaseStringUTFChars(jkey, key); + SetLongField(env, out, reinterpret_cast(creator)); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxOptimizerCreateOptimizer + (JNIEnv *env, jobject obj, jlong jcreator, jint num_param, + jobjectArray jkeys, jobjectArray jvals, jobject out) { + OptimizerHandle handle; + OptimizerCreator creator = reinterpret_cast(jcreator); + int len = env->GetArrayLength(jkeys); + const char **keys = NULL; + if (jkeys != NULL) { + keys = new const char *[len]; + for (size_t i = 0; i < len; i++) { + jstring jkey = reinterpret_cast(env->GetObjectArrayElement(jkeys, i)); + const char *key = env->GetStringUTFChars(jkey, 0); + keys[i] = key; + env->DeleteLocalRef(jkey); + } + } + const char **vals = NULL; + if (jvals != NULL) { + vals = new const char *[len]; + for (size_t i = 0; i < len; i++) { + jstring jval = reinterpret_cast(env->GetObjectArrayElement(jvals, i)); + const char *val = env->GetStringUTFChars(jval, 0); + vals[i] = val; + env->DeleteLocalRef(jval); + } + } + int ret = MXOptimizerCreateOptimizer(creator, + num_param, + keys, + vals, + &handle); + SetLongField(env, out, reinterpret_cast(handle)); + // release allocated memory + if (jkeys != NULL) { + for (size_t i = 0; i < len; i++) { + jstring jkey = reinterpret_cast(env->GetObjectArrayElement(jkeys, i)); + env->ReleaseStringUTFChars(jkey, keys[i]); + env->DeleteLocalRef(jkey); + } + delete[] keys; + } + if (jvals != NULL) { + for (size_t i = 0; i < len; i++) { + jstring jval = reinterpret_cast(env->GetObjectArrayElement(jvals, i)); + env->ReleaseStringUTFChars(jval, vals[i]); + env->DeleteLocalRef(jval); + } + delete[] vals; + } + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxOptimizerFree + (JNIEnv *env, jobject obj, jlong jhandle) { + OptimizerHandle handle = reinterpret_cast(jhandle); + int ret = MXOptimizerFree(handle); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxOptimizerUpdate + (JNIEnv *env, jobject obj, jlong jhandle, jint index, jlong jweight, + jlong jgrad, jfloat lr, jfloat wd) { + OptimizerHandle handle = reinterpret_cast(jhandle); + NDArrayHandle weight = reinterpret_cast(jweight); + NDArrayHandle grad = reinterpret_cast(jgrad); + int ret = MXOptimizerUpdate(handle, + index, + weight, + grad, + lr, + wd); + return ret; +}