-
Notifications
You must be signed in to change notification settings - Fork 33
Closed
Description
Description
When compiling with Reactant @jit, Enzyme.autodiff succeeds for a reduction (sum(abs2, x)) but fails for an elementwise square (x.^2). The non-compiled path works for both.
Minimal Repro
using Lux, Enzyme, Reactant, Random
sum_func(x) = sum(abs2, x)
square_func(x) = x.^2
x = collect(Float32, 1:4)
x_onehot = Enzyme.onehot(x)
x_ra = Reactant.to_rarray(x)
x_ra_onehot = @jit Enzyme.onehot(x_ra)
# Uncompiled: both succeed
Enzyme.autodiff(Forward, sum_func, BatchDuplicated(x, x_onehot))
Enzyme.autodiff(Forward, square_func, BatchDuplicated(x, x_onehot))
# Compiled:
@jit Enzyme.autodiff(Forward, sum_func, BatchDuplicated(x_ra, x_ra_onehot)) # works
@jit Enzyme.autodiff(Forward, square_func, BatchDuplicated(x_ra, x_ra_onehot)) # failsActual Behavior
Fails at compile time with:
ERROR: AssertionError: Invalid start indices: [0, -1]
Stacktrace:
[1] slice(x::Reactant.TracedRArray{…}, start_indices::Vector{…}, limit_indices::Vector{…}; strides::Nothing, location::Reactant.MLIR.IR.Location)
@ Reactant.Ops ~/.julia/packages/Reactant/OPcO3/src/Ops.jl:632
[2] overload_autodiff(::ForwardMode{…}, f::Const{…}, ::Type{…}, args::BatchDuplicated{…})
@ Reactant ~/.julia/packages/Reactant/OPcO3/src/Enzyme.jl:481
[3] autodiff(rmode::ForwardMode{…}, f::Const{…}, rt::Type{…}, args::BatchDuplicated{…})
@ Reactant ~/.julia/packages/Reactant/OPcO3/src/Overlay.jl:21
[4] autodiff
@ ~/.julia/packages/Enzyme/LJjsP/src/Enzyme.jl:562 [inlined]
[5] autodiff
@ ~/.julia/packages/Enzyme/LJjsP/src/Enzyme.jl:534 [inlined]
[6] (::Nothing)(none::typeof(autodiff), none::ForwardMode{…}, none::typeof(square_func), none::Tuple{…})
@ Reactant ./<missing>:0
[7] autodiff
@ ~/.julia/packages/Enzyme/LJjsP/src/Enzyme.jl:534 [inlined]
[8] call_with_reactant(::typeof(autodiff), ::ForwardMode{…}, ::typeof(square_func), ::BatchDuplicated{…})
@ Reactant ~/.julia/packages/Reactant/OPcO3/src/utils.jl:0
[9] make_mlir_fn(f::typeof(autodiff), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/packages/Reactant/OPcO3/src/TracedUtils.jl:348
[10] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:1603
[11] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:1570 [inlined]
[12] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:3492
[13] compile_xla
@ ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:3465 [inlined]
[14] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:3567
[15] top-level scope
@ ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:2642
Some type information was truncated. Use `show(err)` to see complete types.
Environment
julia> versioninfo()
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900K
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 32 virtual cores)
Environment:
JULIA_EDITOR = code
JULIA_VSCODE_REPL = 1
(jl_S3gBRa) pkg> st
Status `/tmp/jl_S3gBRa/Project.toml`
[7da242da] Enzyme v0.13.86
[b2108857] Lux v1.23.0
[3c362404] Reactant v0.2.170
[9a3f8284] Random v1.11.0
Metadata
Metadata
Assignees
Labels
No labels