Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,15 @@ function run(; to_json=false)
}[]

for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
@info "Running benchmark for $model_name"
@info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked"
relative_eval_time, relative_ad_eval_time = try
results = benchmark(model, varinfo_choice, adbackend, islinked)
@info " t(eval) = $(results.primal_time)"
@info " t(grad) = $(results.grad_time)"
(results.primal_time / reference_time),
(results.grad_time / results.primal_time)
catch e
@info "benchmark errored: $e"
missing, missing
end
push!(
Expand Down Expand Up @@ -155,18 +158,33 @@ function combine(head_filename::String, base_filename::String)
all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases)))
@info "$(length(all_testcases)) unique test cases found"
sorted_testcases = sort(
collect(all_testcases); by=(c -> (c.model_name, c.ad_backend, c.varinfo, c.linked))
collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend))
)
results_table = Tuple{
String,Int,String,String,Bool,String,String,String,String,String,String
String,
Int,
String,
String,
Bool,
String,
String,
String,
String,
String,
String,
String,
String,
String,
}[]
sublabels = ["base", "this PR", "speedup"]
results_colnames = [
[
EmptyCells(5),
MultiColumn(3, "t(eval) / t(ref)"),
MultiColumn(3, "t(grad) / t(eval)"),
MultiColumn(3, "t(grad) / t(ref)"),
],
[colnames[1:5]..., "base", "this PR", "speedup", "base", "this PR", "speedup"],
[colnames[1:5]..., sublabels..., sublabels..., sublabels...],
]
sprint_float(x::Float64) = @sprintf("%.2f", x)
sprint_float(m::Missing) = "err"
Expand All @@ -183,6 +201,10 @@ function combine(head_filename::String, base_filename::String)
# Finally that lets us do this division safely
speedup_eval = base_eval / head_eval
speedup_grad = base_grad / head_grad
# As well as this multiplication, which is t(grad) / t(ref)
head_grad_vs_ref = head_grad * head_eval
base_grad_vs_ref = base_grad * base_eval
speedup_grad_vs_ref = base_grad_vs_ref / head_grad_vs_ref
push!(
results_table,
(
Expand All @@ -197,6 +219,9 @@ function combine(head_filename::String, base_filename::String)
sprint_float(base_grad),
sprint_float(head_grad),
sprint_float(speedup_grad),
sprint_float(base_grad_vs_ref),
sprint_float(head_grad_vs_ref),
sprint_float(speedup_grad_vs_ref),
),
)
end
Expand All @@ -212,7 +237,10 @@ function combine(head_filename::String, base_filename::String)
backend=:text,
fit_table_in_display_horizontally=false,
fit_table_in_display_vertically=false,
table_format=TextTableFormat(; horizontal_line_at_merged_column_labels=true),
table_format=TextTableFormat(;
horizontal_line_at_merged_column_labels=true,
horizontal_lines_at_data_rows=collect(3:3:length(results_table)),
),
)
println("```")
end
Expand Down
10 changes: 9 additions & 1 deletion src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL:
Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link
DynamicPPL,
Model,
LogDensityFunction,
VarInfo,
AbstractVarInfo,
getlogjoint_internal,
link
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: AbstractRNG, default_rng
using Statistics: median
Expand Down Expand Up @@ -298,7 +304,9 @@ function run_ad(

# Benchmark
grad_time, primal_time = if benchmark
logdensity(ldf, params) # Warm-up
primal_benchmark = @be logdensity($ldf, $params)
logdensity_and_gradient(ldf, params) # Warm-up
grad_benchmark = @be logdensity_and_gradient($ldf, $params)
median_primal = median(primal_benchmark).time
median_grad = median(grad_benchmark).time
Expand Down