Skip to content

Commit

Permalink
Improve recursion performance (#1439)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 13, 2024
1 parent c65a2f2 commit f4acb2b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
23 changes: 19 additions & 4 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ end
function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String})
calls = LLVM.CallInst[]

hasUser = false
for u in LLVM.uses(fn)
un = LLVM.user(u)

Expand All @@ -862,13 +863,11 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String})

# Something with a user is not permitted
for u2 in LLVM.uses(un)
return false
hasUser = true
break
end
push!(calls, un)
end
if length(calls) == 0
return false
end

done = Set{LLVM.Function}()
todo = LLVM.Function[fn]
Expand Down Expand Up @@ -909,6 +908,22 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String})
end
end
end

changed = false
attrs = collect(function_attributes(fn))
if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs)
if any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in attrs)
delete!(function_attributes(fn), EnumAttribute("writeonly"))
push!(function_attributes(fn), EnumAttribute("readnone"))
else
push!(function_attributes(fn), EnumAttribute("readonly"))
end
changed = true
end

if length(calls) == 0 || hasUser
return changed
end

for c in calls
parentf = LLVM.parent(LLVM.parent(c))
Expand Down
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,27 @@ end
end
end

sumsq2(x) = sum(abs2, x)
sumsin(x) = sum(sin, x)
@testset "Recursion optimization" begin
# Test that we can successfully optimize out the augmented primal from the recursive divide and conquer
fn = sprint() do io
Enzyme.Compiler.enzyme_code_llvm(io, sum, Active, Tuple{Duplicated{Vector{Float64}}})
end
@test occursin("diffe",fn)
@test !occursin("aug",fn)

fn = sprint() do io
Enzyme.Compiler.enzyme_code_llvm(io, sumsq2, Active, Tuple{Duplicated{Vector{Float64}}})
end
@test occursin("diffe",fn)
@test !occursin("aug",fn)

fn = sprint() do io
Enzyme.Compiler.enzyme_code_llvm(io, sumsin, Active, Tuple{Duplicated{Vector{Float64}}})
end
@test occursin("diffe",fn)
@test !occursin("aug",fn)
end

# @testset "Split Tape" begin
Expand Down

0 comments on commit f4acb2b

Please sign in to comment.