diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 7065319d71..449f78de38 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -312,7 +312,9 @@ function overload_autodiff( act = act_from_type(A, reverse, needs_primal(CMode)) push!(ret_activity, act) if act == enzyme_out || act == enzyme_outnoneed - attr = fill(MLIR.IR.Attribute(unwrapped_eltype(a)(1)), Ops.mlir_type(a)) + attr = MLIR.IR.DenseElementsAttribute( + fill(one(unwrapped_eltype(a)), size(a)) + ) cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) push!(ad_inputs, cst) end diff --git a/test/autodiff.jl b/test/autodiff.jl index 0b759db0b4..21e5525eb8 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -131,3 +131,15 @@ end res2 = @jit ddf(x) @test res2 ≈ 4 * 3 * 3.1^2 end + +@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin + a = ones(ComplexF64, 2, 2) + b = 2.0 * ones(ComplexF64, 2, 2) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y) + res = @jit df(a_re, b_re) # before, this segfaulted + @test res.val ≈ 4ones(2, 2) + @test res.derivs[1] ≈ 4ones(2, 2) + @test res.derivs[2] ≈ 2ones(2, 2) +end