Skip to content

Commit

Permalink
alspgrad algorithm tested
Browse files Browse the repository at this point in the history
  • Loading branch information
lindahua committed Feb 12, 2014
1 parent a75f90d commit d0c745e
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 39 deletions.
1 change: 1 addition & 0 deletions examples/densenmf.jl
Expand Up @@ -44,6 +44,7 @@ function print_help()
println(" multmse: Multiplicative update (minimize MSE)")
println(" multdiv: Multiplicative update (minimize divergence)")
println(" projals: Projected ALS")
println(" alspgrad: ALS Projected Gradient Descent")
println()
end

Expand Down
87 changes: 83 additions & 4 deletions src/alspgrad.jl
Expand Up @@ -34,12 +34,17 @@ immutable ALSGradUpdH_State
Array(Float64, k, n),
Array(Float64, k, n),
Array(Float64, k, n),
At_mul_B(W, W),
At_mul_B(W, X),
Array(Float64, k, k),
Array(Float64, k, n),
Array(Float64, k, n))
end
end

function set_w!(s::ALSGradUpdH_State, X::ContiguousMatrix, W::ContiguousMatrix)
At_mul_B!(s.WtW, W, W)
At_mul_B!(s.WtX, W, X)
end

function alspgrad_updateh!(X::Matrix{Float64},
W::Matrix{Float64},
H::Matrix{Float64};
Expand All @@ -51,6 +56,7 @@ function alspgrad_updateh!(X::Matrix{Float64},
verbose::Bool = false)

s = ALSGradUpdH_State(X, W, H)
set_w!(s, X, W)
_alspgrad_updateh!(X, W, H, s,
maxiter, traceiter, tolg,
beta, sigma, verbose)
Expand Down Expand Up @@ -178,12 +184,18 @@ immutable ALSGradUpdW_State
Array(Float64, p, k),
Array(Float64, p, k),
Array(Float64, p, k),
A_mul_Bt(H, H),
A_mul_Bt(X, H),
Array(Float64, k, k),
Array(Float64, p, k),
Array(Float64, p, k))
end
end

function set_h!(s::ALSGradUpdW_State, X::ContiguousMatrix, H::ContiguousMatrix)
A_mul_Bt!(s.HHt, H, H)
A_mul_Bt!(s.XHt, X, H)
end


function alspgrad_updatew!(X::Matrix{Float64},
W::Matrix{Float64},
H::Matrix{Float64};
Expand All @@ -195,6 +207,7 @@ function alspgrad_updatew!(X::Matrix{Float64},
verbose::Bool = false)

s = ALSGradUpdW_State(X, W, H)
set_h!(s, X, H)
_alspgrad_updatew!(X, W, H, s,
maxiter, traceiter, tolg,
beta, sigma, verbose)
Expand Down Expand Up @@ -305,3 +318,69 @@ function _alspgrad_updatew!(X::Matrix{Float64}, # size (p, n)
end


## main algorithm

type ALSPGrad
maxiter::Int # maximum number of main iterations
maxsubiter::Int # maximum number of iterations within a sub-routine
tol::Float64 # tolerance of changes on W & H (main)
tolg::Float64 # tolerance of grad norm in sub-routine
verbose::Bool # whether to show procedural information (main)

function ALSPGrad(;maxiter::Integer=100,
maxsubiter::Integer=200,
tol::Real=1.0e-6,
tolg::Real=1.0e-4,
verbose::Bool=false)
new(int(maxiter),
int(maxsubiter),
float64(tol),
float64(tolg),
verbose)
end
end

immutable ALSPGradUpd <: NMFUpdater
maxsubiter::Int
tolg::Float64
end

solve!(alg::ALSPGrad, X::Matrix{Float64}, W::Matrix{Float64}, H::Matrix{Float64}) =
nmf_skeleton!(ALSPGradUpd(alg.maxsubiter, alg.tolg),
X, W, H, alg.maxiter, alg.verbose, alg.tol)


immutable ALSPGradUpd_State
WH::Matrix{Float64}
uhstate::ALSGradUpdH_State
uwstate::ALSGradUpdW_State

ALSPGradUpd_State(X::ContiguousMatrix, W::ContiguousMatrix, H::ContiguousMatrix) =
new(W * H,
ALSGradUpdH_State(X, W, H),
ALSGradUpdW_State(X, W, H))
end

prepare_state(::ALSPGradUpd, X, W, H) = ALSPGradUpd_State(X, W, H)
evaluate_objv(u::ALSPGradUpd, s::ALSPGradUpd_State, X, W, H) = sqL2dist(X, s.WH)

function update_wh!(upd::ALSPGradUpd, s::ALSPGradUpd_State,
X::Matrix{Float64},
W::Matrix{Float64},
H::Matrix{Float64})

# update H
set_w!(s.uhstate, X, W)
_alspgrad_updateh!(X, W, H, s.uhstate,
upd.maxsubiter, 20, upd.tolg, 0.2, 0.01, false)

# update W
set_h!(s.uwstate, X, H)
_alspgrad_updatew!(X, W, H, s.uwstate,
upd.maxsubiter, 20, upd.tolg, 0.2, 0.01, false)

# update WH
A_mul_B!(s.WH, W, H)
end


69 changes: 34 additions & 35 deletions src/interf.jl
@@ -1,45 +1,44 @@
# Interface function: nnmf

function nnmf(X::Matrix{Float64}, k::Integer;
init::Symbol=:nndsvdar,
alg::Symbol=:projals,
maxiter::Integer=100,
tol::Real=1.0e-6,
verbose::Bool=false)
init::Symbol=:nndsvdar,
alg::Symbol=:alspgrad,
maxiter::Integer=100,
tol::Real=1.0e-6,
verbose::Bool=false)

p, n = size(X)
k <= min(p, n) || error("The value of k should not exceed min(size(X)).")
p, n = size(X)
k <= min(p, n) || error("The value of k should not exceed min(size(X)).")

# determine whether H needs to be initialized
if alg == :projals
initH = false
elseif alg == :multmse || alg == :multdiv
initH = true
else
error("Invalid value for alg.")
end
# determine whether H needs to be initialized
if alg == :projals
initH = false
else
initH = true
end

# perform initialization
if init == :random
W, H = randinit(X, k; zeroh=!initH, normalize=true)
elseif init == :nndsvd
W, H = nndsvd(X, k; zeroh=!initH)
elseif init == :nndsvda
W, H = nndsvd(X, k; variant=:a, zeroh=!initH)
elseif init == :nndsvdar
W, H = nndsvd(X, k; variant=:ar, zeroh=!initH)
else
error("Invalid value for init.")
end
# perform initialization
if init == :random
W, H = randinit(X, k; zeroh=!initH, normalize=true)
elseif init == :nndsvd
W, H = nndsvd(X, k; zeroh=!initH)
elseif init == :nndsvda
W, H = nndsvd(X, k; variant=:a, zeroh=!initH)
elseif init == :nndsvdar
W, H = nndsvd(X, k; variant=:ar, zeroh=!initH)
else
error("Invalid value for init.")
end

# choose algorithm
alginst =
alg == :projals ? ProjectedALS(maxiter=maxiter, tol=tol, verbose=verbose) :
alg == :multmse ? MultUpdate(obj=:mse, maxiter=maxiter, tol=tol, verbose=verbose) :
alg == :multdiv ? MultUpdate(obj=:div, maxiter=maxiter, tol=tol, verbose=verbose) :
error("Invalid algorithm.")
# choose algorithm
alginst =
alg == :projals ? ProjectedALS(maxiter=maxiter, tol=tol, verbose=verbose) :
alg == :alspgrad ? ALSPGrad(maxiter=maxiter, tol=tol, verbose=verbose) :
alg == :multmse ? MultUpdate(obj=:mse, maxiter=maxiter, tol=tol, verbose=verbose) :
alg == :multdiv ? MultUpdate(obj=:div, maxiter=maxiter, tol=tol, verbose=verbose) :
error("Invalid algorithm.")

# run optimization
solve!(alginst, X, W, H)
# run optimization
solve!(alginst, X, W, H)
end

0 comments on commit d0c745e

Please sign in to comment.