Skip to content

Commit

Permalink
Merge d8398fa into 71c5ac0
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Nov 19, 2020
2 parents 71c5ac0 + d8398fa commit 7d5cbbb
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/derivatives/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ function get_implementation(bc, f, T, args)
end
function Base.copy(_bc::Broadcasted{TrackedStyle})
bc = remove_not_tracked(_bc)
flattened_bc = Broadcast.flatten(bc)
flattened_bc = Base.Broadcast.flatten(bc)
untracked_bc = broadcast_rebuild(bc)
flattened_untracked_bc = Broadcast.flatten(untracked_bc)
T = Core.Compiler.return_type(copy, Tuple{typeof(untracked_bc)})
f, args = flattened_untracked_bc.f, flattened_bc.args
f, args = flattened_bc.f, flattened_bc.args
implementation = get_implementation(_bc, f, T, args)
if implementation isa Val{:reversediff}
return ∇broadcast(f, args...)
elseif implementation isa Val{:tracker}
return tracker_∇broadcast(f, args...)
else
flattened_untracked_bc = Base.Broadcast.flatten(untracked_bc)
style, axes = getstyle(flattened_untracked_bc), flattened_bc.axes
return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes))
end
Expand Down

0 comments on commit 7d5cbbb

Please sign in to comment.