From 3bd93fe986c2249657c7ac9f3cada09d49a0f1d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 24 Jan 2025 07:46:54 +0100 Subject: [PATCH 1/2] Fix #593 --- src/Interpreter.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 7065319d71..449f78de38 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -312,7 +312,9 @@ function overload_autodiff( act = act_from_type(A, reverse, needs_primal(CMode)) push!(ret_activity, act) if act == enzyme_out || act == enzyme_outnoneed - attr = fill(MLIR.IR.Attribute(unwrapped_eltype(a)(1)), Ops.mlir_type(a)) + attr = MLIR.IR.DenseElementsAttribute( + fill(one(unwrapped_eltype(a)), size(a)) + ) cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) push!(ad_inputs, cst) end From bc0a683eb5cafa37189bb22a974a7fe7ca77429e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 24 Jan 2025 11:18:01 +0100 Subject: [PATCH 2/2] Add test --- test/autodiff.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/autodiff.jl b/test/autodiff.jl index 0b759db0b4..21e5525eb8 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -131,3 +131,15 @@ end res2 = @jit ddf(x) @test res2 ≈ 4 * 3 * 3.1^2 end + +@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin + a = ones(ComplexF64, 2, 2) + b = 2.0 * ones(ComplexF64, 2, 2) + a_re = Reactant.to_rarray(a) + b_re = Reactant.to_rarray(b) + df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y) + res = @jit df(a_re, b_re) # before, this segfaulted + @test res.val ≈ 4ones(2, 2) + @test res.derivs[1] ≈ 4ones(2, 2) + @test res.derivs[2] ≈ 2ones(2, 2) +end