Skip to content

AD with Tracker.jl versus Zygote.jl and Mooncake.jl. For BNN. #2454

@deveshjawla

Description

@deveshjawla

Can someone help me understand what's causing the Tracker.jl to fail with "linking"? And how could Zygote match the performance of Tracker when "Standard" during the TuringBenchmarking?

I get the following outputs.

┌ Warning: Gradient computation (with linking) failed for AutoTracker(): MethodError(copyto!, (0.0, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}}(+, (0.0, -0.1371315395009085))), 0x0000000000006a89)
└ @ TuringBenchmarking ~/.julia/packages/TuringBenchmarking/fc6o7/src/TuringBenchmarking.jl:243
"gradient" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "AutoMooncake{Mooncake.Config}(Mooncake.Config(false, false))" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: ["AutoMooncake{Mooncake.Config}(Mooncake.Config(false, false))"]
                  "linked" => Trial(531.125 μs)
                  "standard" => Trial(532.459 μs)
          "AutoZygote()" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: ["Zygote"]
                  "linked" => Trial(2.382 ms)
                  "standard" => Trial(2.252 ms)
          "AutoTracker()" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: ["Tracker"]
                  "linked" => 0-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                  "standard" => Trial(99.125 μs)

MWE:

using Lux, Zygote, Tracker, Mooncake, Turing, Random, TuringBenchmarking, Functors
nn = Chain(Dense(10, 5, relu), Dense(5, 1, use_bias=false))
rng = Xoshiro(0)
ps, st = Lux.setup(rng, nn)
num_params = Lux.parameterlength(nn) # number of parameters in NN
const model = StatefulLuxLayer{true}(nn, nothing, st)

function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i+length(x)-1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end

@model function BNN(x, y, num_p)
    θ_p ~ MvNormal(zeros(num_p), ones(num_p))

    preds = Lux.apply(model, x, vector_to_parameters(θ_p, ps))

    sigma ~ Gamma(0.1, 1.0) # Prior for the variance
 
    y[:] ~ Product(Normal.(vec(preds), sigma))
end

benchmark_result = benchmark_model(BNN(randn(10,10), randn(1,10), num_params), adbackends=[AutoZygote(), AutoTracker(), AutoMooncake(; config=Mooncake.Config(; debug_mode=false))]) 
  1. Mooncake works well on other problems too.
  2. Zygote is slow compared to both. Would be good to bring its speed upto Tracker.
  3. Tracker is the fastest but failed with "linked"

Any help is appreciated. Thank you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    user issueIssues or questions raised by, or actively affecting, users

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions