Skip to content

Commit

Permalink
generalize cache container types to allow matrix calculus without req…
Browse files Browse the repository at this point in the history
…uiring the user to reshape input/output
  • Loading branch information
jrevels committed Aug 17, 2016
1 parent 2f712dc commit 897e6a8
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 54 deletions.
71 changes: 37 additions & 34 deletions src/cache.jl
Expand Up @@ -4,33 +4,35 @@

const JACOBIAN_CACHE = Dict{Tuple{Int,Int,DataType,Bool},Any}()

immutable JacobianCache{N,T}
dualvec::Vector{Dual{N,T}}
immutable JacobianCache{N,T,D}
duals::D
seeds::NTuple{N,Partials{N,T}}
end

function JacobianCache{T,N}(::Type{T}, xlen, chunk::Chunk{N})
dualvec = Vector{Dual{N,T}}(xlen)
function JacobianCache{N}(x, chunk::Chunk{N})
T = eltype(x)
duals = Array{Dual{N,T}}(size(x))
seeds = construct_seeds(T, chunk)
return JacobianCache{N,T}(dualvec, seeds)
return JacobianCache{N,T,typeof(duals)}(duals, seeds)
end

Base.copy(cache::JacobianCache) = JacobianCache(copy(cache.dualvec), cache.seeds)
@inline jacobian_dual_type{T,M,N}(::Array{T,M}, ::Chunk{N}) = Array{Dual{N,T},M}

@eval function multithread_jacobian_cachefetch!{T,N}(::Type{T}, xlen, chunk::Chunk{N},
usecache::Bool, alt::Bool = false)
Base.copy(cache::JacobianCache) = JacobianCache(copy(cache.duals), cache.seeds)

@eval function multithread_jacobian_cachefetch!{N}(x, chunk::Chunk{N}, usecache::Bool,
alt::Bool = false)
T, xlen = eltype(x), length(x)
if usecache
result = get!(JACOBIAN_CACHE, (xlen, N, T, alt)) do
construct_jacobian_caches(T, xlen, chunk)
construct_jacobian_caches(x, chunk)
end
else
result = construct_jacobian_caches(T, xlen, chunk)
result = construct_jacobian_caches(x, chunk)
end
return result::NTuple{$NTHREADS,JacobianCache{N,T}}
return result::NTuple{$NTHREADS,JacobianCache{N,T,jacobian_dual_type(x, chunk)}}
end

multithread_jacobian_cachefetch!(x, args...) = multithread_jacobian_cachefetch!(eltype(x), length(x), args...)

jacobian_cachefetch!(args...) = multithread_jacobian_cachefetch!(args...)[compat_threadid()]

########################
Expand All @@ -40,48 +42,50 @@ jacobian_cachefetch!(args...) = multithread_jacobian_cachefetch!(args...)[compat
# only used for vector mode, so we can assume that N == length(x)
const HESSIAN_CACHE = Dict{Tuple{Int,DataType},Any}()

immutable HessianCache{N,T}
dualvec::Vector{Dual{N,Dual{N,T}}}
immutable HessianCache{N,T,D}
duals::D
inseeds::NTuple{N,Partials{N,T}}
outseeds::NTuple{N,Partials{N,Dual{N,T}}}
end

function HessianCache{T,N}(::Type{T}, chunk::Chunk{N})
dualvec = Vector{Dual{N,Dual{N,T}}}(N)
function HessianCache{N}(x, chunk::Chunk{N})
T = eltype(x)
duals = Array{Dual{N,Dual{N,T}}}(size(x))
inseeds = construct_seeds(T, chunk)
outseeds = construct_seeds(Dual{N,T}, chunk)
return HessianCache{N,T}(dualvec, inseeds, outseeds)
return HessianCache{N,T,typeof(duals)}(duals, inseeds, outseeds)
end

Base.copy(cache::HessianCache) = HessianCache(copy(cache.dualvec), cache.inseeds, cache.outseeds)
@inline hessian_dual_type{T,M,N}(::Array{T,M}, ::Chunk{N}) = Array{Dual{N,Dual{N,T}},M}

@eval function multithread_hessian_cachefetch!{T,N}(::Type{T}, chunk::Chunk{N}, usecache::Bool)
Base.copy(cache::HessianCache) = HessianCache(copy(cache.duals), cache.inseeds, cache.outseeds)

@eval function multithread_hessian_cachefetch!{N}(x, chunk::Chunk{N}, usecache::Bool)
T = eltype(x)
if usecache
result = get!(HESSIAN_CACHE, (N, T)) do
construct_hessian_caches(T, chunk)
construct_hessian_caches(x, chunk)
end
else
result = construct_hessian_caches(T, chunk)
result = construct_hessian_caches(x, chunk)
end
return result::NTuple{$NTHREADS,HessianCache{N,T}}
return result::NTuple{$NTHREADS,HessianCache{N,T,hessian_dual_type(x, chunk)}}
end

multithread_hessian_cachefetch!(x, args...) = multithread_hessian_cachefetch!(eltype(x), args...)

hessian_cachefetch!(args...) = multithread_hessian_cachefetch!(args...)[compat_threadid()]

#################
# Partial seeds #
#################

function seedall!{N,T}(xdual::Vector{Dual{N,T}}, x, seed::Partials{N,T})
function seedall!{N,T}(xdual, x, seed::Partials{N,T})
for i in eachindex(xdual)
xdual[i] = Dual{N,T}(x[i], seed)
end
return xdual
end

function seed!{N,T}(xdual::Vector{Dual{N,T}}, x, seed::Partials{N,T}, index)
function seed!{N,T}(xdual, x, seed::Partials{N,T}, index)
offset = index - 1
for i in 1:N
j = i + offset
Expand All @@ -90,7 +94,7 @@ function seed!{N,T}(xdual::Vector{Dual{N,T}}, x, seed::Partials{N,T}, index)
return xdual
end

function seed!{N,T}(xdual::Vector{Dual{N,T}}, x,seeds::NTuple{N,Partials{N,T}}, index, chunksize = N)
function seed!{N,T}(xdual, x,seeds::NTuple{N,Partials{N,T}}, index, chunksize = N)
offset = index - 1
for i in 1:chunksize
j = i + offset
Expand All @@ -99,8 +103,7 @@ function seed!{N,T}(xdual::Vector{Dual{N,T}}, x,seeds::NTuple{N,Partials{N,T}},
return xdual
end

function seedhess!{N,T}(xdual::Vector{Dual{N,Dual{N,T}}}, x,
inseeds::NTuple{N,Partials{N,T}},
function seedhess!{N,T}(xdual, x, inseeds::NTuple{N,Partials{N,T}},
outseeds::NTuple{N,Partials{N,Dual{N,T}}})
for i in 1:N
xdual[i] = Dual{N,Dual{N,T}}(Dual{N,T}(x[i], inseeds[i]), outseeds[i])
Expand All @@ -112,13 +115,13 @@ end
# @eval'd functions #
#####################

@eval function construct_jacobian_caches{T,N}(::Type{T}, xlen, chunk::Chunk{N})
result = JacobianCache(T, xlen, chunk)
@eval function construct_jacobian_caches{N}(x, chunk::Chunk{N})
result = JacobianCache(x, chunk)
return $(Expr(:tuple, :result, [:(copy(result)) for i in 2:NTHREADS]...))
end

@eval function construct_hessian_caches{T,N}(::Type{T}, chunk::Chunk{N})
result = HessianCache(T, chunk)
@eval function construct_hessian_caches{N}(x, chunk::Chunk{N})
result = HessianCache(x, chunk)
return $(Expr(:tuple, :result, [:(copy(result)) for i in 2:NTHREADS]...))
end

Expand Down
10 changes: 5 additions & 5 deletions src/gradient.jl
Expand Up @@ -85,7 +85,7 @@ end

function compute_vector_mode_gradient(f, x, chunk, usecache)
cache = jacobian_cachefetch!(x, chunk, usecache)
xdual = cache.dualvec
xdual = cache.duals
seed!(xdual, x, cache.seeds, 1)
return f(xdual)
end
Expand Down Expand Up @@ -119,7 +119,7 @@ function chunk_mode_gradient_expr(out_definition::Expr)

# fetch and seed work vectors
cache = jacobian_cachefetch!(x, chunk, usecache)
xdual = cache.dualvec
xdual = cache.duals
seeds = cache.seeds
zeroseed = zero(eltype(seeds))
seedall!(xdual, x, zeroseed)
Expand Down Expand Up @@ -180,12 +180,12 @@ if IS_MULTITHREADED_JULIA
zeroseed = zero(eltype(seeds))

Base.Threads.@threads for t in 1:NTHREADS
seedall!(caches[t].dualvec, x, zeroseed)
seedall!(caches[t].duals, x, zeroseed)
end

# do first chunk manually to calculate output type
current_cache = caches[compat_threadid()]
current_xdual = current_cache.dualvec
current_xdual = current_cache.duals
current_seeds = current_cache.seeds
seed!(current_xdual, x, current_seeds, 1)
current_dual = f(current_xdual)
Expand All @@ -197,7 +197,7 @@ if IS_MULTITHREADED_JULIA
Base.Threads.@threads for c in middlechunks
# see https://github.com/JuliaLang/julia/issues/14948
local chunk_cache = caches[compat_threadid()]
local chunk_xdual = chunk_cache.dualvec
local chunk_xdual = chunk_cache.duals
local chunk_seeds = chunk_cache.seeds
local chunk_index = ((c - 1) * N + 1)
seed!(chunk_xdual, x, seeds, chunk_index)
Expand Down
2 changes: 1 addition & 1 deletion src/hessian.jl
Expand Up @@ -81,7 +81,7 @@ end

function compute_vector_mode_hessian(f, x, chunk, usecache)
cache = hessian_cachefetch!(x, chunk, usecache)
xdual = cache.dualvec
xdual = cache.duals
seedhess!(xdual, x, cache.inseeds, cache.outseeds)
return f(xdual)
end
Expand Down
31 changes: 17 additions & 14 deletions src/jacobian.jl
Expand Up @@ -17,8 +17,8 @@ jacobian(result::JacobianResult) = result.jacobian
###############

function jacobian{N}(f, x, chunk::Chunk{N} = pickchunk(x);
multithread::Bool = false,
usecache::Bool = true)
multithread::Bool = false,
usecache::Bool = true)
if N == length(x)
return vector_mode_jacobian(f, x, chunk, usecache)
elseif multithread
Expand Down Expand Up @@ -124,16 +124,16 @@ end

function compute_vector_mode_jacobian(f, x, chunk, usecache)
cache = jacobian_cachefetch!(x, chunk, usecache)
xdual = cache.dualvec
xdual = cache.duals
seed!(xdual, x, cache.seeds, 1)
return f(xdual)
end

function compute_vector_mode_jacobian(f!, y, x, chunk, usecache)
cache = jacobian_cachefetch!(x, chunk, usecache)
ycache = jacobian_cachefetch!(y, chunk, usecache, true)
xdual = cache.dualvec
ydual = ycache.dualvec
xdual = cache.duals
ydual = ycache.duals
seed!(xdual, x, cache.seeds, 1)
seedall!(ydual, y, zero(eltype(ycache.seeds)))
f!(ydual, xdual)
Expand All @@ -155,13 +155,15 @@ end

function vector_mode_jacobian!(out, f, x, chunk, usecache)
ydual = compute_vector_mode_jacobian(f, x, chunk, usecache)
return load_jacobian!(out, ydual)
load_jacobian!(reshape(out, length(ydual), length(x)), ydual)
return out
end

function vector_mode_jacobian!(out, f!, y, x, chunk, usecache)
ydual = compute_vector_mode_jacobian(f!, y, x, chunk, usecache)
load_jacobian_value!(y, ydual)
return load_jacobian!(out, ydual)
load_jacobian!(reshape(out, length(y), length(x)), ydual)
return out
end

# chunk mode #
Expand All @@ -179,9 +181,9 @@ function jacobian_chunk_mode_expr(out_definition::Expr, cache_definition::Expr,
lastchunkindex = xlen - lastchunksize + 1
middlechunks = 2:div(xlen - lastchunksize, N)

# fetch and seed work vectors
# fetch and seed work arrays
$(cache_definition)
xdual = cache.dualvec
xdual = cache.duals
seeds = cache.seeds
zeroseed = zero(eltype(seeds))
seedall!(xdual, x, zeroseed)
Expand All @@ -191,21 +193,22 @@ function jacobian_chunk_mode_expr(out_definition::Expr, cache_definition::Expr,
$(ydual_compute)
seed!(xdual, x, zeroseed, 1)
$(out_definition)
load_jacobian_chunk!(out, ydual, 1, N)
out_reshaped = reshape(out, length(ydual), length(xdual))
load_jacobian_chunk!(out_reshaped, ydual, 1, N)

# do middle chunks
for c in middlechunks
i = ((c - 1) * N + 1)
seed!(xdual, x, seeds, i)
$(ydual_compute)
seed!(xdual, x, zeroseed, i)
load_jacobian_chunk!(out, ydual, i, N)
load_jacobian_chunk!(out_reshaped, ydual, i, N)
end

# do final chunk
seed!(xdual, x, seeds, lastchunkindex, lastchunksize)
$(ydual_compute)
load_jacobian_chunk!(out, ydual, lastchunkindex, lastchunksize)
load_jacobian_chunk!(out_reshaped, ydual, lastchunkindex, lastchunksize)

$(y_definition)

Expand All @@ -225,7 +228,7 @@ end
quote
cache = jacobian_cachefetch!(x, chunk, usecache)
ycache = jacobian_cachefetch!(y, chunk, usecache, true)
ydual = ycache.dualvec
ydual = ycache.duals
yzeroseed = zero(eltype(ycache.seeds))
end,
:(f!(seedall!(ydual, y, yzeroseed), xdual)),
Expand All @@ -244,7 +247,7 @@ end
quote
cache = jacobian_cachefetch!(x, chunk, usecache)
ycache = jacobian_cachefetch!(y, chunk, usecache, true)
ydual = ycache.dualvec
ydual = ycache.duals
yzeroseed = zero(eltype(ycache.seeds))
end,
:(f!(seedall!(ydual, y, yzeroseed), xdual)),
Expand Down
8 changes: 8 additions & 0 deletions test/MiscTest.jl
Expand Up @@ -56,6 +56,14 @@ testf2 = x -> testdf(x[1]) * f(x[2])

@test_approx_eq ForwardDiff.gradient(f2, x) ForwardDiff.gradient(testf2, x)

######################################
# Higher-Dimensional Differentiation #
######################################

x = rand(3, 3)

@test_approx_eq ForwardDiff.jacobian(inv, x) -kron(inv(x'), inv(x))

########################
# Conversion/Promotion #
########################
Expand Down

0 comments on commit 897e6a8

Please sign in to comment.