Skip to content

Commit

Permalink
correct the seeding order for seeded_reverse_pass! functions (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Mar 21, 2017
1 parent 09e7471 commit b1496d0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions src/api/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@ function seeded_reverse_pass!(result::AbstractArray, output::AbstractArray, inpu
end

function seeded_reverse_pass!(output::TrackedReal, input::TrackedReal, tape)
seed!(output)
pull_value!(output)
unseed!(input)
seed!(output)
reverse_pass!(tape)
unseed!(output)
return deriv(input)
end

# gradient (input is an array, output is a scalar) #
#--------------------------------------------------#

function seeded_reverse_pass!(result, output::TrackedReal, input, tape)
seed!(output)
pull_value!(output)
unseed!(input)
seed!(output)
reverse_pass!(tape)
extract_result!(result, output, input)
return result
Expand All @@ -43,9 +44,10 @@ end
function seeded_reverse_pass!(result::AbstractArray, output::AbstractArray, input::TrackedArray, tape)
result_matrix = reshape(result, length(output), length(input))
input_deriv = deriv(input)
pull_value!(output)
for i in eachindex(output)
seed!(output, i)
unseed!(input)
seed!(output, i)
reverse_pass!(tape)
for j in eachindex(input)
result_matrix[i, j] = input_deriv[j]
Expand Down
2 changes: 1 addition & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ srand(1)

test_println(kind, f, pad = " ") = println(pad, "testing $(kind): `$(f)`...")

test_approx(A, B) = @test isapprox(A, B, atol = 1e-5)
@inline test_approx(A, B) = @test isapprox(A, B, atol = 1e-5)

tracked_is(a, b) = value(a) === value(b) && deriv(a) === deriv(b) && tape(a) === tape(b)
tracked_is(a::AbstractArray, b::AbstractArray) = all(map(tracked_is, a, b))
Expand Down

0 comments on commit b1496d0

Please sign in to comment.