diff --git a/src/utils.jl b/src/utils.jl index 83c9d51b64..aa00b8b20d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -95,13 +95,16 @@ function should_rewrite_ft(@nospecialize(ft)) return false end if ft <: Core.Function - mod = ft.name.module - # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions - if has_ancestor(mod, Reactant.Ops) || - has_ancestor(mod, Reactant.TracedUtils) || - has_ancestor(mod, Reactant.MLIR) || - has_ancestor(mod, Reactant.TracedRandom) - return false + # We need this for closures to work + if hasfield(typeof(ft), :name) && hasfield(typeof(ft.name), :module) + mod = ft.name.module + # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions + if has_ancestor(mod, Reactant.Ops) || + has_ancestor(mod, Reactant.TracedUtils) || + has_ancestor(mod, Reactant.MLIR) || + has_ancestor(mod, Reactant.TracedRandom) + return false + end end end # Don't rewrite Val diff --git a/test/compile.jl b/test/compile.jl index f50211b011..1f780a4ef8 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -127,3 +127,22 @@ end @test !occursin("subtract", repr(hlo)) @test !occursin("add", repr(hlo)) end + +# While a bit specific, the following is used to check for a bug in `should_rewrite_ft` +function sinusoidal_embedding( + x::AbstractArray{T,4}, min_freq, max_freq, embedding_dims::Int +) where {T} + if size(x)[1:3] != (1, 1, 1) + throw(DimensionMismatch("Input shape must be (1, 1, 1, batch)")) + end + + lower, upper = log(T(min_freq)), log(T(max_freq)) + n = embedding_dims รท 2 + x_ = 2 .* x .* exp.(reshape(range(lower, upper; length=n), 1, 1, n, 1)) + return cat(sinpi.(x_), cospi.(x_); dims=Val(3)) +end + +@testset "sinusoidal_embedding" begin + x_ra = Reactant.to_rarray(rand(Float32, 1, 1, 1, 4)) + hlo = @code_hlo sinusoidal_embedding(x_ra, 0.1, 10.0, 4) +end