Skip to content

Forward mode AD fails on vector-output functions #1749

@jacobleft

Description

@jacobleft

Description

When compiling with Reactant @jit, Enzyme.autodiff succeeds for a reduction (sum(abs2, x)) but fails for an elementwise square (x.^2). The non-compiled path works for both.

Minimal Repro

using Lux, Enzyme, Reactant, Random

sum_func(x) = sum(abs2, x)
square_func(x) = x.^2

x = collect(Float32, 1:4)
x_onehot = Enzyme.onehot(x)

x_ra = Reactant.to_rarray(x)
x_ra_onehot = @jit Enzyme.onehot(x_ra)

# Uncompiled: both succeed
Enzyme.autodiff(Forward, sum_func, BatchDuplicated(x, x_onehot))
Enzyme.autodiff(Forward, square_func, BatchDuplicated(x, x_onehot))

# Compiled:
@jit Enzyme.autodiff(Forward, sum_func, BatchDuplicated(x_ra, x_ra_onehot))   # works
@jit Enzyme.autodiff(Forward, square_func, BatchDuplicated(x_ra, x_ra_onehot)) # fails

Actual Behavior

Fails at compile time with:

ERROR: AssertionError: Invalid start indices: [0, -1]
Stacktrace:
  [1] slice(x::Reactant.TracedRArray{…}, start_indices::Vector{…}, limit_indices::Vector{…}; strides::Nothing, location::Reactant.MLIR.IR.Location)
    @ Reactant.Ops ~/.julia/packages/Reactant/OPcO3/src/Ops.jl:632
  [2] overload_autodiff(::ForwardMode{…}, f::Const{…}, ::Type{…}, args::BatchDuplicated{…})
    @ Reactant ~/.julia/packages/Reactant/OPcO3/src/Enzyme.jl:481
  [3] autodiff(rmode::ForwardMode{…}, f::Const{…}, rt::Type{…}, args::BatchDuplicated{…})
    @ Reactant ~/.julia/packages/Reactant/OPcO3/src/Overlay.jl:21
  [4] autodiff
    @ ~/.julia/packages/Enzyme/LJjsP/src/Enzyme.jl:562 [inlined]
  [5] autodiff
    @ ~/.julia/packages/Enzyme/LJjsP/src/Enzyme.jl:534 [inlined]
  [6] (::Nothing)(none::typeof(autodiff), none::ForwardMode{…}, none::typeof(square_func), none::Tuple{…})
    @ Reactant ./<missing>:0
  [7] autodiff
    @ ~/.julia/packages/Enzyme/LJjsP/src/Enzyme.jl:534 [inlined]
  [8] call_with_reactant(::typeof(autodiff), ::ForwardMode{…}, ::typeof(square_func), ::BatchDuplicated{…})
    @ Reactant ~/.julia/packages/Reactant/OPcO3/src/utils.jl:0
  [9] make_mlir_fn(f::typeof(autodiff), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/OPcO3/src/TracedUtils.jl:348
 [10] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:1603
 [11] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:1570 [inlined]
 [12] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:3492
 [13] compile_xla
    @ ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:3465 [inlined]
 [14] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:3567
 [15] top-level scope
    @ ~/.julia/packages/Reactant/OPcO3/src/Compiler.jl:2642
Some type information was truncated. Use `show(err)` to see complete types.

Environment

julia> versioninfo()
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900K
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 32 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_VSCODE_REPL = 1
(jl_S3gBRa) pkg> st
Status `/tmp/jl_S3gBRa/Project.toml`
  [7da242da] Enzyme v0.13.86
  [b2108857] Lux v1.23.0
  [3c362404] Reactant v0.2.170
  [9a3f8284] Random v1.11.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions