Skip to content

Commit

Permalink
auto-gen all univariate HessianNum functions using Calculus.jl, added…
Browse files Browse the repository at this point in the history
… corresponding tests
  • Loading branch information
jrevels committed Jul 28, 2015
1 parent c189ca1 commit 548be29
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 84 deletions.
43 changes: 15 additions & 28 deletions src/HessianNum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,42 +173,29 @@ end
#-------------------------------------#
-(h::HessianNum) = HessianNum(-gradnum(h), -hess(h))

# The Tuples in `h_univar_funcs` have the following format:
#
# (:function_name,
# :(expression defining the kth entry of the hessian vector, using any available variables))
const h_univar_funcs = Tuple{Symbol, Expr}[
(:sqrt, :((-grad(h,i)*grad(h,j)+2*value(h)*hess(h,i)) / (4*(value(h)^(1.5))))),
(:cbrt, :((-2*grad(h,i)*grad(h,j)+3*value(h)*hess(h,k)) / (9*cbrt(value(h)^5)))),
(:exp, :(exp(value(h))*(grad(h,i)*grad(h,j)+hess(h,k)))),
(:log, :((value(h)*hess(h,k)-grad(h,i)*grad(h,j))/(value(h)^2))),
(:log2, :((value(h)*hess(h,k)-grad(h,i)*grad(h,j)) / ((value(h)^2)*0.6931471805599453))),
(:log10, :((value(h)*hess(h,k)-grad(h,i)*grad(h,j)) / ((value(h)^2)*2.302585092994046))),
(:sin, :(-sin(value(h))*grad(h,i)*grad(h,j)+cos(value(h))*hess(h,k))),
(:cos, :(-cos(value(h))*grad(h,i)*grad(h,j)-sin(value(h))*hess(h,k))),
(:tan, :((sec(value(h))^2)*(2*tan(value(h))*grad(h,i)*grad(h,j)+hess(h,k)))),
(:asin, :((value(h)*grad(h,i)*grad(h,j)-((value(h)^2)-1)*new_hess[k]) / ((1-(value(h)^2))^1.5))),
(:acos, :((-value(h)*grad(h,i)*grad(h,j)+((value(h)^2)-1)*new_hess[k]) / ((1-(value(h)^2))^1.5))),
(:atan, :((-2*value(h)*grad(h,i)*grad(h,j)+((value(h)^2)+1)*new_hess[k]) / ((1+(value(h)^2))^2))),
(:sinh, :(sinh(value(h))*grad(h,i)*grad(h,j)+cosh(value(h))*hess(h,k))),
(:cosh, :(cosh(value(h))*grad(h,i)*grad(h,j)+sinh(value(h))*hess(h,k))),
(:tanh, :((sech(value(h))^2)*(-2*tanh(value(h))*grad(h,i)*grad(h,j)+hess(h,k)))),
(:asinh, :((-value(h)*grad(h,i)*grad(h,j)+((1+(value(h)^2))*hess(h,k))) / ((1+(value(h)^2))^1.5))),
(:acosh, :((-value(h)*grad(h,i)*grad(h,j)+(((value(h)^2)-1)*hess(h,k))) / (((1+value(h))^1.5)*((value(h)-1)^1.5)))),
(:atanh, :((2*value(h)*grad(h,i)*grad(h,j)-(((value(h)^2)-1)*hess(h,k)))/(((value(h)^2)-1)^2))),
(:lgamma, :(digamma(value(h))*hess(h,k)+trigamma(value(h))*grad(h,i)*grad(h,j))),
(:digamma, :(trigamma(value(h))*hess(h,k)+polygamma(2,value(h))*grad(h,i)*grad(h,j)))
]
# the second derivative of functions in unsupported_univar_hess_funcs involves differentiating
# elementary functions that are unsupported by Calculus.jl, e.g. abs(x) and polygamma(x)
const unsupported_univar_hess_funcs = [:asec, :acsc, :asecd, :acscd, :acsch, :trigamma]
const univar_hess_funcs = filter!(sym -> !in(sym, unsupported_univar_hess_funcs), map(first, Calculus.symbolic_derivatives_1arg()))

# Univariate function construction loop
for (fsym, term) in h_univar_funcs
for fsym in univar_hess_funcs
loadfsym = symbol(string("loadhess_", fsym, "!"))

hval = :hval
call_expr = :($(fsym)($hval))
deriv1 = differentiate(call_expr, hval)
deriv2 = differentiate(deriv1, hval)

@eval begin
function $(loadfsym){N}(h::HessianNum{N}, output)
hval = value(h)
deriv1 = $deriv1
deriv2 = $deriv2
k = 1
for i in 1:N
for j in 1:i
output[k] = $(term)
output[k] = deriv1*hess(h, k) + deriv2*grad(h, i)*grad(h, j)
k += 1
end
end
Expand Down
56 changes: 0 additions & 56 deletions test/test_fad_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,62 +87,6 @@ fill!(testout, zero(eltype(testout)))
jacf = jacobian_func(jac_testf, P, M, mutates=false)
@test jacf(testx) == testresult

########################
# Test Hessian methods #
########################
# hess_testf: R⁴ -> R
function hess_testf(x::Vector)
@assert length(x) == N
return prod(i->i^2, x)
end

function hess_deriv(i, j)
wrt = [:a, :b, :c, :d]

diff = differentiate(:(a^2 * b^2 * c^2 * d^2), wrt[j])
diff = differentiate(diff, wrt[i])

str = string(diff)
str = replace(str, 'a', "x[1]")
str = replace(str, 'b', "x[2]")
str = replace(str, 'c', "x[3]")
str = replace(str, 'd', "x[4]")

return parse(str)
end

function hess_deriv(x::Vector, i, j)
ex = hess_deriv(i,j)
@eval begin
x = $x
return $ex
end
end

# hard code the correct hessian for
# hess_testf at the given vector x
function hess_test_result(x::Vector)
@assert length(x) == N
return [hess_deriv(x, i, j) for i in 1:N, j in 1:N]
end

testout = Array(Float64, N, N)
testresult = hess_test_result(testx)

hessian!(hess_testf, testx, testout, P)
@test testout == testresult
fill!(testout, zero(eltype(testout)))

@test hessian(hess_testf, testx, P) == testresult

hessf! = hessian_func(hess_testf, P, mutates=true)
hessf!(testx, testout)
@test testout == testresult
fill!(testout, zero(eltype(testout)))

hessf = hessian_func(hess_testf, P, mutates=false)
@test hessf(testx) == testresult

#######################
# Test Tensor methods #
#######################
Expand Down
60 changes: 60 additions & 0 deletions test/test_hessnum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,63 @@ seekstart(io)
@test read(io, typeof(test_hess)) == test_hess

close(io)

####################################
# Math tests (including API usage) #
####################################

# Univariate functions #
#----------------------#
N = 4
P = Partials{N,Float64}
testout = Array(Float64, N, N)

function hess_deriv_ij(f_expr, x::Vector, i, j)
var_syms = [:a, :b, :c, :d]
diff_expr = differentiate(f_expr, var_syms[j])
diff_expr = differentiate(diff_expr, var_syms[i])
@eval begin
a,b,c,d = $x
return $diff_expr
end
end

function hess_test_result(f_expr, x::Vector)
return [hess_deriv_ij(f_expr, x, i, j) for i in 1:N, j in 1:N]
end

function hess_test_x(fsym, N)
randrange = 0.01:.01:.99

if fsym == :acosh
randrange += 1
elseif fsym == :acoth
randrange += 2
end

return rand(randrange, N)
end

for fsym in ForwardDiff.univar_hess_funcs
testexpr = :($(fsym)(a) + $(fsym)(b) - $(fsym)(c) * $(fsym)(d))

@eval function testf(x::Vector)
a,b,c,d = x
return $testexpr
end

testx = hess_test_x(fsym, N)
testresult = hess_test_result(testexpr, testx)

hessian!(testf, testx, testout, P)
@test_approx_eq testout testresult

@test_approx_eq hessian(testf, testx, P) testresult

hessf! = hessian_func(testf, P, mutates=true)
hessf!(testx, testout)
@test_approx_eq testout testresult

hessf = hessian_func(testf, P, mutates=false)
@test_approx_eq hessf(testx) testresult
end

0 comments on commit 548be29

Please sign in to comment.