-
-
Notifications
You must be signed in to change notification settings - Fork 57
/
admm.jl
68 lines (47 loc) · 1.58 KB
/
admm.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
mutable struct ADMM{U} <: AbstractOptimizer
λ::U
ρ::U
end
"""
ADMM()
ADMM(λ, ρ)
`ADMM` is an implementation of Lasso using the alternating direction methods of multipliers and loosely based on [this implementation](https://web.stanford.edu/~boyd/papers/admm/lasso/lasso.html).
`λ` is the sparsification parameter, `ρ` the augmented Lagrangian parameter.
# Example
```julia
opt = ADMM()
opt = ADMM(1e-1, 2.0)
```
"""
ADMM() = ADMM(0.1, 1.0)
function set_threshold!(opt::ADMM, threshold)
opt.λ = threshold*opt.ρ
end
get_threshold(opt::ADMM) = opt.λ/opt.ρ
init(o::ADMM, A::AbstractArray, Y::AbstractArray) = A \ Y
init!(X::AbstractArray, o::ADMM, A::AbstractArray, Y::AbstractArray) = ldiv!(X, qr(A, Val(true)), Y)
#soft_thresholding(x::AbstractArray, t::T) where T <: Real = sign.(x) .* max.(abs.(x) .- t, zero(eltype(x)))
function fit!(X::AbstractArray, A::AbstractArray, Y::AbstractArray, opt::ADMM; maxiter::Int64 = 1, convergence_error::T = eps()) where T <: Real
n, m = size(A)
g = NormL1(get_threshold(opt))
x̂ = deepcopy(X)
ŷ = zero(X)
P = I(m)/opt.ρ - (A' * pinv(opt.ρ*I(n) + A*A') *A)/opt.ρ
c = P*(A'*Y)
x_i = similar(X)
x_i .= X
iters = 0
@inbounds for i in 1:maxiter
iters += 1
x̂ .= P*(opt.ρ.*X .- ŷ) .+ c
prox!(X, g, x̂ .+ ŷ./opt.ρ)
ŷ .= ŷ .+ opt.ρ.*(x̂ .- X)
if norm(x_i - X, 2) < convergence_error
break
else
x_i .= X
end
end
X[abs.(X) .< get_threshold(opt)] .= zero(eltype(X))
return iters
end