Skip to content

Commit

Permalink
Added the coordinate descent method from scikit-learn, supporting spa…
Browse files Browse the repository at this point in the history
…rsity (#23)

* Added coordiante descent translation from scikit-learn, still tests to write

* Completed the implementation, gives same results as the Python version (however with different convergence criterion)

* Fixed a small typing bug, added a working test

* Added reference

* Changed default alpha

* Fixed a small mistake in test

* Fixed a typing mistake in test

* Updated coorddesc.jl header

* Updates for Julia 1.1

* Removed Manifest, uses lazy transpose, removed unused lines

* Fixes to Project.toml

* Fixed interface for coordinate descent. Added testing of interf.jl
  • Loading branch information
vilim authored and ararslan committed Jul 24, 2019
1 parent 63bafcf commit affbab8
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Manifest.toml
16 changes: 16 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name = "NMF"
uuid = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386"
license = "MIT"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
1 change: 1 addition & 0 deletions src/NMF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module NMF
include("multupd.jl")
include("projals.jl")
include("alspgrad.jl")
include("coorddesc.jl")

include("interf.jl")

Expand Down
146 changes: 146 additions & 0 deletions src/coorddesc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Coordinate descent method, translated from the Python/Cython implementation
# in scikit-learn and modified to comply with the interfaces of the NMF package

# Original files
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/decomposition/nmf.py
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/decomposition/cdnmf_fast.pyx

# Original implementation authors:
# Vlad Niculae
# Lars Buitinck
# Mathieu Blondel <mathieu@mblondel.org>
# Tom Dupre la Tour

# Original license: BSD 3 clause

# Julia translation: Vilim Štih

# Reference: Cichocki, Andrzej, and P. H. A. N. Anh-Huy. "Fast local algorithms for
# large scale nonnegative matrix and tensor factorizations."
# IEICE transactions on fundamentals of electronics, communications and
# computer sciences 92.3: 708-721, 2009.


mutable struct CoordinateDescent{T}
maxiter::Int
verbose::Bool
tol::T
α::T
l₁ratio::T
regularization::Symbol
shuffle::Bool

function CoordinateDescent{T}(;maxiter::Integer=100,
verbose::Bool=false,
tol::Real=cbrt(eps(T)),
α::Real=T(0.0),
regularization=:both,
l₁ratio::Real=zero(T),
shuffle::Bool=false) where T
new{T}(maxiter, verbose, tol, α, l₁ratio, regularization, shuffle)
end
end


solve!(alg::CoordinateDescent{T}, X, W, H) where {T} =
nmf_skeleton!(CoordinateDescentUpd{T}(alg.α, alg.l₁ratio, alg.regularization, alg.shuffle),
X, W, H, alg.maxiter, alg.verbose, alg.tol)


struct CoordinateDescentUpd{T} <: NMFUpdater{T}
l₁W::T
l₂W::T
l₁H::T
l₂H::T
shuffle::Bool
function CoordinateDescentUpd{T}::T, l₁ratio::T, regularization::Symbol, shuffle::Bool) where {T}
αW = zero(T)
αH = zero(T)

if (regularization == :both) || (regularization == :components)
αH = α
end

if (regularization == :both) || (regularization == :transformation)
αW = α
end

new{T}(αW*l₁ratio,
αW*(1-l₁ratio),
αH*l₁ratio,
αH*(1-l₁ratio),
shuffle)
end
end

mutable struct CoordinateDescentState{T}
violation::T
violation_init::Union{Nothing, T}
end

prepare_state(::CoordinateDescentUpd{T}, X, W, H) where {T} =
CoordinateDescentState(zero(T), nothing)
evaluate_objv(::CoordinateDescentUpd{T}, s::CoordinateDescentState, X, W, H) where {T} =
s.violation / (s.violation_init === nothing ? oneunit(T) : s.violation_init)

"Updates W only"
function _update_coord_descent!(X, W, H, l1_reg, l2_reg, shuffle)
HHt = H * H'
XHt = X * H'

n_components = size(H, 1)
n_samples = size(W, 1)

if l2_reg > 0.
HHt[diagind(HHt)] .+= l2_reg
end
if l1_reg > 0.
XHt .-= l1_reg
end
if shuffle
permutation = randperm(n_components)
else
permutation = 1:n_components
end

violation = zero(eltype(X))

for t in permutation
for i in 1:n_samples
# gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt
grad = -XHt[i, t]

for r in 1:n_components
grad += HHt[t, r] * W[i, r]
end

# projected gradient
pg = W[i, t] == 0 ? min(zero(grad), grad) : grad
violation += abs(pg)

# Hessian
hess = HHt[t, t]
if hess != 0
W[i, t] = max(W[i, t] - grad / hess, zero(grad))
end
end
end
return violation
end


function update_wh!(upd::CoordinateDescentUpd{T}, s::CoordinateDescentState{T},
X::AbstractArray{T}, W::AbstractArray{T}, H::AbstractArray{T}) where T
Ht = transpose(H)

violation = zero(T)
violation += _update_coord_descent!(X, W, H, upd.l₁W, upd.l₂W, upd.shuffle)
Wt = transpose(W)
violation += _update_coord_descent!(PermutedDimsArray(X, (2,1)), Ht, Wt,
upd.l₁H, upd.l₂H, upd.shuffle)

s.violation = violation
if s.violation_init !== nothing
s.violation_init = violation
end
end
2 changes: 2 additions & 0 deletions src/interf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ function nnmf(X::AbstractMatrix{T}, k::Integer;
alginst = MultUpdate{T}(obj=:mse, maxiter=maxiter, tol=tol, verbose=verbose)
elseif alg == :multdiv
alginst = MultUpdate{T}(obj=:div, maxiter=maxiter, tol=tol, verbose=verbose)
elseif alg == :cd
alginst = CoordinateDescent{T}(maxiter=maxiter, tol=tol, verbose=verbose)
else
throw(ArgumentError("Invalid algorithm."))
end
Expand Down
4 changes: 2 additions & 2 deletions src/projals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ prepare_state(::ProjectedALSUpd{T}, X, W, H) where {T} = ProjectedALSUpd_State{T
function evaluate_objv(u::ProjectedALSUpd{T}, s::ProjectedALSUpd_State{T}, X, W, H) where T
r = convert(T, 0.5) * sqL2dist(X, s.WH)
if u.lambda_w > 0
r += (convert(T, 0.5) * u.lambda_w) * abs2(vecnorm(W))
r += (convert(T, 0.5) * u.lambda_w) * abs2(norm(W))
end
if u.lambda_h > 0
r += (convert(T, 0.5) * u.lambda_h) * abs2(vecnorm(H))
r += (convert(T, 0.5) * u.lambda_h) * abs2(norm(H))
end
return r
end
Expand Down
17 changes: 17 additions & 0 deletions test/coorddesc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using NMF
using Test

p = 5
n = 8
k = 3

for T in (Float64, Float32)
Wg = max.(rand(T, p, k) .- T(0.5), zero(T))
Hg = max.(rand(T, k, n) .- T(0.5), zero(T))
X = Wg * Hg
W = Wg .+ rand(T, p, k)*T(0.1)

NMF.solve!(NMF.CoordinateDescent{T}=0.0, maxiter=1000, tol=1e-9), X, W, Hg)

@test X W * Hg atol=1e-4
end
6 changes: 3 additions & 3 deletions test/interf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ n = 8
k = 3

for T in (Float64, Float32)
Wg = max(rand(T, p, k) .- 0.3, 0)
Hg = max(rand(T, k, n) .- 0.3, 0)
Wg = max.(rand(T, p, k) .- 0.3, 0)
Hg = max.(rand(T, k, n) .- 0.3, 0)
X = Wg * Hg

for alg in (:multmse, :multdiv, :projals, :alspgrad)
for alg in (:multmse, :multdiv, :projals, :alspgrad, :cd)
for init in (:random, :nndsvd, :nndsvda, :nndsvdar)
ret = NMF.nnmf(X, k, alg=alg, init=init)
end
Expand Down
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using LinearAlgebra

tests = ["utils",
"initialization",
"alspgrad"]
"alspgrad",
"coorddesc",
"interf"]

println("Running tests:")
for t in tests
Expand Down

0 comments on commit affbab8

Please sign in to comment.