diff --git a/Project.toml b/Project.toml index 8c0a746108..d5e57ef820 100644 --- a/Project.toml +++ b/Project.toml @@ -62,7 +62,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.3" -Reactant_jll = "0.0.34" +Reactant_jll = "0.0.36" Scratch = "1.2" SpecialFunctions = "2" Statistics = "1.10" diff --git a/src/Compiler.jl b/src/Compiler.jl index 3c4c3996d4..bc4e232f28 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -112,11 +112,9 @@ function create_result( return Meta.quot(tocopy) end -const opt_passes::String = join( +# Optimization passes via transform dialect +const transform_passes::String = join( [ - "inline{default-pipeline=canonicalize max-iterations=4}", - "canonicalize,cse", - "canonicalize", "enzyme-hlo-generate-td{" * join( [ @@ -273,9 +271,22 @@ const opt_passes::String = join( "transform-interpreter", "enzyme-hlo-remove-transform", ], - ',', + ",", +) + +# Optimization passes which apply to an individual function +const func_passes::String = join( + ["canonicalize,cse", "canonicalize", transform_passes], "," ) +const opt_passes::String = join( + ["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ',' +) + +# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate +# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. +const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" + function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true) pm = MLIR.IR.PassManager() MLIR.IR.enable_verifier!(pm, enable_verifier) @@ -335,7 +346,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}" if optimize === :all run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -351,7 +364,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :before_kernel run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -381,7 +396,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :only_enzyme run_pass_pipeline!(mod, "enzyme-batch") - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -391,7 +408,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :after_enzyme run_pass_pipeline!(mod, "enzyme-batch") - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, join( @@ -407,7 +426,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) ) elseif optimize === :before_enzyme run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) - run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false + ) run_pass_pipeline!( mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern ) diff --git a/src/Reactant.jl b/src/Reactant.jl index 0919fd9aaf..039a717a93 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -213,7 +213,7 @@ include("Overlay.jl") function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) -)::RT where {copy_if_inactive,RT<:RArray} +)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}} if haskey(seen, prev) return seen[prev] end diff --git a/test/autodiff.jl b/test/autodiff.jl index 044799bcb1..0b759db0b4 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -120,3 +120,14 @@ end @test stret.st2 ≈ x .+ 1 @test stret.st1 === stret.st2 end + +@testset "Nested AD" begin + x = ConcreteRNumber(3.1) + f(x) = x * x * x * x + df(x) = Enzyme.gradient(Reverse, f, x)[1] + res1 = @jit df(x) + @test res1 ≈ 4 * 3.1^3 + ddf(x) = Enzyme.gradient(Reverse, df, x)[1] + res2 = @jit ddf(x) + @test res2 ≈ 4 * 3 * 3.1^2 +end