-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weighted loss function #37
Comments
Hi David, What is the equation for the whole objective function you want to optimize? If you want to use the usual regularizers for fused lasso, and you are only changing the loss function, then you solve that fairly easily. You can use a proximal gradient or a FISTA algorithm, providing a way of computing the gradient for your loss, and the tools you can find in proxTV for the Total-Variation + L1 regularizer. |
Hi, First to say is that I'm not an expert with the algorithms. The full function has the usual regularizers and looks like: min ||Wi(Xi-Yi)||^2 + lambda Wi(Xi-Xj) in two dimensions. We are using now a graph fused lasso package from https://github.com/tansey/gfl which works well and can be used for different loss functions. The problem is that the algorithm used there ADMM seems to converge slowly and we are struggling to make it faster as we don't see how to parallelize it. Thanks David |
ADMM is indeed quite slow unless you find good stepsize parameters, which is no easy task in general. But this problem should be solvable through a combination of FISTA and the 2D Total Variation proximity solver provided in proxTV. You see, FISTA is a general algorithm for minimizing a sum of a differentiable function (your loss) and a non-differentiable function. For the differentiable part you need to provide a way to compute the gradient of that part, which at first glance should be simple for your problem. For the non-differentiable part you need to provide the prox operator, but you already have that coded in proxTV! About parallelization, FISTA essentialy alternates between calls to the gradient function and the prox operator you provided. 2D Total Variation here alredy admits parallelization, so half of your problem is already solved. FISTA is not implemented in this library, though I created some scripts using it for the related paper. I'm attaching a general implementation of FISTA in MATLAB. As you will see, you need to provide gradient and prox functions approriate for your problem, as discussed above. Hope this helps! |
It definitely helps a lot. |
Hi,
We try to solve a fused lasso problem with a slightly different loss function. In our case the loss function is weighted as
min ||Wi(Xi-Yi)||^2 + lambda ....
To adapt proxTV to this case, do you think it will be feasible and simple for us to build up a new proximity operator based on the function tautString_TV1_Weighted?
Regards
David
The text was updated successfully, but these errors were encountered: