Skip to content

Commit

Permalink
Merge 81f9fb3 into 09e7471
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Mar 15, 2017
2 parents 09e7471 + 81f9fb3 commit 4241f08
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/api/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@ function seeded_reverse_pass!(result::AbstractArray, output::AbstractArray, inpu
end

function seeded_reverse_pass!(output::TrackedReal, input::TrackedReal, tape)
seed!(output)
unseed!(input)
seed!(output)
reverse_pass!(tape)
unseed!(output)
return deriv(input)
end

# gradient (input is an array, output is a scalar) #
#--------------------------------------------------#

function seeded_reverse_pass!(result, output::TrackedReal, input, tape)
seed!(output)
unseed!(input)
seed!(output)
reverse_pass!(tape)
extract_result!(result, output, input)
return result
Expand All @@ -44,8 +43,8 @@ function seeded_reverse_pass!(result::AbstractArray, output::AbstractArray, inpu
result_matrix = reshape(result, length(output), length(input))
input_deriv = deriv(input)
for i in eachindex(output)
seed!(output, i)
unseed!(input)
seed!(output, i)
reverse_pass!(tape)
for j in eachindex(input)
result_matrix[i, j] = input_deriv[j]
Expand Down

0 comments on commit 4241f08

Please sign in to comment.