From 94362a279e9d3bd1d1855939bd2eaa12353e8db6 Mon Sep 17 00:00:00 2001
From: rawkintrevo
Date: Mon, 5 Oct 2015 11:27:58 -0500
Subject: [PATCH 1/3] FLINK-1994: Added 4 new effective learning rates
Added SGD gain calculation schemes
fixed optimal SGD calculation scheme
FLINK-1994: Added 4 new effective learning rates
[FLINK-1994] [ml] Add different gain calculation schemes to SGD
fixed long lines in GradientDescent.scala
[FLINK-1994][ml]Add different gain calculation schemes to SGD
[FLINK-1994][ml] Add different gain calculation schemes to SGD
[FLINK-1994][ml] Add different gain calculation schemes to SGD
Added SGD gain calculation schemes
fixed optimal SGD calculation scheme
[FLINK-1994] [ml] Add different gain calculation schemes to SGD
FLINK-1994: Added 3 new effective learning rates
Added SGD gain calculation schemes
fixed optimal SGD calculation scheme
[FLINK-1994] [ml] Add different gain calculation schemes to SGD
fixed long lines in GradientDescent.scala
[FLINK-1994][ml] Add different gain calculation schemes to SGD
[Flink-1994][ml] Add different gain calculation schemes to SGD
[FLINK-1994][ml] Updated docs, refactored optimizationMethod from Int to String
[FLINK-1994][ml] Added test and example to docs
[FLINK-1994][ml] Fixed Int Artifacts in LinearRegression.scala
Added LearningRateMethod to IterativeSolver
The learning rate method defines how the effective learning step is calculated for
each iteration step of the IterativeSolver.
Fixed docs, merged enumeration from Till, fixed typo in Wus method
---
docs/libs/ml/optimization.md | 111 ++++++++++++++++--
.../ml/optimization/GradientDescent.scala | 49 ++++++--
.../apache/flink/ml/optimization/Solver.scala | 91 +++++++++++++-
.../optimization/GradientDescentITSuite.scala | 34 +++++-
4 files changed, 264 insertions(+), 21 deletions(-)
diff --git a/docs/libs/ml/optimization.md b/docs/libs/ml/optimization.md
index 110383d6802e3..d011fbd4ae776 100644
--- a/docs/libs/ml/optimization.md
+++ b/docs/libs/ml/optimization.md
@@ -76,7 +76,7 @@ few large ones.
The $L_1$ penalty can be used to drive a number of the solution coefficients to 0, thereby
producing sparse solutions.
The regularization constant $\lambda$ in $\eqref{eq:objectiveFunc}$ determines the amount of regularization applied to the model,
-and is usually determined through model cross-validation.
+and is usually determined through model cross-validation.
A good comparison of regularization types can be found in [this](http://www.robotics.stanford.edu/~ang/papers/icml04-l1l2.pdf) paper by Andrew Ng.
Which regularization type is supported depends on the actually used optimization algorithm.
@@ -94,9 +94,7 @@ In mini-batch SGD we instead sample random subsets of the dataset, and compute t
over each batch. At each iteration of the algorithm we update the weights once, based on
the average of the gradients computed from each mini-batch.
-An important parameter is the learning rate $\eta$, or step size, which is currently determined as
-$\eta = \eta_0/\sqrt{j}$, where $\eta_0$ is the initial step size and $j$ is the iteration
-number. The setting of the initial step size can significantly affect the performance of the
+An important parameter is the learning rate $\eta$, or step size, which can be determined by one of five methods, listed below. The setting of the initial step size can significantly affect the performance of the
algorithm. For some practical tips on tuning SGD see Leon Botou's
"[Stochastic Gradient Descent Tricks](http://research.microsoft.com/pubs/192769/tricks-2012.pdf)".
@@ -156,7 +154,7 @@ The following list contains a mapping between the implementing classes and the r
| RegularizationConstant |
- The amount of regularization to apply. (Default value: 0.0)
+ The amount of regularization to apply. (Default value: 0.1)
|
@@ -189,9 +187,26 @@ The following list contains a mapping between the implementing classes and the r
+
+ | OptimizationMethod |
+
+
+ (Default value: "default")
+
+ |
+
+
+ | Decay |
+
+
+
+ (Default value: 0.0)
+
+ |
+
-
+
### Loss Function
The loss function which is minimized has to implement the `LossFunction` interface, which defines methods to compute the loss and the gradient of it.
@@ -199,12 +214,12 @@ Either one defines ones own `LossFunction` or one uses the `GenericLossFunction`
An example can be seen here
```Scala
-val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
```
The full list of supported outer loss functions can be found [here](#partial-loss-function-values).
The full list of supported prediction functions can be found [here](#prediction-function-values).
-
+
#### Partial Loss Function Values ##
@@ -256,6 +271,84 @@ The full list of supported prediction functions can be found [here](#prediction-
+#### Effective Learning Rate ##
+
+Where:
+
+- $j$ is the iteration number
+
+- $\eta_j$ is the step size on step $j$
+
+- $\eta_0$ is the initial step size
+
+- $\lambda$ is the regularization constant
+
+- $\tau$ is the decay constant, which causes the learning rate to be a decreasing function of $j$, that is to say as iterations increase, learning rate decreases. The exact rate of decay is function specific, see **Inverse Scaling** and **Wei Xu's Method** (which is an extension of the **Inverse Scaling** method).
+
+
+
+
+ | Function Name |
+ Description |
+ Function |
+ Called As |
+
+
+
+
+ | Default |
+
+
+ The function default method used for determining the step size. This is equivalent to the inverse scaling method for $\tau$ = 0.5. This special case is kept as the default to maintain backwards compatibility.
+
+ |
+ $\eta_j = \eta_0/\sqrt{j}$ |
+ LearningRateMethod.Default |
+
+
+ | Constant |
+
+
+ The step size is constant throughout the learning task.
+
+ |
+ $\eta_j = \eta_0$ |
+ LearningRateMethod.Constant |
+
+
+ | Leon Bottou's Method |
+
+
+ This is the 'optimal' method of sklearn. Chooses optimal initial $t_0 = \lambda \cdot eta_0$, based on Leon Bottou's Learning with Large Data Sets
+
+ |
+ $\eta_j = 1 / (\lambda \cdot (\frac{1}{\lambda \cdot eta_0) } +j -1) $ |
+ LearningRateMethod.Bottou |
+
+
+ | Inverse Scaling |
+
+
+ A very common method for determining the step size.
+
+ |
+ $\eta_j = \lambda / j^{\tau}$ |
+ LearningRateMethod.InvScaling |
+
+
+ | Wei Xu's Method |
+
+
+ Method proposed by Wei Xu in Towards Optimal One Pass Large Scale Learning with
+ Averaged Stochastic Gradient Descent
+
+ |
+ $\eta_j = \lambda \cdot (1+ \lambda \cdot \eta_0 \cdot j)^{-\tau} $ |
+ LearningRateMethod.Xu |
+
+
+
+
### Examples
In the Flink implementation of SGD, given a set of examples in a `DataSet[LabeledVector]` and
@@ -276,6 +369,8 @@ val sgd = GradientDescentL1()
.setRegularizationConstant(0.2)
.setIterations(100)
.setLearningRate(0.01)
+ .setOptimizationMethod("xu")
+ .setDecay(-0.75)
// Obtain data
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
index 78bad708ef3f0..309b9a62ede10 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
@@ -22,7 +22,8 @@ package org.apache.flink.ml.optimization
import org.apache.flink.api.scala._
import org.apache.flink.ml.common._
import org.apache.flink.ml.math._
-import org.apache.flink.ml.optimization.IterativeSolver.{ConvergenceThreshold, Iterations, LearningRate}
+import org.apache.flink.ml.optimization.IterativeSolver._
+import org.apache.flink.ml.optimization.LearningRateMethod.LearningRateMethodTrait
import org.apache.flink.ml.optimization.Solver._
import org.apache.flink.ml._
@@ -43,6 +44,10 @@ import org.apache.flink.ml._
* [[IterativeSolver.ConvergenceThreshold]] when provided the algorithm will
* stop the iterations if the relative change in the value of the objective
* function between successive iterations is is smaller than this value.
+ * [[IterativeSolver.LearningRateMethodValue]] determines functional form of
+ * effective learning rate.
+ * [[IterativeSolver.Decay]] Used in some functional forms for determining
+ * effective learning rate.
*/
abstract class GradientDescent extends IterativeSolver {
@@ -61,7 +66,8 @@ abstract class GradientDescent extends IterativeSolver {
val lossFunction = parameters(LossFunction)
val learningRate = parameters(LearningRate)
val regularizationConstant = parameters(RegularizationConstant)
-
+ val learningRateMethod = parameters(LearningRateMethodValue)
+ val decay = parameters(Decay)
// Initialize weights
val initialWeightsDS: DataSet[WeightVector] = createInitialWeightsDS(initialWeights, data)
@@ -75,7 +81,9 @@ abstract class GradientDescent extends IterativeSolver {
numberOfIterations,
regularizationConstant,
learningRate,
- lossFunction)
+ lossFunction,
+ learningRateMethod,
+ decay)
case Some(convergence) =>
optimizeWithConvergenceCriterion(
data,
@@ -84,7 +92,9 @@ abstract class GradientDescent extends IterativeSolver {
regularizationConstant,
learningRate,
convergence,
- lossFunction
+ lossFunction,
+ learningRateMethod,
+ decay
)
}
}
@@ -96,7 +106,9 @@ abstract class GradientDescent extends IterativeSolver {
regularizationConstant: Double,
learningRate: Double,
convergenceThreshold: Double,
- lossFunction: LossFunction)
+ lossFunction: LossFunction,
+ learningRateMethod: LearningRateMethodTrait,
+ decay: Double)
: DataSet[WeightVector] = {
// We have to calculate for each weight vector the sum of squared residuals,
// and then sum them and apply regularization
@@ -119,7 +131,9 @@ abstract class GradientDescent extends IterativeSolver {
previousWeightsDS,
lossFunction,
regularizationConstant,
- learningRate)
+ learningRate,
+ learningRateMethod,
+ decay)
val currentLossSumDS = calculateLoss(dataPoints, currentWeightsDS, lossFunction)
@@ -148,11 +162,19 @@ abstract class GradientDescent extends IterativeSolver {
numberOfIterations: Int,
regularizationConstant: Double,
learningRate: Double,
- lossFunction: LossFunction)
+ lossFunction: LossFunction,
+ optimizationMethod: LearningRateMethodTrait,
+ decay: Double)
: DataSet[WeightVector] = {
initialWeightsDS.iterate(numberOfIterations) {
weightVectorDS => {
- SGDStep(data, weightVectorDS, lossFunction, regularizationConstant, learningRate)
+ SGDStep(data,
+ weightVectorDS,
+ lossFunction,
+ regularizationConstant,
+ learningRate,
+ optimizationMethod,
+ decay)
}
}
}
@@ -168,7 +190,9 @@ abstract class GradientDescent extends IterativeSolver {
currentWeights: DataSet[WeightVector],
lossFunction: LossFunction,
regularizationConstant: Double,
- learningRate: Double)
+ learningRate: Double,
+ learningRateMethod: LearningRateMethodTrait,
+ decay: Double)
: DataSet[WeightVector] = {
data.mapWithBcVariable(currentWeights){
@@ -190,8 +214,11 @@ abstract class GradientDescent extends IterativeSolver {
BLAS.scal(1.0/count, weights)
val gradient = WeightVector(weights, intercept/count)
-
- val effectiveLearningRate = learningRate/Math.sqrt(iteration)
+ val effectiveLearningRate = learningRateMethod.calculateLearningRate(
+ learningRate,
+ iteration,
+ regularizationConstant,
+ decay)
val newWeights = takeStep(
weightVector.weights,
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
index 39a031f8a5806..c234b27368c99 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
@@ -23,6 +23,7 @@ import org.apache.flink.ml.common._
import org.apache.flink.ml.math.{SparseVector, DenseVector}
import org.apache.flink.api.scala._
import org.apache.flink.ml.optimization.IterativeSolver._
+import org.apache.flink.ml.optimization.LearningRateMethod.LearningRateMethodTrait
/** Base class for optimization algorithms
*
@@ -105,7 +106,7 @@ object Solver {
}
case object RegularizationConstant extends Parameter[Double] {
- val defaultValue = Some(0.0) // TODO(tvas): Properly initialize this, ensure Parameter > 0!
+ val defaultValue = Some(0.0001) // TODO(tvas): Properly initialize this, ensure Parameter > 0!
}
}
@@ -131,6 +132,16 @@ abstract class IterativeSolver() extends Solver {
parameters.add(ConvergenceThreshold, convergenceThreshold)
this
}
+
+ def setLearningRateMethod(learningRateMethod: LearningRateMethodTrait): this.type = {
+ parameters.add(LearningRateMethodValue, learningRateMethod)
+ this
+ }
+
+ def setDecay(decay: Double): this.type = {
+ parameters.add(Decay, decay)
+ this
+ }
}
object IterativeSolver {
@@ -149,4 +160,82 @@ object IterativeSolver {
case object ConvergenceThreshold extends Parameter[Double] {
val defaultValue = None
}
+
+ case object LearningRateMethodValue extends Parameter[LearningRateMethodTrait] {
+ val defaultValue = Some(LearningRateMethod.Default)
+ }
+
+ case object Decay extends Parameter[Double] {
+ val defaultValue = Some(0.0)
+ }
+}
+
+object LearningRateMethod {
+
+ sealed trait LearningRateMethodTrait extends Serializable {
+ def calculateLearningRate(
+ initialLearningRate: Double,
+ iteration: Int,
+ regularizationConstant: Double,
+ decay: Double)
+ : Double
+ }
+
+ object Default extends LearningRateMethodTrait {
+ override def calculateLearningRate(
+ initialLearningRate: Double,
+ iteration: Int,
+ regularizationConstant: Double,
+ decay: Double)
+ : Double = {
+ initialLearningRate / Math.sqrt(iteration)
+ }
+ }
+
+ object Constant extends LearningRateMethodTrait {
+ override def calculateLearningRate(
+ initialLearningRate: Double,
+ iteration: Int,
+ regularizationConstant: Double,
+ decay: Double)
+ : Double = {
+ initialLearningRate
+ }
+ }
+
+ object Bottou extends LearningRateMethodTrait {
+ override def calculateLearningRate(
+ initialLearningRate: Double,
+ iteration: Int,
+ regularizationConstant: Double,
+ decay: Double)
+ : Double = {
+ 1 /
+ (regularizationConstant *
+ (1 / (initialLearningRate * regularizationConstant) + iteration - 1))
+ }
+ }
+
+ object InvScaling extends LearningRateMethodTrait {
+ override def calculateLearningRate(
+ initialLearningRate: Double,
+ iteration: Int,
+ regularizationConstant: Double,
+ decay: Double)
+ : Double = {
+ 1 / Math.pow(iteration, decay)
+ }
+ }
+
+ object Xu extends LearningRateMethodTrait {
+ override def calculateLearningRate(
+ initialLearningRate: Double,
+ iteration: Int,
+ regularizationConstant: Double,
+ decay: Double)
+ : Double = {
+ initialLearningRate *
+ Math.pow(1 + regularizationConstant * initialLearningRate * iteration, -decay)
+ }
+ }
}
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
index d84d017a720a2..8ed1c6e2a10be 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
@@ -45,7 +45,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
.setIterations(2000)
.setLossFunction(lossFunction)
.setRegularizationConstant(0.3)
-
+
val inputDS: DataSet[LabeledVector] = env.fromCollection(regularizationData)
val weightDS = sgd.optimize(inputDS, None)
@@ -240,6 +240,38 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
weight0NoConvergence should not be (weight0Early +- 0.1)
}
+ it should "come up with similar parameter estimates with xu step-size strategy" in {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ env.setParallelism(2)
+
+ val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
+ val sgd = SimpleGradientDescent()
+ .setStepsize(1.0)
+ .setIterations(800)
+ .setLossFunction(lossFunction)
+ .setLearningRateMethod(LearningRateMethod.Xu)
+ .setDecay(-0.75)
+
+ val inputDS = env.fromCollection(data)
+ val weightDS = sgd.optimize(inputDS, None)
+
+ val weightList: Seq[WeightVector] = weightDS.collect()
+
+ weightList.size should equal(1)
+
+ val weightVector: WeightVector = weightList.head
+
+ val weights = weightVector.weights.asInstanceOf[DenseVector].data
+ val weight0 = weightVector.intercept
+
+ expectedWeights zip weights foreach {
+ case (expectedWeight, weight) =>
+ weight should be (expectedWeight +- 0.1)
+ }
+ weight0 should be (expectedWeight0 +- 0.1)
+ }
// TODO: Need more corner cases, see sklearn tests for SGD linear model
}
From 26ee2437fa6872d3d5ec92b1271865692650a8ef Mon Sep 17 00:00:00 2001
From: Trevor Grant
Date: Thu, 21 Jan 2016 21:57:51 -0600
Subject: [PATCH 2/3] [FLINK-1994][ml] Added 4 new effective learning rate
methods
---
docs/libs/ml/optimization.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/libs/ml/optimization.md b/docs/libs/ml/optimization.md
index d011fbd4ae776..69b566d1e28af 100644
--- a/docs/libs/ml/optimization.md
+++ b/docs/libs/ml/optimization.md
@@ -310,7 +310,7 @@ Where:
The step size is constant throughout the learning task.
-
+
|
$\eta_j = \eta_0$ |
LearningRateMethod.Constant |
From 721943abcf5cf9f2aaadd74b5bb474e19881ee76 Mon Sep 17 00:00:00 2001
From: Trevor Grant
Date: Thu, 21 Jan 2016 22:03:59 -0600
Subject: [PATCH 3/3] [FLINK-1994][ml] Add different gain calulation schemes to
SGD
---
docs/libs/ml/optimization.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/libs/ml/optimization.md b/docs/libs/ml/optimization.md
index 69b566d1e28af..8a4909659d359 100644
--- a/docs/libs/ml/optimization.md
+++ b/docs/libs/ml/optimization.md
@@ -308,9 +308,9 @@ Where:
| Constant |
-
+
The step size is constant throughout the learning task.
-
+
|
$\eta_j = \eta_0$ |
LearningRateMethod.Constant |