-
Notifications
You must be signed in to change notification settings - Fork 231
Closed
Labels
user issueIssues or questions raised by, or actively affecting, usersIssues or questions raised by, or actively affecting, users
Description
Can someone help me understand what's causing the Tracker.jl to fail with "linking"? And how could Zygote match the performance of Tracker when "Standard" during the TuringBenchmarking?
I get the following outputs.
┌ Warning: Gradient computation (with linking) failed for AutoTracker(): MethodError(copyto!, (0.0, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}}(+, (0.0, -0.1371315395009085))), 0x0000000000006a89)
└ @ TuringBenchmarking ~/.julia/packages/TuringBenchmarking/fc6o7/src/TuringBenchmarking.jl:243
"gradient" => 3-element BenchmarkTools.BenchmarkGroup:
tags: []
"AutoMooncake{Mooncake.Config}(Mooncake.Config(false, false))" => 2-element BenchmarkTools.BenchmarkGroup:
tags: ["AutoMooncake{Mooncake.Config}(Mooncake.Config(false, false))"]
"linked" => Trial(531.125 μs)
"standard" => Trial(532.459 μs)
"AutoZygote()" => 2-element BenchmarkTools.BenchmarkGroup:
tags: ["Zygote"]
"linked" => Trial(2.382 ms)
"standard" => Trial(2.252 ms)
"AutoTracker()" => 2-element BenchmarkTools.BenchmarkGroup:
tags: ["Tracker"]
"linked" => 0-element BenchmarkTools.BenchmarkGroup:
tags: []
"standard" => Trial(99.125 μs)
MWE:
using Lux, Zygote, Tracker, Mooncake, Turing, Random, TuringBenchmarking, Functors
nn = Chain(Dense(10, 5, relu), Dense(5, 1, use_bias=false))
rng = Xoshiro(0)
ps, st = Lux.setup(rng, nn)
num_params = Lux.parameterlength(nn) # number of parameters in NN
const model = StatefulLuxLayer{true}(nn, nothing, st)
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i+length(x)-1)), size(x))
i += length(x)
return z
end
return fmap(get_ps, ps)
end
@model function BNN(x, y, num_p)
θ_p ~ MvNormal(zeros(num_p), ones(num_p))
preds = Lux.apply(model, x, vector_to_parameters(θ_p, ps))
sigma ~ Gamma(0.1, 1.0) # Prior for the variance
y[:] ~ Product(Normal.(vec(preds), sigma))
end
benchmark_result = benchmark_model(BNN(randn(10,10), randn(1,10), num_params), adbackends=[AutoZygote(), AutoTracker(), AutoMooncake(; config=Mooncake.Config(; debug_mode=false))])
- Mooncake works well on other problems too.
- Zygote is slow compared to both. Would be good to bring its speed upto Tracker.
- Tracker is the fastest but failed with "linked"
Any help is appreciated. Thank you.
patrickm663
Metadata
Metadata
Assignees
Labels
user issueIssues or questions raised by, or actively affecting, usersIssues or questions raised by, or actively affecting, users