diff --git a/src/Compiler.jl b/src/Compiler.jl index 278eab497c..7ecbacda92 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1572,6 +1572,7 @@ function compile_mlir!( f, args, compile_options::CompileOptions, + elem_apply_cache=default_elem_apply_cache(), callcache=default_callcache(), sdycache=default_sdycache(), sdygroupidcache=default_sdygroupidcache(); @@ -1592,6 +1593,7 @@ function compile_mlir!( activate_callcache!(callcache) activate_sdycache!(sdycache) activate_sdygroupidcache!(sdygroupidcache) + activate_elem_apply_cache!(elem_apply_cache) # Save in the TLS whether we are raising. We identify that condition by # checking whether the user set an explicit list of passes, or chose @@ -1620,6 +1622,7 @@ function compile_mlir!( deactivate_sdycache!(sdycache) deactivate_sdygroupidcache!(sdygroupidcache) deactivate_callcache!(callcache) + deactivate_elem_apply_cache!(elem_apply_cache) MLIR.IR.deactivate!(MLIR.IR.body(mod)) MLIR.IR.deactivate!(mod) end @@ -3832,7 +3835,7 @@ function register_thunk( ) end -for cache_type in (:callcache, :sdycache, :sdygroupidcache) +for cache_type in (:callcache, :sdycache, :sdygroupidcache, :elem_apply_cache) activate_fn = Symbol(:activate_, cache_type, :!) deactivate_fn = Symbol(:deactivate_, cache_type, :!) has_fn = Symbol(:_has_, cache_type) @@ -3904,4 +3907,20 @@ function default_callcache() }() end +function default_elem_apply_cache() + return Dict{ + Vector, + @NamedTuple{ + f_name::String, + result::Any, + seen_args::OrderedIdDict, + linear_args::Vector, + linear_results::Vector{Reactant.TracedType}, + fnwrapped::Bool, + argprefix::Symbol, + resprefix::Symbol, + } + }() end + +end \ No newline at end of file diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 260e81e714..2aacc17d53 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -1145,24 +1145,58 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} return elem_apply_via_while_loop(f, args...) end - argprefix::Symbol = gensym("broadcastarg") - resprefix::Symbol = gensym("broadcastresult") - resargprefix::Symbol = gensym("broadcastresarg") + seen = Reactant.OrderedIdDict() + cache_key = [] + Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes) + cache = Reactant.Compiler.elem_apply_cache() + if !haskey(cache, cache_key) + argprefix::Symbol = gensym("broadcastarg") + resprefix::Symbol = gensym("broadcastresult") + resargprefix::Symbol = gensym("broadcastresarg") + + mlir_fn_res = make_mlir_fn( + f, + args, + (), + string(f) * "_broadcast_scalar", + false; + toscalar=true, + argprefix, + resprefix, + resargprefix, + ) + (; fnwrapped, result, seen_args, linear_args, linear_results) = mlir_fn_res + - mlir_fn_res = make_mlir_fn( - f, - args, - (), - string(f) * "_broadcast_scalar", - false; - toscalar=true, - argprefix, - resprefix, - resargprefix, - ) - fnwrap = mlir_fn_res.fnwrapped - func2 = mlir_fn_res.f - (; result, seen_args, linear_args, linear_results) = mlir_fn_res + + func2 = mlir_fn_res.f + f_name = Base.String(get_attribute_by_name(func2, "sym_name")) + func2.operation = MLIR.API.MlirOperation(C_NULL) + + cache[cache_key] = (; + f_name, + result, + seen_args, + linear_args, + linear_results, + fnwrapped, + argprefix, + resprefix, + ) + else + (; + f_name, + result, + seen_args, + linear_args, + linear_results, + fnwrapped, + argprefix, + resprefix, + ) = cache[cache_key] + end + + f_name = MLIR.IR.FlatSymbolRefAttribute(f_name) invmap = IdDict() for (k, v) in seen_args @@ -1182,17 +1216,14 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} ) for arg in linear_results ] - fname = get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - batch_inputs = MLIR.IR.Value[] for a in linear_args idx, path = get_argidx(a, argprefix) - if idx == 1 && fnwrap + if idx == 1 && fnwrapped push_val!(batch_inputs, f, path[3:end]) else - if fnwrap + if fnwrapped idx -= 1 end push_val!(batch_inputs, args[idx], path[3:end]) @@ -1202,7 +1233,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} res = MLIR.Dialects.enzyme.batch( batch_inputs; outputs=out_tys2, - fn=fname, + fn=f_name, batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]), ) @@ -1219,10 +1250,10 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} set!(result, path[2:end], resv) elseif path[1] == argprefix idx = path[2]::Int - if idx == 1 && fnwrap + if idx == 1 && fnwrapped set!(f, path[3:end], resv) else - if fnwrap + if fnwrapped idx -= 1 end set!(args[idx], path[3:end], resv) @@ -1236,8 +1267,6 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape ) - func2.operation = MLIR.API.MlirOperation(C_NULL) - return traced2_result end