Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1543][MLlib] Add ADMM for solving Lasso (and elastic net) problem #458

Closed
wants to merge 11 commits into from

Conversation

@coderxiang
Copy link
Contributor

commented Apr 20, 2014

This PR introduces the Alternating Direction Method of Multipliers (ADMM) for solving Lasso (elastic net, in fact) in mllib.

ADMM is capable of solving a class of composite minimization problems in a distributed way. Specifically for Lasso (if only L1-regularization) or elastic-net (both L1- and L2- regularization), in each iteration, it requires solving independent systems of linear equations on each partition and a subsequent soft-threholding operation on the driver machine. Unlike SGD, it is a deterministic algorithm (except for the random partition). Details can be found in the S. Boyd's paper.

The linear algebra operations mainly rely on the Breeze library, particularly, it applies breeze.linalg.cholesky to perform cholesky decomposition on each partition to solve the linear system.

I tried to follow the organization of existing Lasso implementation. However, as ADMM is also a good fit for similar optimization problems, e.g., (sparse) logistic regression, it may be worth reorganizing and putting ADMM into a separate section.

lebesgue added some commits Apr 20, 2014

lebesgue
lebesgue
@mengxr

This comment has been minimized.

Copy link
Contributor

commented Apr 21, 2014

Jenkins, test this please.

@mengxr

This comment has been minimized.

Copy link
Contributor

commented Apr 21, 2014

@coderxiang Did you compare ADMM and SGD/L-BFGS implemented in MLlib on some large datasets?


val updatedX = if (row >= col) {
chol.t \ (chol \ q)
}else {

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 22, 2014

Contributor

add a space after }

while (iter <= numIterations && !minorChange) {
val zBroadcast = z
def localUpdate(
it: Iterator[((BDV[Double], BDM[Double], BDM[Double]), (BDV[Double], BDV[Double]))]):

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 22, 2014

Contributor

It is better to put the ): in the next line, with the returned type.

var l2RegParam: Double,
var penalty: Double)
extends GeneralizedLinearAlgorithm[LassoModel] with Serializable {

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 22, 2014

Contributor

A single blank line is good enough. Same below.

* the number of features in the data.
*/
def train(
input: RDD[LabeledPoint],

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 22, 2014

Contributor

4 characters indent in the definition of function.

* @param penalty ADMM penalty of the constraint
*/
def train(
input: RDD[LabeledPoint],

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 22, 2014

Contributor

4-char indent

}

def main(args: Array[String]) {
if (args.length != 5) {

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 22, 2014

Contributor

length of args is 7.

val testRDD = sc.parallelize(testData, 2).cache()


This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 22, 2014

Contributor

remove a blank line.


val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
.map { case LabeledPoint(label, features) =>
LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 22, 2014

Contributor

another 2-char indent is better.

@yinxusen

This comment has been minimized.

Copy link
Contributor

commented Apr 22, 2014

@coderxiang It is better to have a detailed test, just like what @mengxr said. I can do a favor on testing if you need.

@coderxiang

This comment has been minimized.

Copy link
Contributor Author

commented Apr 22, 2014

@yinxusen Thanks for the comments. I'm running some comparison between SGD and ADMM right now and will try to post them later. It would also be great if you could provide further testing.

@etrain

This comment has been minimized.

Copy link
Contributor

commented Apr 22, 2014

Hey, this looks awesome! One high-level issue I see is that the ADMM optimizer has embedded in it knowledge of the loss function it's trying to minimize. ADMM is much more general than that and is nicely scalable - can we abstract out the general ADMM computation pattern out in a spirit similar to what we've done with GradientDescent - and have Lasso, SVM, etc. done with ADMM as subclasses that implement a specialized "compute" function (or something)?

lebesgue added some commits Apr 22, 2014

@coderxiang

This comment has been minimized.

Copy link
Contributor Author

commented Apr 22, 2014

@mengxr @yinxusen Here are some comparison results between ADMM and SGD. These results are just for these particular parameter settings and data sets and no over-generalization should be drawn.

The experiments are carried out on a small data set (200 by 200 design matrix) and a large one (10k by 10k design matrix), both randomly generated. Only L1-regularization is employed and the parameter is 5 and 20 respectively. At most 500 iterations are run.

Method loss-small / loss-large AUC-small / AUC-large :
SGD 96.42 /NaN 0.8838 / NaN
ADMM 93.54 / 4008.55 0.8771/ 0.9464
FISTA (local lasso solver) 93.52 / 4009.88 0.8767/0.9481

On average, each aggregate step in SGD takes 11s on average, while each reduce step in ADMM requires 8s, as shown on the web UI. The ROC is available from here

I tried two parameter settings (stepsize=0.05/0.01, iter=500/100) for SGD , both seem to encounter some convergence problem. The results are shown as NaN. Maybe we can discuss this separately.

@coderxiang

This comment has been minimized.

Copy link
Contributor Author

commented Apr 22, 2014

@etrain that's a good point if ADMM implementations of other algorithms are going to be added in MLlib. Fortunately, for lasso, ridge regression and sparse logistic regression, the computation on the driver and pretty similar, all need is to write separate local optimization programs.

@yinxusen

This comment has been minimized.

Copy link
Contributor

commented Apr 22, 2014

Cool, could you share your data-generator code to me, and let me take care of the Nan problem? Besides, could you provide the total running time of SGD and ADMM when they reach a similar loss?

@coderxiang

This comment has been minimized.

Copy link
Contributor Author

commented Apr 22, 2014

@yinxusen Just sent the data/code to you. For the running time, from the web UI, it appears to be just (nIter * average reduce/aggregate), . The aggregate time for SGD remains almost identical in each iteration, while ADMM's reduce time varies and often taker longer time in later iterations. Since ADMM has early termination criterion, it usually takes less than 500 iterations before convergence. I may include these details in following evaluation.

@etrain

This comment has been minimized.

Copy link
Contributor

commented Apr 22, 2014

Right - the pattern is virtually identical except for an update function call. Can we abstract this away so that we can deliver the first 3 algorithms of ADMM with a few lines of code so that it's straightforward to add new versions of ADMM algorithms.

@yinxusen

This comment has been minimized.

Copy link
Contributor

commented Apr 23, 2014

@coderxiang I do some experiments on your dataset.

  • For MLlib, you should first rewrite your labels {+1, -1} into {+1, 0}. Reference here
  • For Lasso, you need preprocess your dataset, and make it with zero mean and unit norm. Reference here. @mengxr just removed the former preprocessing because it is not elegant.

I open a JIRA issue to explain the reason why Infinity occurs. IMHO, I prefer rewriting this line into

brzAxpy(2.0 * diff / weights.size, brzData, cumGradient.toBreeze)

to do average, since the gradient is used for updating each single element of weights. But I am not sure of that, maybe @mengxr and @etrain could give us some suggestions.

@coderxiang

This comment has been minimized.

Copy link
Contributor Author

commented Apr 23, 2014

@yinxusen thanks for the testing.

  • For lasso, the label actually does not matter since it is a regression model. I'll take care of this in Logistic regression.
  • The normalization is a good point. However, maybe it is better to give a separate procedure for this, since normalization may be required by more than just Lasso.

For the convergence issue of SGD, I guess more complicated SGD algorithm (like dual averaging) may solve this problem.

@yinxusen

This comment has been minimized.

Copy link
Contributor

commented Apr 24, 2014

I do the preprocess of your data, make it with zero-mean, unit norm. But Lasso also performances poorly, with Infinity results or rising losses.

Since Lasso is a regression method, maybe we should use regression data to test it. the current classification dataset is not suitable. @coderxiang Do you have any regression dataset?

More complicated SGDs are worthy of consideration.

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.DeveloperApi

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 26, 2014

Contributor

remove extra line

while (iter <= numIterations && !minorChange) {
val zBroadcast = z
def localUpdate(
it: Iterator[((BDV[Double], BDM[Double], BDM[Double]), (BDV[Double], BDV[Double]))]

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 26, 2014

Contributor

improper indent

val zBroadcast = z
def localUpdate(
it: Iterator[((BDV[Double], BDM[Double], BDM[Double]), (BDV[Double], BDV[Double]))]
):Iterator[((BDV[Double], BDM[Double], BDM[Double]), (BDV[Double], BDV[Double]))] = {

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 26, 2014

Contributor

same here

var minorChange: Boolean = false
while (iter <= numIterations && !minorChange) {
val zBroadcast = z
def localUpdate(

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 26, 2014

Contributor

How about moving this function definition out of while loop?

} else {
(q :/ penalty) - ((design.t *(chol.t\(chol\(design * q)))) :/ (penalty * penalty))
}
Iterator((localData._1, (updatedX, updatedU)))

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 26, 2014

Contributor

localData._1 is not modified in the function. How about using broadcast instead of variable serialization? The iterative useless ser/des slows down the procedure, especially when the local design matrix is too large.

it
}
}
dividedData = dividedData.mapPartitions(localUpdate).cache()

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 26, 2014

Contributor

Using broadcast instead of ser/des, then the cache() is saved, since you will use reduce to aggregate last and zSum back to driver.

private var l1RegParam: Double = 1.0
private var l2RegParam: Double = .0
private var penalty: Double = 10.0

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 26, 2014

Contributor

remove extra line.

var iter = 1
var minorChange: Boolean = false
while (iter <= numIterations && !minorChange) {
val zBroadcast = z

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 26, 2014

Contributor

Do you mean real broadcast variable in Spark? If not, the zBroadcast could be removed, since = here is a shallow copy.

val (lab, design, chol) = u
val residual = design * zBroadcast - lab
(0.5 * residual.dot(residual), (v._2 :/ penalty) + v._1)
}.reduce{case (x, y) => (x._1 + y._1, x._2 + y._2)}

This comment has been minimized.

Copy link
@yinxusen

yinxusen Apr 26, 2014

Contributor

Be careful with reduce, it will throw exception if some of the partitions are empty. Using aggregate is better.

lebesgue added some commits Apr 26, 2014

@coderxiang

This comment has been minimized.

Copy link
Contributor Author

commented May 1, 2014

@yinxusen @mengxr I update the local solver and the running time has been substantially improved. On a cluster with 4 workers, the program converges in 2.3 min for 10k * 10k data and 19 min for 50k * 10k. Most of the computation happens in the preprocess of factorizing each design matrix. The computational cost inside each ADMM iteration is minor now.

@debasish83

This comment has been minimized.

Copy link

commented May 4, 2014

@coderxiang admm should be compared against bfgs based classification/regression that @dbtsai is working on....admm should improve the network transfer from worker to master and thus improve the runtime without affecting misclassification error....r u planning to do that comparison ?

@dbtsai

This comment has been minimized.

Copy link
Member

commented May 5, 2014

lbfgs is not good for L1 problem. I'm working on and preparing to do benchmark with bfgs variant OWL-QN for L1 which is ideal to be compared with ADMM.

@debasish83

This comment has been minimized.

Copy link

commented May 5, 2014

Depends how you solve L1 with lbfgs...

OWL-QN for L1 is definitely a solution...

You can also replace L1 as soft-max and but then you have to be careful
with the schedule of soft-max smoothness....

I think just pick OWL-QN for L1 (as it is implemented in breeze) and
comparing against ADMM will be good....

On Sun, May 4, 2014 at 10:31 PM, DB Tsai notifications@github.com wrote:

lbfgs is not good for L1 problem. I'm working on and preparing to do
benchmark with bfgs variant OWL-QN for L1 which is ideal to be compared
with ADMM.


Reply to this email directly or view it on GitHubhttps://github.com//pull/458#issuecomment-42160096
.

@ajtulloch

This comment has been minimized.

Copy link
Contributor

commented May 11, 2014

I agree with @etrain, it's possible to abstract out the ADMM optimisation routine such that it's trivial to implement L1-logistic regression, lasso, SVMs, etc with very few additional lines of code. I implemented that for Spark a few months ago (albeit a naive, unperformant, and untested implementation).

If you wanted to see an alternative way of structuring this diff, my code is available at ajtulloch/spark/SPARK-1794-GenericADMM.

pwendell added a commit to pwendell/spark that referenced this pull request May 12, 2014

Merge pull request apache#458 from tdas/docs-update
Updated java API docs for streaming, along with very minor changes in the code examples.

Docs updated for:
Scala: StreamingContext, DStream, PairDStreamFunctions
Java: JavaStreamingContext, JavaDStream, JavaPairDStream

Example updated:
JavaQueueStream: Not use deprecated method
ActorWordCount: Use the public interface the right way.

@coderxiang coderxiang closed this Jul 30, 2014

@kellrott

This comment has been minimized.

Copy link
Contributor

commented Sep 6, 2014

Is there any work still happening with ADMM in Spark? This patch was rejected and the jira issue was closed. Has everyone given up?

@chouqin chouqin referenced this pull request Oct 28, 2014

andrewor14 pushed a commit to andrewor14/spark that referenced this pull request Jan 8, 2015

Merge pull request apache#458 from tdas/docs-update
Updated java API docs for streaming, along with very minor changes in the code examples.

Docs updated for:
Scala: StreamingContext, DStream, PairDStreamFunctions
Java: JavaStreamingContext, JavaDStream, JavaPairDStream

Example updated:
JavaQueueStream: Not use deprecated method
ActorWordCount: Use the public interface the right way.
(cherry picked from commit 256a355)

Signed-off-by: Patrick Wendell <pwendell@gmail.com>
@luxx92

This comment has been minimized.

Copy link

commented Jun 29, 2016

hello,could you share your data-generator code to me.I‘m working on this project at school,hope
to compare the results with your dataset,thank you.

@debasish83

This comment has been minimized.

Copy link

commented Jun 29, 2016

ADMM is already implemented as part of Breeze proximal NonlinearMinimizer where the ADMM solver stays in master and gradient calculator is used in similar manner as how Breeze LBFGS/OWLQN has been plugged in...I did not open up a PR since OWLQN has been chosen for L1 logistic...

markhamstra pushed a commit to markhamstra/spark that referenced this pull request Nov 7, 2017

j-esse pushed a commit to j-esse/spark that referenced this pull request Jan 24, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
9 participants
You can’t perform that action at this time.