Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1543][MLlib] Add ADMM for solving Lasso (and elastic net) problem #458

Closed
wants to merge 11 commits into from
Closed

[SPARK-1543][MLlib] Add ADMM for solving Lasso (and elastic net) problem #458

wants to merge 11 commits into from

Conversation

coderxiang
Copy link
Contributor

@coderxiang coderxiang commented Apr 20, 2014

This PR introduces the Alternating Direction Method of Multipliers (ADMM) for solving Lasso (elastic net, in fact) in mllib.

ADMM is capable of solving a class of composite minimization problems in a distributed way. Specifically for Lasso (if only L1-regularization) or elastic-net (both L1- and L2- regularization), in each iteration, it requires solving independent systems of linear equations on each partition and a subsequent soft-threholding operation on the driver machine. Unlike SGD, it is a deterministic algorithm (except for the random partition). Details can be found in the S. Boyd's paper.

The linear algebra operations mainly rely on the Breeze library, particularly, it applies breeze.linalg.cholesky to perform cholesky decomposition on each partition to solve the linear system.

I tried to follow the organization of existing Lasso implementation. However, as ADMM is also a good fit for similar optimization problems, e.g., (sparse) logistic regression, it may be worth reorganizing and putting ADMM into a separate section.

@mengxr
Copy link
Contributor

@mengxr mengxr commented Apr 21, 2014

Jenkins, test this please.

@mengxr
Copy link
Contributor

@mengxr mengxr commented Apr 21, 2014

@coderxiang Did you compare ADMM and SGD/L-BFGS implemented in MLlib on some large datasets?


val updatedX = if (row >= col) {
chol.t \ (chol \ q)
}else {
Copy link
Contributor

@yinxusen yinxusen Apr 22, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a space after }

@yinxusen
Copy link
Contributor

@yinxusen yinxusen commented Apr 22, 2014

@coderxiang It is better to have a detailed test, just like what @mengxr said. I can do a favor on testing if you need.

@coderxiang
Copy link
Contributor Author

@coderxiang coderxiang commented Apr 22, 2014

@yinxusen Thanks for the comments. I'm running some comparison between SGD and ADMM right now and will try to post them later. It would also be great if you could provide further testing.

@etrain
Copy link
Contributor

@etrain etrain commented Apr 22, 2014

Hey, this looks awesome! One high-level issue I see is that the ADMM optimizer has embedded in it knowledge of the loss function it's trying to minimize. ADMM is much more general than that and is nicely scalable - can we abstract out the general ADMM computation pattern out in a spirit similar to what we've done with GradientDescent - and have Lasso, SVM, etc. done with ADMM as subclasses that implement a specialized "compute" function (or something)?

@coderxiang
Copy link
Contributor Author

@coderxiang coderxiang commented Apr 22, 2014

@mengxr @yinxusen Here are some comparison results between ADMM and SGD. These results are just for these particular parameter settings and data sets and no over-generalization should be drawn.

The experiments are carried out on a small data set (200 by 200 design matrix) and a large one (10k by 10k design matrix), both randomly generated. Only L1-regularization is employed and the parameter is 5 and 20 respectively. At most 500 iterations are run.

Method loss-small / loss-large AUC-small / AUC-large :
SGD 96.42 /NaN 0.8838 / NaN
ADMM 93.54 / 4008.55 0.8771/ 0.9464
FISTA (local lasso solver) 93.52 / 4009.88 0.8767/0.9481

On average, each aggregate step in SGD takes 11s on average, while each reduce step in ADMM requires 8s, as shown on the web UI. The ROC is available from here

I tried two parameter settings (stepsize=0.05/0.01, iter=500/100) for SGD , both seem to encounter some convergence problem. The results are shown as NaN. Maybe we can discuss this separately.

@coderxiang
Copy link
Contributor Author

@coderxiang coderxiang commented Apr 22, 2014

@etrain that's a good point if ADMM implementations of other algorithms are going to be added in MLlib. Fortunately, for lasso, ridge regression and sparse logistic regression, the computation on the driver and pretty similar, all need is to write separate local optimization programs.

@yinxusen
Copy link
Contributor

@yinxusen yinxusen commented Apr 22, 2014

Cool, could you share your data-generator code to me, and let me take care of the Nan problem? Besides, could you provide the total running time of SGD and ADMM when they reach a similar loss?

@coderxiang
Copy link
Contributor Author

@coderxiang coderxiang commented Apr 22, 2014

@yinxusen Just sent the data/code to you. For the running time, from the web UI, it appears to be just (nIter * average reduce/aggregate), . The aggregate time for SGD remains almost identical in each iteration, while ADMM's reduce time varies and often taker longer time in later iterations. Since ADMM has early termination criterion, it usually takes less than 500 iterations before convergence. I may include these details in following evaluation.

@etrain
Copy link
Contributor

@etrain etrain commented Apr 22, 2014

Right - the pattern is virtually identical except for an update function call. Can we abstract this away so that we can deliver the first 3 algorithms of ADMM with a few lines of code so that it's straightforward to add new versions of ADMM algorithms.

@yinxusen
Copy link
Contributor

@yinxusen yinxusen commented Apr 23, 2014

@coderxiang I do some experiments on your dataset.

  • For MLlib, you should first rewrite your labels {+1, -1} into {+1, 0}. Reference here
  • For Lasso, you need preprocess your dataset, and make it with zero mean and unit norm. Reference here. @mengxr just removed the former preprocessing because it is not elegant.

I open a JIRA issue to explain the reason why Infinity occurs. IMHO, I prefer rewriting this line into

brzAxpy(2.0 * diff / weights.size, brzData, cumGradient.toBreeze)

to do average, since the gradient is used for updating each single element of weights. But I am not sure of that, maybe @mengxr and @etrain could give us some suggestions.

@coderxiang
Copy link
Contributor Author

@coderxiang coderxiang commented Apr 23, 2014

@yinxusen thanks for the testing.

  • For lasso, the label actually does not matter since it is a regression model. I'll take care of this in Logistic regression.
  • The normalization is a good point. However, maybe it is better to give a separate procedure for this, since normalization may be required by more than just Lasso.

For the convergence issue of SGD, I guess more complicated SGD algorithm (like dual averaging) may solve this problem.

@yinxusen
Copy link
Contributor

@yinxusen yinxusen commented Apr 24, 2014

I do the preprocess of your data, make it with zero-mean, unit norm. But Lasso also performances poorly, with Infinity results or rising losses.

Since Lasso is a regression method, maybe we should use regression data to test it. the current classification dataset is not suitable. @coderxiang Do you have any regression dataset?

More complicated SGDs are worthy of consideration.

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.DeveloperApi

Copy link
Contributor

@yinxusen yinxusen Apr 26, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra line

} else {
(q :/ penalty) - ((design.t *(chol.t\(chol\(design * q)))) :/ (penalty * penalty))
}
Iterator((localData._1, (updatedX, updatedU)))
Copy link
Contributor

@yinxusen yinxusen Apr 26, 2014

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

localData._1 is not modified in the function. How about using broadcast instead of variable serialization? The iterative useless ser/des slows down the procedure, especially when the local design matrix is too large.

@coderxiang
Copy link
Contributor Author

@coderxiang coderxiang commented May 1, 2014

@yinxusen @mengxr I update the local solver and the running time has been substantially improved. On a cluster with 4 workers, the program converges in 2.3 min for 10k * 10k data and 19 min for 50k * 10k. Most of the computation happens in the preprocess of factorizing each design matrix. The computational cost inside each ADMM iteration is minor now.

@debasish83
Copy link

@debasish83 debasish83 commented May 4, 2014

@coderxiang admm should be compared against bfgs based classification/regression that @dbtsai is working on....admm should improve the network transfer from worker to master and thus improve the runtime without affecting misclassification error....r u planning to do that comparison ?

@dbtsai
Copy link
Member

@dbtsai dbtsai commented May 5, 2014

lbfgs is not good for L1 problem. I'm working on and preparing to do benchmark with bfgs variant OWL-QN for L1 which is ideal to be compared with ADMM.

@debasish83
Copy link

@debasish83 debasish83 commented May 5, 2014

Depends how you solve L1 with lbfgs...

OWL-QN for L1 is definitely a solution...

You can also replace L1 as soft-max and but then you have to be careful
with the schedule of soft-max smoothness....

I think just pick OWL-QN for L1 (as it is implemented in breeze) and
comparing against ADMM will be good....

On Sun, May 4, 2014 at 10:31 PM, DB Tsai notifications@github.com wrote:

lbfgs is not good for L1 problem. I'm working on and preparing to do
benchmark with bfgs variant OWL-QN for L1 which is ideal to be compared
with ADMM.


Reply to this email directly or view it on GitHubhttps://github.com//pull/458#issuecomment-42160096
.

@ajtulloch
Copy link
Contributor

@ajtulloch ajtulloch commented May 11, 2014

I agree with @etrain, it's possible to abstract out the ADMM optimisation routine such that it's trivial to implement L1-logistic regression, lasso, SVMs, etc with very few additional lines of code. I implemented that for Spark a few months ago (albeit a naive, unperformant, and untested implementation).

If you wanted to see an alternative way of structuring this diff, my code is available at ajtulloch/spark/SPARK-1794-GenericADMM.

pwendell added a commit to pwendell/spark that referenced this issue May 12, 2014
Updated java API docs for streaming, along with very minor changes in the code examples.

Docs updated for:
Scala: StreamingContext, DStream, PairDStreamFunctions
Java: JavaStreamingContext, JavaDStream, JavaPairDStream

Example updated:
JavaQueueStream: Not use deprecated method
ActorWordCount: Use the public interface the right way.
@coderxiang coderxiang closed this Jul 30, 2014
@kellrott
Copy link
Contributor

@kellrott kellrott commented Sep 6, 2014

Is there any work still happening with ADMM in Spark? This patch was rejected and the jira issue was closed. Has everyone given up?

andrewor14 pushed a commit to andrewor14/spark that referenced this issue Jan 8, 2015
Updated java API docs for streaming, along with very minor changes in the code examples.

Docs updated for:
Scala: StreamingContext, DStream, PairDStreamFunctions
Java: JavaStreamingContext, JavaDStream, JavaPairDStream

Example updated:
JavaQueueStream: Not use deprecated method
ActorWordCount: Use the public interface the right way.
(cherry picked from commit 256a355)

Signed-off-by: Patrick Wendell <pwendell@gmail.com>
@luxx92
Copy link

@luxx92 luxx92 commented Jun 29, 2016

hello,could you share your data-generator code to me.I‘m working on this project at school,hope
to compare the results with your dataset,thank you.

@debasish83
Copy link

@debasish83 debasish83 commented Jun 29, 2016

ADMM is already implemented as part of Breeze proximal NonlinearMinimizer where the ADMM solver stays in master and gradient calculator is used in similar manner as how Breeze LBFGS/OWLQN has been plugged in...I did not open up a PR since OWLQN has been chosen for L1 logistic...

markhamstra pushed a commit to markhamstra/spark that referenced this issue Nov 7, 2017
j-esse pushed a commit to j-esse/spark that referenced this issue Jan 24, 2019
arjunshroff pushed a commit to arjunshroff/spark that referenced this issue Nov 24, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
9 participants