Skip to content

Commit

Permalink
fix typos/bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Mar 7, 2017
1 parent a7623a6 commit d3e4b63
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 47 deletions.
37 changes: 20 additions & 17 deletions benchmark/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,26 @@ function grad_benchmark_driver!(out, f, x)
gc()

if length(tp.tape) <= 10000
ctp = ReverseDiff.compile(tp)

# warmup
ReverseDiff.gradient!(out, ctp, x)
ReverseDiff.forward_pass!(ctp)
ReverseDiff.reverse_pass!(ctp)

# actual
print(" gradient! (compiled): ")
@time ReverseDiff.gradient!(out, ctp, x)
gc()
print(" forward pass (compiled): ")
@time ReverseDiff.forward_pass!(ctp)
gc()
print(" reverse pass (compiled): ")
@time ReverseDiff.reverse_pass!(ctp)
gc()
@eval begin
out, x = $out, $x
ctp = ReverseDiff.compile($tp)

# warmup
ReverseDiff.gradient!(out, ctp, x)
ReverseDiff.forward_pass!(ctp)
ReverseDiff.reverse_pass!(ctp)

# actual
print(" gradient! (compiled): ")
@time ReverseDiff.gradient!(out, ctp, x)
gc()
print(" forward pass (compiled): ")
@time ReverseDiff.forward_pass!(ctp)
gc()
print(" reverse pass (compiled): ")
@time ReverseDiff.reverse_pass!(ctp)
gc()
end
else
println("skipping compiled GradientTape benchmark because the tape is too long ($(length(tp.tape)) elements)")
end
Expand Down
2 changes: 1 addition & 1 deletion src/api/hessians.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ end
function hessian!(result::DiffResult, tape::Union{HessianTape,CompiledHessian}, input::AbstractArray)
seeded_forward_pass!(tape, input)
seeded_reverse_pass!(DiffResult(DiffBase.gradient(result), DiffBase.hessian(result)), tape)
DiffBase.value!(result, tape.func(input))
DiffBase.value!(result, func_hook(tape)(input))
return result
end

Expand Down
40 changes: 33 additions & 7 deletions src/api/tape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ for T in (:GradientTape, :JacobianTape, :HessianTape)

Base.length(t::$T) = length(t.tape)

input_hook(t::$T) = t.input
@inline func_hook(t::$T) = t.func

output_hook(t::$T) = t.output
@inline input_hook(t::$T) = t.input

@inline output_hook(t::$T) = t.output

forward_pass!(t::$T) = forward_pass!(t.tape)

Expand All @@ -54,17 +56,41 @@ immutable CompiledTape{S,T<:AbstractTape} <: AbstractTape
tape::T
end

(::Type{CompiledTape{S}}{S,T<:AbstractTape}(t::T) = CompiledTape{S,T}(t)
(::Type{CompiledTape{S}}){S,T<:AbstractTape}(t::T) = CompiledTape{S,T}(t)

Base.show{S}(io::IO, t::CompiledTape{S}) = print(io, typeof(t).name, "{$S}($(t.tape.func))")

typealias CompiledGradient{S,T<:GradientTape} CompiledTape{S,T}
typealias CompiledJacobian{S,T<:JacobianTape} CompiledTape{S,T}
typealias CompiledHessian{S,T<:HessianTape} CompiledTape{S,T}

Base.length(ct::CompiledTape) = length(ct.tape)

input_hook(ct::CompiledTape) = input_hook(ct.tape)
@inline func_hook(ct::CompiledTape) = func_hook(ct.tape)

@inline input_hook(ct::CompiledTape) = input_hook(ct.tape)

output_hook(ct::CompiledTape) = output_hook(ct.tape)
@inline output_hook(ct::CompiledTape) = output_hook(ct.tape)

function generate_forward_pass_method{T}(::Type{T}, tape::RawTape)
body = Expr(:block)
push!(body.args, :(tape = compiled_tape.tape.tape))
for i in 1:length(tape)
push!(body.args, :(ReverseDiff.forward_exec!(tape[$i]::$(typeof(tape[i])))))
end
push!(body.args, :(return nothing))
return :(ReverseDiff.forward_pass!(compiled_tape::$T) = $body)
end

function generate_reverse_pass_method{T}(::Type{T}, tape::RawTape)
body = Expr(:block)
push!(body.args, :(tape = compiled_tape.tape.tape))
for i in length(tape):-1:1
push!(body.args, :(ReverseDiff.reverse_exec!(tape[$i]::$(typeof(tape[i])))))
end
push!(body.args, :(return nothing))
return :(ReverseDiff.reverse_pass!(compiled_tape::$T) = $body)
end

"""
ReverseDiff.compile(t::AbstractTape)
Expand All @@ -88,8 +114,8 @@ function compile(t::AbstractTape)
end

function compile(ct::CompiledTape)
eval(ReverseDiff, generate_forward_pass_method(typeof(ct), ct.tape))
eval(ReverseDiff, generate_reverse_pass_method(typeof(ct), ct.tape))
eval(ReverseDiff, generate_forward_pass_method(typeof(ct), ct.tape.tape))
eval(ReverseDiff, generate_reverse_pass_method(typeof(ct), ct.tape.tape))
return ct
end

Expand Down
22 changes: 0 additions & 22 deletions src/tape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,6 @@ end
@noinline reverse_exec!(instruction::ScalarInstruction) = scalar_reverse_exec!(instruction)
@noinline reverse_exec!(instruction::SpecialInstruction) = special_reverse_exec!(instruction)

####################
# pass compilation #
####################

function generate_forward_pass_method{T}(::Type{T}, tape::RawTape)
body = Expr(:block)
for i in 1:length(tape)
push!(body.args, :(ReverseDiff.forward_exec!(tape[$i]::$(typeof(tape[i])))))
end
push!(body.args, :(return nothing))
return :(ReverseDiff.forward_pass!(tape::$T) = $body)
end

function generate_reverse_pass_method{T}(::Type{T}, tape::RawTape)
body = Expr(:block)
for i in length(tape):-1:1
push!(body.args, :(ReverseDiff.reverse_exec!(tape[$i]::$(typeof(tape[i])))))
end
push!(body.args, :(return nothing))
return :(ReverseDiff.reverse_pass!(tape::$T) = $body)
end

###################
# Pretty Printing #
###################
Expand Down

0 comments on commit d3e4b63

Please sign in to comment.