From a88d844d9ff70f5d32b79725f9ac722758e0b506 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 1 Jan 2025 17:17:55 -0500 Subject: [PATCH 01/11] Pipeline for enzyme --- src/Compiler.jl | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 3c4c3996d4..18ae00e3fa 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -112,9 +112,9 @@ function create_result( return Meta.quot(tocopy) end -const opt_passes::String = join( +# Optimization passes which apply to an individual function +const func_passes::String = join( [ - "inline{default-pipeline=canonicalize max-iterations=4}", "canonicalize,cse", "canonicalize", "enzyme-hlo-generate-td{" * @@ -273,9 +273,19 @@ const opt_passes::String = join( "transform-interpreter", "enzyme-hlo-remove-transform", ], + "," +) + +const opt_passes::String = join( + [ + "inline{default-pipeline=canonicalize max-iterations=4}", + func_passes + ], ',', ) +const enzyme_pass::String = "enzyme{postpasses=\"canonicalize,$func_passes,remove-unnecessary-enzyme-ops,enzyme-simplify-math,$func_passes\"}" + function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) pm = MLIR.IR.PassManager() MLIR.IR.enable_verifier!(pm, enable_verifier) @@ -335,7 +345,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}" if optimize === :all run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( mod, join( @@ -351,7 +361,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :before_kernel run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( mod, join( @@ -381,7 +391,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :only_enzyme run_pass_pipeline!(mod, "enzyme-batch") - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( mod, join( @@ -391,7 +401,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :after_enzyme run_pass_pipeline!(mod, "enzyme-batch") - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( mod, join( @@ -407,7 +417,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :before_enzyme run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern ) From e21ae2f8220853325cfe12c29d912b4346fd554d Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 20:02:05 -0500 Subject: [PATCH 02/11] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Compiler.jl | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 18ae00e3fa..44ee96ba9e 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -273,15 +273,11 @@ const func_passes::String = join( "transform-interpreter", "enzyme-hlo-remove-transform", ], - "," + ",", ) const opt_passes::String = join( - [ - "inline{default-pipeline=canonicalize max-iterations=4}", - func_passes - ], - ',', + ["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ',' ) const enzyme_pass::String = "enzyme{postpasses=\"canonicalize,$func_passes,remove-unnecessary-enzyme-ops,enzyme-simplify-math,$func_passes\"}" @@ -345,7 +341,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}" if optimize === :all run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -361,7 +359,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :before_kernel run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -391,7 +391,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :only_enzyme run_pass_pipeline!(mod, "enzyme-batch") - run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -401,7 +403,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :after_enzyme run_pass_pipeline!(mod, "enzyme-batch") - run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -417,7 +421,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :before_enzyme run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern ) From 67777f9fd118a8b06988db7dad96e43e5f413e66 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 1 Jan 2025 20:16:22 -0500 Subject: [PATCH 03/11] Nested AD --- Project.toml | 2 +- src/Compiler.jl | 20 +++++++++++++++----- test/autodiff.jl | 10 ++++++++++ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 8c0a746108..c7b25d4c33 100644 --- a/Project.toml +++ b/Project.toml @@ -62,7 +62,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.34" +Reactant_jll = "0.0.35" Scratch = "1.2" SpecialFunctions = "2" Statistics = "1.10" diff --git a/src/Compiler.jl b/src/Compiler.jl index 44ee96ba9e..633aafdcb7 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -112,11 +112,9 @@ function create_result( return Meta.quot(tocopy) end -# Optimization passes which apply to an individual function -const func_passes::String = join( +# Optimization passes via transform dialect +const transform_passes::String = join( [ - "canonicalize,cse", - "canonicalize", "enzyme-hlo-generate-td{" * join( [ @@ -272,6 +270,16 @@ const func_passes::String = join( "}", "transform-interpreter", "enzyme-hlo-remove-transform", + ], + "," +) + +# Optimization passes which apply to an individual function +const func_passes::String = join( + [ + "canonicalize,cse", + "canonicalize", + transform_passes ], ",", ) @@ -280,7 +288,9 @@ const opt_passes::String = join( ["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ',' ) -const enzyme_pass::String = "enzyme{postpasses=\"canonicalize,$func_passes,remove-unnecessary-enzyme-ops,enzyme-simplify-math,$func_passes\"}" +# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate +# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. +const enzyme_pass::String = "enzyme{postpasses=\"canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) pm = MLIR.IR.PassManager() diff --git a/test/autodiff.jl b/test/autodiff.jl index 044799bcb1..fe7501332f 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -120,3 +120,13 @@ end @test stret.st2 ≈ x .+ 1 @test stret.st1 === stret.st2 end + +@testset "Nested AD" begin + x = ConcreteRNumber(3.1) + f(x) = x^4 + df(x) = Enzyme.gradient(Reverse, f, x)[1] + @test @jit df(x) ≈ 4 * 3.1 ^ 3 + ddf(x) = Enzyme.gradient(Reverse, df, x)[1] + ddf(x) = Enzyme.gradient(Reverse, df, x)[1] + @test @jit ddf(x) ≈ 4 * 3 * 3.1 ^ 2 +end From 34dc7dfa92a886c682fd9835beb5923226deb71f Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 20:18:03 -0500 Subject: [PATCH 04/11] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Compiler.jl | 4 ++-- test/autodiff.jl | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 633aafdcb7..dd7fd19536 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -270,8 +270,8 @@ const transform_passes::String = join( "}", "transform-interpreter", "enzyme-hlo-remove-transform", - ], - "," + ], + ",", ) # Optimization passes which apply to an individual function diff --git a/test/autodiff.jl b/test/autodiff.jl index fe7501332f..7154229926 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -122,11 +122,11 @@ end end @testset "Nested AD" begin - x = ConcreteRNumber(3.1) - f(x) = x^4 - df(x) = Enzyme.gradient(Reverse, f, x)[1] - @test @jit df(x) ≈ 4 * 3.1 ^ 3 - ddf(x) = Enzyme.gradient(Reverse, df, x)[1] - ddf(x) = Enzyme.gradient(Reverse, df, x)[1] - @test @jit ddf(x) ≈ 4 * 3 * 3.1 ^ 2 + x = ConcreteRNumber(3.1) + f(x) = x^4 + df(x) = Enzyme.gradient(Reverse, f, x)[1] + @test @jit df(x) ≈ 4 * 3.1^3 + ddf(x) = Enzyme.gradient(Reverse, df, x)[1] + ddf(x) = Enzyme.gradient(Reverse, df, x)[1] + @test @jit ddf(x) ≈ 4 * 3 * 3.1^2 end From 8109df5cfe409ef80bf5763319234be9490fbf97 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 20:22:09 -0500 Subject: [PATCH 05/11] Update Compiler.jl --- src/Compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index dd7fd19536..69738c5b8b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -290,7 +290,7 @@ const opt_passes::String = join( # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. -const enzyme_pass::String = "enzyme{postpasses=\"canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" +const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) pm = MLIR.IR.PassManager() From a54ef86e06d06e08d357d025de7cde7d80e15cb3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 20:36:19 -0500 Subject: [PATCH 06/11] Update src/Compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Compiler.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 69738c5b8b..bc4e232f28 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -276,12 +276,7 @@ const transform_passes::String = join( # Optimization passes which apply to an individual function const func_passes::String = join( - [ - "canonicalize,cse", - "canonicalize", - transform_passes - ], - ",", + ["canonicalize,cse", "canonicalize", transform_passes], "," ) const opt_passes::String = join( From 42f2e2b7c9504118ca7bb1d0da6359fee00adbb9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 23:57:01 -0500 Subject: [PATCH 07/11] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c7b25d4c33..d5e57ef820 100644 --- a/Project.toml +++ b/Project.toml @@ -62,7 +62,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.35" +Reactant_jll = "0.0.36" Scratch = "1.2" SpecialFunctions = "2" Statistics = "1.10" From 976f083c26d0059aaa1f5a36abbfb7e93f966085 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 2 Jan 2025 00:22:54 -0500 Subject: [PATCH 08/11] Update autodiff.jl --- test/autodiff.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/autodiff.jl b/test/autodiff.jl index 7154229926..758a6caf46 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -123,10 +123,9 @@ end @testset "Nested AD" begin x = ConcreteRNumber(3.1) - f(x) = x^4 + f(x) = x*x*x*x df(x) = Enzyme.gradient(Reverse, f, x)[1] @test @jit df(x) ≈ 4 * 3.1^3 ddf(x) = Enzyme.gradient(Reverse, df, x)[1] - ddf(x) = Enzyme.gradient(Reverse, df, x)[1] @test @jit ddf(x) ≈ 4 * 3 * 3.1^2 end From 44f6f9dc39ce66e718d2aa9ea17c416082b4b3ee Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 2 Jan 2025 00:24:15 -0500 Subject: [PATCH 09/11] Update test/autodiff.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/autodiff.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/autodiff.jl b/test/autodiff.jl index 758a6caf46..93e92caa0b 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -123,7 +123,7 @@ end @testset "Nested AD" begin x = ConcreteRNumber(3.1) - f(x) = x*x*x*x + f(x) = x * x * x * x df(x) = Enzyme.gradient(Reverse, f, x)[1] @test @jit df(x) ≈ 4 * 3.1^3 ddf(x) = Enzyme.gradient(Reverse, df, x)[1] From 481064cb1a69e35b86026fad5a4b244b37dbd271 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 2 Jan 2025 00:36:13 -0500 Subject: [PATCH 10/11] Update autodiff.jl --- test/autodiff.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/autodiff.jl b/test/autodiff.jl index 93e92caa0b..0b759db0b4 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -125,7 +125,9 @@ end x = ConcreteRNumber(3.1) f(x) = x * x * x * x df(x) = Enzyme.gradient(Reverse, f, x)[1] - @test @jit df(x) ≈ 4 * 3.1^3 + res1 = @jit df(x) + @test res1 ≈ 4 * 3.1^3 ddf(x) = Enzyme.gradient(Reverse, df, x)[1] - @test @jit ddf(x) ≈ 4 * 3 * 3.1^2 + res2 = @jit ddf(x) + @test res2 ≈ 4 * 3 * 3.1^2 end From 64e2adfce7efd1ebfa1ecc7f79d2c54534db0939 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 2 Jan 2025 01:26:56 -0500 Subject: [PATCH 11/11] fixbug --- src/Reactant.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Reactant.jl b/src/Reactant.jl index 0919fd9aaf..039a717a93 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -213,7 +213,7 @@ include("Overlay.jl") function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) -)::RT where {copy_if_inactive,RT<:RArray} +)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}} if haskey(seen, prev) return seen[prev] end