Skip to content

Commit

Permalink
fixed tests and removed verbose=true as the default value
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Sep 16, 2023
1 parent 3d97e77 commit 9b2e59e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
9 changes: 4 additions & 5 deletions src/TuringBenchmarking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ export benchmark_model, make_turing_suite, @tagged

# Don't include `TrackerAD` because it's never going to win.
const DEFAULT_ADBACKENDS = [
ForwardDiffAD{}(Turing.Essential.CHUNKSIZE[]), # chunksize=40
ZygoteAD(),
ForwardDiffAD{Turing.Essential.CHUNKSIZE[]}(), # chunksize=40
ReverseDiffAD{false}(), # rdcache=false
ReverseDiffAD{true}() # rdcache=false
ReverseDiffAD{true}(), # rdcache=false
ZygoteAD(),
]

backend_label(::ForwardDiffAD) = "ForwardDiff"
Expand Down Expand Up @@ -67,7 +67,6 @@ function benchmark_model(
varinfo::DynamicPPL.AbstractVarInfo = DynamicPPL.VarInfo(model),
sampler::Union{AbstractMCMC.AbstractSampler,Nothing} = nothing,
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext(),
verbose=true,
kwargs...
)
suite = make_turing_suite(
Expand All @@ -80,7 +79,7 @@ function benchmark_model(
context,
kwargs...
)
return run(suite; verbose=verbose, kwargs...)
return run(suite; kwargs...)
end

"""
Expand Down
13 changes: 4 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@ BenchmarkTools.DEFAULT_PARAMETERS.evals = 1
BenchmarkTools.DEFAULT_PARAMETERS.samples = 2

# These should be ordered (ascendingly) by runtime.
ADBACKENDS = [
TuringBenchmarking.ForwardDiffAD{40}(),
TuringBenchmarking.ReverseDiffAD{true}(),
TuringBenchmarking.ReverseDiffAD{false}(),
TuringBenchmarking.ZygoteAD(),
]
ADBACKENDS = TuringBenchmarking.DEFAULT_ADBACKENDS

@testset "TuringBenchmarking.jl" begin
@testset "Item-Response model" begin
Expand Down Expand Up @@ -68,7 +63,7 @@ ADBACKENDS = [
)
results = run(suite, verbose=true)

for (i, adbackend) in enumerate(ADBACKENDS)
@testset "$adbackend" for (i, adbackend) in enumerate(ADBACKENDS)
adbackend_string = "$(adbackend)"
results_backend = results[@tagged adbackend_string]
# Each AD backend should have two results.
Expand All @@ -90,7 +85,7 @@ ADBACKENDS = [
)
results = run(suite, verbose=true)

for (i, adbackend) in enumerate(ADBACKENDS)
@testset "$adbackend" for (i, adbackend) in enumerate(ADBACKENDS)
adbackend_string = "$(adbackend)"
results_backend = results[@tagged adbackend_string]
# Each AD backend should have two results.
Expand Down Expand Up @@ -127,7 +122,7 @@ ADBACKENDS = [
)
results = run(suite, verbose=true)

for (i, adbackend) in enumerate(ADBACKENDS)
@testset "$adbackend" for (i, adbackend) in enumerate(ADBACKENDS)
adbackend_string = "$(adbackend)"
results_backend = results[@tagged adbackend_string]
if adbackend isa TuringBenchmarking.ZygoteAD
Expand Down

0 comments on commit 9b2e59e

Please sign in to comment.