Skip to content

Commit

Permalink
Merge pull request #39 from tsano430/fix-multupd
Browse files Browse the repository at this point in the history
Fix#15: Avoidance of zero division
  • Loading branch information
mschauer committed Oct 29, 2020
2 parents 9c2ab61 + 6b72312 commit 9ab7f92
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 27 deletions.
11 changes: 6 additions & 5 deletions README.md
Expand Up @@ -143,11 +143,12 @@ The matrices ``W`` and ``H`` are updated in place.
This algorithm has two different kind of objectives: minimizing mean-squared-error (``:mse``) and minimizing divergence (``:div``). Both ``W`` and ``H`` need to be initialized.

```julia
MultUpdate(obj=:mse, # objective, either :mse or :div
maxiter=100, # maximum number of iterations
verbose=false, # whether to show procedural information
tol=1.0e-6, # tolerance of changes on W and H upon convergence
lambda=1.0e-9) # regularization coefficients (added to the denominator)
MultUpdate(obj::Symbol=:mse, # objective, either :mse or :div
maxiter::Integer=100, # maximum number of iterations
verbose::Bool=false, # whether to show procedural information
tol::Real=1.0e-6, # tolerance of changes on W and H upon convergence
lambda_w::Real=0.0, # L1 regularization coefficient for W
lambda_h::Real=0.0) # L1 regularization coefficient for H
```

**Note:** the values above are default values for the keyword arguments. One can override part (or all) of them.
Expand Down
64 changes: 42 additions & 22 deletions src/multupd.jl
Expand Up @@ -7,40 +7,54 @@
#

mutable struct MultUpdate{T}
obj::Symbol # objective :mse or :div
maxiter::Int # maximum number of iterations
verbose::Bool # whether to show procedural information
tol::T # change tolerance upon convergence
lambda::T # regularization coefficient
obj::Symbol # objective :mse or :div
maxiter::Int # maximum number of iterations
verbose::Bool # whether to show procedural information
tol::T # change tolerance upon convergence
lambda_w::T # L1 regularization coefficient for W
lambda_h::T # L1 regularization coefficient for H

function MultUpdate{T}(;obj::Symbol=:mse,
maxiter::Integer=100,
verbose::Bool=false,
tol::Real=cbrt(eps(T)),
lambda::Real=sqrt(eps(T))) where T
lambda_w::Real=zero(T),
lambda_h::Real=zero(T),
lambda::Union{Real, Nothing}=nothing) where T

obj == :mse || obj == :div || throw(ArgumentError("Invalid value for obj."))
maxiter > 1 || throw(ArgumentError("maxiter must be greater than 1."))
tol > 0 || throw(ArgumentError("tol must be positive."))
lambda >= 0 || throw(ArgumentError("lambda must be non-negative."))

new{T}(obj, maxiter, verbose, tol, lambda)
lambda_w >= 0 || throw(ArgumentError("lambda_w must be non-negative."))
lambda_h >= 0 || throw(ArgumentError("lambda_h must be non-negative."))
if lambda !== nothing && lambda >= 0
@warn "lambda is deprecated, use lambda_w and lambda_h instead."
lambda_w = iszero(lambda_w) ? lambda : lambda_w
lambda_h = iszero(lambda_h) ? lambda : lambda_h
end
if obj == :div
lambda_w = max(lambda_w, sqrt(eps(T)))
lambda_h = max(lambda_h, sqrt(eps(T)))
end
new{T}(obj, maxiter, verbose, tol, lambda_w, lambda_h)
end
end

function solve!(alg::MultUpdate, X, W, H)
function solve!(alg::MultUpdate{T}, X, W, H) where T

if alg.obj == :mse
nmf_skeleton!(MultUpdMSE(alg.lambda), X, W, H, alg.maxiter, alg.verbose, alg.tol)
nmf_skeleton!(MultUpdMSE(alg.lambda_w, alg.lambda_h, sqrt(eps(T))), X, W, H, alg.maxiter, alg.verbose, alg.tol)
else # alg == :div
nmf_skeleton!(MultUpdDiv(alg.lambda), X, W, H, alg.maxiter, alg.verbose, alg.tol)
nmf_skeleton!(MultUpdDiv(alg.lambda_w, alg.lambda_h, sqrt(eps(T))), X, W, H, alg.maxiter, alg.verbose, alg.tol)
end
end

# the multiplicative updating algorithm for MSE objective

struct MultUpdMSE{T} <: NMFUpdater{T}
lambda::T
lambda_w::T
lambda_h::T
delta::T
end

struct MultUpdMSE_State{T}
Expand All @@ -66,7 +80,9 @@ evaluate_objv(::MultUpdMSE, s::MultUpdMSE_State, X, W, H) = sqL2dist(X, s.WH)
function update_wh!(upd::MultUpdMSE{T}, s::MultUpdMSE_State{T}, X, W::Matrix{T}, H::Matrix{T}) where T

# fields
lambda = upd.lambda
lambda_w = upd.lambda_w
lambda_h = upd.lambda_h
delta = upd.delta
WH = s.WH
WtX = s.WtX
WtWH = s.WtWH
Expand All @@ -79,7 +95,7 @@ function update_wh!(upd::MultUpdMSE{T}, s::MultUpdMSE_State{T}, X, W::Matrix{T},
mul!(WtWH, Wt, WH)

@inbounds for i = 1:length(H)
H[i] *= (WtX[i] / (WtWH[i] + lambda))
H[i] *= (max(zero(T), WtX[i] - lambda_h) / (WtWH[i] + delta))
end
mul!(WH, W, H)

Expand All @@ -89,7 +105,7 @@ function update_wh!(upd::MultUpdMSE{T}, s::MultUpdMSE_State{T}, X, W::Matrix{T},
mul!(WHHt, WH, Ht)

@inbounds for i = 1:length(W)
W[i] *= (XHt[i] / (WHHt[i] + lambda))
W[i] *= (max(zero(T), XHt[i] - lambda_w) / (WHHt[i] + delta))
end
mul!(WH, W, H)
end
Expand All @@ -98,7 +114,9 @@ end
# the multiplicative updating algorithm for divergence objective

struct MultUpdDiv{T} <: NMFUpdater{T}
lambda::T
lambda_w::T
lambda_h::T
delta::T
end

struct MultUpdDiv_State{T}
Expand Down Expand Up @@ -131,7 +149,9 @@ function update_wh!(upd::MultUpdDiv{T}, s::MultUpdDiv_State{T}, X, W::Matrix{T},
pn = p * n

# fields
lambda = upd.lambda
lambda_w = upd.lambda_w
lambda_h = upd.lambda_h
delta = upd.delta
sW = s.sW
sH = s.sH
WH = s.WH
Expand All @@ -143,23 +163,23 @@ function update_wh!(upd::MultUpdDiv{T}, s::MultUpdDiv_State{T}, X, W::Matrix{T},

# update H
@inbounds for i = 1:length(X)
Q[i] = X[i] / (WH[i] + lambda)
Q[i] = X[i] / (WH[i] + delta)
end
mul!(WtQ, transpose(W), Q)
sum!(fill!(sW, 0), W)
@inbounds for j = 1:n, i = 1:k
H[i,j] *= (WtQ[i,j] / sW[i])
H[i,j] *= (WtQ[i,j] / (sW[i] + lambda_h))
end
mul!(WH, W, H)

# update W
@inbounds for i = 1:length(X)
Q[i] = X[i] / (WH[i] + lambda)
Q[i] = X[i] / (WH[i] + delta)
end
mul!(QHt, Q, transpose(H))
sum!(fill!(sH, 0), H)
@inbounds for j = 1:k, i = 1:p
W[i,j] *= (QHt[i,j] / sH[j])
W[i,j] *= (QHt[i,j] / (sH[j] + lambda_w))
end
mul!(WH, W, H)
end
29 changes: 29 additions & 0 deletions test/multupd.jl
@@ -0,0 +1,29 @@
using NMF
using Test

p = 5
n = 8
k = 3

Random.seed!(5678)

for T in (Float64, Float32)
for alg in (:mse, :div)
for lambda_w in (0.0, 1e-4)
for lambda_h in (0.0, 1e-4)
Wg = max.(rand(T, p, k) .- T(0.5), zero(T))
Hg = max.(rand(T, k, n) .- T(0.5), zero(T))
X = Wg * Hg
W = Wg .+ rand(T, p, k)*T(0.1)

NMF.solve!(NMF.MultUpdate{T}(obj=alg, maxiter=5000, tol=1e-9, lambda_w=lambda_w, lambda_h=lambda_h), X, W, Hg)

@test all(W .>= 0.0)
@test all(Hg .>= 0.0)
@test !any(isnan.(W))
@test !any(isnan.(Hg))
@test X W * Hg atol=1e-2
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Expand Up @@ -5,6 +5,7 @@ using LinearAlgebra

tests = ["utils",
"initialization",
"multupd",
"alspgrad",
"coorddesc",
"interf"]
Expand Down

0 comments on commit 9ab7f92

Please sign in to comment.