Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve recursion performance #1439

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ end
function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String})
calls = LLVM.CallInst[]

hasUser = false
for u in LLVM.uses(fn)
un = LLVM.user(u)

Expand All @@ -862,13 +863,11 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String})

# Something with a user is not permitted
for u2 in LLVM.uses(un)
return false
hasUser = true
break
end
push!(calls, un)
end
if length(calls) == 0
return false
end

done = Set{LLVM.Function}()
todo = LLVM.Function[fn]
Expand Down Expand Up @@ -909,6 +908,22 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String})
end
end
end

changed = false
attrs = collect(function_attributes(fn))
if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs)
if any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in attrs)
delete!(function_attributes(fn), EnumAttribute("writeonly"))
push!(function_attributes(fn), EnumAttribute("readnone"))
else
push!(function_attributes(fn), EnumAttribute("readonly"))
end
changed = true
end

if length(calls) == 0 || hasUser
return changed
end

for c in calls
parentf = LLVM.parent(LLVM.parent(c))
Expand Down
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,27 @@ end
end
end

sumsq2(x) = sum(abs2, x)
sumsin(x) = sum(sin, x)
@testset "Recursion optimization" begin
# Test that we can successfully optimize out the augmented primal from the recursive divide and conquer
fn = sprint() do io
Enzyme.Compiler.enzyme_code_llvm(io, sum, Active, Tuple{Duplicated{Vector{Float64}}})
end
@test occursin("diffe",fn)
@test !occursin("aug",fn)

fn = sprint() do io
Enzyme.Compiler.enzyme_code_llvm(io, sumsq2, Active, Tuple{Duplicated{Vector{Float64}}})
end
@test occursin("diffe",fn)
@test !occursin("aug",fn)

fn = sprint() do io
Enzyme.Compiler.enzyme_code_llvm(io, sumsin, Active, Tuple{Duplicated{Vector{Float64}}})
end
@test occursin("diffe",fn)
@test !occursin("aug",fn)
end

# @testset "Split Tape" begin
Expand Down
Loading