Skip to content

Commit

Permalink
annotate function types to encourage specialization
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Nov 3, 2016
1 parent e534cba commit b2d1150
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 32 deletions.
8 changes: 4 additions & 4 deletions src/api_utils.jl
Expand Up @@ -33,16 +33,16 @@ end
# vector mode function evaluation #
###################################

vector_mode_dual_eval(f, x, opts::Multithread) = vector_mode_dual_eval(f, x, gradient_options(opts))
vector_mode_dual_eval(f, x, opts::Tuple) = vector_mode_dual_eval(f, x, first(opts))
vector_mode_dual_eval{F}(f::F, x, opts::Multithread) = vector_mode_dual_eval(f, x, gradient_options(opts))
vector_mode_dual_eval{F}(f::F, x, opts::Tuple) = vector_mode_dual_eval(f, x, first(opts))

function vector_mode_dual_eval(f, x, opts)
function vector_mode_dual_eval{F}(f::F, x, opts)
xdual = opts.duals
seed!(xdual, x, opts.seeds)
return f(xdual)
end

function vector_mode_dual_eval(f!, y, x, opts)
function vector_mode_dual_eval{F}(f!::F, y, x, opts)
ydual, xdual = opts.duals
seed!(xdual, x, opts.seeds)
seed!(ydual, y)
Expand Down
4 changes: 2 additions & 2 deletions src/derivative.jl
Expand Up @@ -2,9 +2,9 @@
# API methods #
###############

derivative(f, x) = extract_derivative(f(Dual(x, one(x))))
derivative{F}(f::F, x) = extract_derivative(f(Dual(x, one(x))))

function derivative!(out, f, x)
function derivative!{F}(out, f::F, x)
y = f(Dual(x, one(x)))
extract_derivative!(out, y)
return out
Expand Down
16 changes: 8 additions & 8 deletions src/gradient.jl
Expand Up @@ -2,15 +2,15 @@
# API methods #
###############

function gradient(f, x, opts::AbstractOptions = Options(x))
function gradient{F}(f::F, x, opts::AbstractOptions = Options(x))
if chunksize(opts) == length(x)
return vector_mode_gradient(f, x, opts)
else
return chunk_mode_gradient(f, x, opts)
end
end

function gradient!(out, f, x, opts::AbstractOptions = Options(x))
function gradient!{F}(out, f::F, x, opts::AbstractOptions = Options(x))
if chunksize(opts) == length(x)
vector_mode_gradient!(out, f, x, opts)
else
Expand Down Expand Up @@ -56,13 +56,13 @@ end
# vector mode #
###############

function vector_mode_gradient(f, x, opts)
function vector_mode_gradient{F}(f::F, x, opts)
ydual = vector_mode_dual_eval(f, x, opts)
out = similar(x, valtype(ydual))
return extract_gradient!(out, ydual)
end

function vector_mode_gradient!(out, f, x, opts)
function vector_mode_gradient!{F}(out, f::F, x, opts)
ydual = vector_mode_dual_eval(f, x, opts)
extract_gradient!(out, ydual)
return out
Expand Down Expand Up @@ -119,11 +119,11 @@ function chunk_mode_gradient_expr(out_definition::Expr)
end
end

@eval function chunk_mode_gradient{N}(f, x, opts::Options{N})
@eval function chunk_mode_gradient{F,N}(f::F, x, opts::Options{N})
$(chunk_mode_gradient_expr(:(out = similar(x, valtype(ydual)))))
end

@eval function chunk_mode_gradient!{N}(out, f, x, opts::Options{N})
@eval function chunk_mode_gradient!{F,N}(out, f::F, x, opts::Options{N})
$(chunk_mode_gradient_expr(:()))
end

Expand Down Expand Up @@ -185,11 +185,11 @@ if IS_MULTITHREADED_JULIA
end
end

@eval function chunk_mode_gradient(f, x, multi_opts::Multithread)
@eval function chunk_mode_gradient{F}(f::F, x, multi_opts::Multithread)
$(multithread_chunk_mode_expr(:(out = similar(x, valtype(current_ydual)))))
end

@eval function chunk_mode_gradient!(out, f, x, multi_opts::Multithread)
@eval function chunk_mode_gradient!{F}(out, f::F, x, multi_opts::Multithread)
$(multithread_chunk_mode_expr(:()))
end
else
Expand Down
12 changes: 6 additions & 6 deletions src/hessian.jl
Expand Up @@ -2,18 +2,18 @@
# API methods #
###############

function hessian(f, x, opts::AbstractOptions = HessianOptions(x))
function hessian{F}(f::F, x, opts::AbstractOptions = HessianOptions(x))
∇f = y -> gradient(f, y, gradient_options(opts))
return jacobian(∇f, x, jacobian_options(opts))
end

function hessian!(out, f, x, opts::AbstractOptions = HessianOptions(x))
function hessian!{F}(out, f::F, x, opts::AbstractOptions = HessianOptions(x))
∇f = y -> gradient(f, y, gradient_options(opts))
jacobian!(out, ∇f, x, jacobian_options(opts))
return out
end

function hessian!(out::DiffResult, f, x, opts::AbstractOptions = HessianOptions(out, x))
function hessian!{F}(out::DiffResult, f::F, x, opts::AbstractOptions = HessianOptions(out, x))
∇f! = (y, z) -> begin
result = DiffResult(zero(eltype(y)), y)
gradient!(result, f, z, gradient_options(opts))
Expand All @@ -30,6 +30,6 @@ end

const HESS_OPTIONS_ERR_MSG = "To use `hessian`/`hessian!` with options, use `HessianOptions` or `Multithread(::HessianOptions)` instead of `Options`."

hessian(f, x, ::Options) = error(HESS_OPTIONS_ERR_MSG)
hessian!(out, f, x, ::Options) = error(HESS_OPTIONS_ERR_MSG)
hessian!(::DiffResult, f, x, ::Options) = error(HESS_OPTIONS_ERR_MSG)
hessian{F}(f::F, x, ::Options) = error(HESS_OPTIONS_ERR_MSG)
hessian!{F}(out, f::F, x, ::Options) = error(HESS_OPTIONS_ERR_MSG)
hessian!{F}(::DiffResult, f::F, x, ::Options) = error(HESS_OPTIONS_ERR_MSG)
24 changes: 12 additions & 12 deletions src/jacobian.jl
Expand Up @@ -2,23 +2,23 @@
# API methods #
###############

function jacobian(f, x, opts::Options = Options(x))
function jacobian{F}(f::F, x, opts::Options = Options(x))
if chunksize(opts) == length(x)
return vector_mode_jacobian(f, x, opts)
else
return chunk_mode_jacobian(f, x, opts)
end
end

function jacobian(f!, y, x, opts::Options = Options(y, x))
function jacobian{F}(f!::F, y, x, opts::Options = Options(y, x))
if chunksize(opts) == length(x)
return vector_mode_jacobian(f!, y, x, opts)
else
return chunk_mode_jacobian(f!, y, x, opts)
end
end

function jacobian!(out, f, x, opts::Options = Options(x))
function jacobian!{F}(out, f::F, x, opts::Options = Options(x))
if chunksize(opts) == length(x)
vector_mode_jacobian!(out, f, x, opts)
else
Expand All @@ -27,7 +27,7 @@ function jacobian!(out, f, x, opts::Options = Options(x))
return out
end

function jacobian!(out, f!, y, x, opts::Options = Options(y, x))
function jacobian!{F}(out, f!::F, y, x, opts::Options = Options(y, x))
if chunksize(opts) == length(x)
vector_mode_jacobian!(out, f!, y, x, opts)
else
Expand Down Expand Up @@ -71,15 +71,15 @@ reshape_jacobian(out::DiffResult, ydual, xdual) = reshape_jacobian(DiffBase.jaco
# vector mode #
###############

function vector_mode_jacobian{N}(f, x, opts::Options{N})
function vector_mode_jacobian{F,N}(f::F, x, opts::Options{N})
ydual = vector_mode_dual_eval(f, x, opts)
out = similar(ydual, valtype(eltype(ydual)), length(ydual), N)
extract_jacobian!(out, ydual, N)
extract_value!(out, ydual)
return out
end

function vector_mode_jacobian{N}(f!, y, x, opts::Options{N})
function vector_mode_jacobian{F,N}(f!::F, y, x, opts::Options{N})
ydual = vector_mode_dual_eval(f!, y, x, opts)
map!(value, y, ydual)
out = similar(y, length(y), N)
Expand All @@ -88,14 +88,14 @@ function vector_mode_jacobian{N}(f!, y, x, opts::Options{N})
return out
end

function vector_mode_jacobian!{N}(out, f, x, opts::Options{N})
function vector_mode_jacobian!{F,N}(out, f::F, x, opts::Options{N})
ydual = vector_mode_dual_eval(f, x, opts)
extract_jacobian!(out, ydual, N)
extract_value!(out, ydual)
return out
end

function vector_mode_jacobian!{N}(out, f!, y, x, opts::Options{N})
function vector_mode_jacobian!{F,N}(out, f!::F, y, x, opts::Options{N})
ydual = vector_mode_dual_eval(f!, y, x, opts)
map!(value, y, ydual)
extract_jacobian!(out, ydual, N)
Expand Down Expand Up @@ -150,7 +150,7 @@ function jacobian_chunk_mode_expr(work_array_definition::Expr, compute_ydual::Ex
end
end

@eval function chunk_mode_jacobian{N}(f, x, opts::Options{N})
@eval function chunk_mode_jacobian{F,N}(f::F, x, opts::Options{N})
$(jacobian_chunk_mode_expr(quote
xdual = opts.duals
seed!(xdual, x)
Expand All @@ -160,7 +160,7 @@ end
:()))
end

@eval function chunk_mode_jacobian{N}(f!, y, x, opts::Options{N})
@eval function chunk_mode_jacobian{F,N}(f!::F, y, x, opts::Options{N})
$(jacobian_chunk_mode_expr(quote
ydual, xdual = opts.duals
seed!(xdual, x)
Expand All @@ -170,7 +170,7 @@ end
:(map!(value, y, ydual))))
end

@eval function chunk_mode_jacobian!{N}(out, f, x, opts::Options{N})
@eval function chunk_mode_jacobian!{F,N}(out, f::F, x, opts::Options{N})
$(jacobian_chunk_mode_expr(quote
xdual = opts.duals
seed!(xdual, x)
Expand All @@ -180,7 +180,7 @@ end
:(extract_value!(out, ydual))))
end

@eval function chunk_mode_jacobian!{N}(out, f!, y, x, opts::Options{N})
@eval function chunk_mode_jacobian!{F,N}(out, f!::F, y, x, opts::Options{N})
$(jacobian_chunk_mode_expr(quote
ydual, xdual = opts.duals
seed!(xdual, x)
Expand Down

0 comments on commit b2d1150

Please sign in to comment.