diff --git a/src/Ops.jl b/src/Ops.jl index 82888608b4..204ac53d77 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -5,7 +5,13 @@ module Ops using ..MLIR: MLIR using ..MLIR.Dialects: stablehlo, chlo, enzyme using ..Reactant: - Reactant, ConcreteRArray, ConcreteRNumber, TracedRArray, TracedRNumber, mlir_type + Reactant, + ConcreteRArray, + ConcreteRNumber, + TracedRArray, + TracedRNumber, + mlir_type, + mlir_stacktrace struct Token mlir_data::MLIR.IR.Value @@ -13,10 +19,7 @@ end # constant ops function constant( - x::DenseArray{T,N}; - location=MLIR.IR.Location( - "stablehlo.constant", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + x::DenseArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T,N} value = MLIR.IR.DenseElementsAttribute(x) output = mlir_type(TracedRArray{T,N}, size(x)) @@ -29,20 +32,14 @@ function constant(x::ConcreteRArray; kwargs...) end function constant( - x::T; - location=MLIR.IR.Location( - "stablehlo.constant", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} res = constant(fill(x); location) return TracedRNumber{T}((), res.mlir_data) end function constant( - x::ConcreteRNumber{T}; - location=MLIR.IR.Location( - "stablehlo.constant", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + x::ConcreteRNumber{T}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T} output = mlir_type(TracedRArray{T,0}, ()) value = MLIR.IR.DenseElementsAttribute(fill(MLIR.IR.Attribute(convert(T, x)), output)) @@ -93,10 +90,7 @@ for (dialect, op) in [ @eval begin function $op( x::TracedRArray{T,N}; - location=MLIR.IR.Location( - $(string(Symbol(dialect, :., op))), - MLIR.IR.Location(@__FILE__, @__LINE__, 0), - ), + location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T,N} res = MLIR.IR.result( $(:($dialect.$op))( @@ -108,10 +102,7 @@ for (dialect, op) in [ function $op( x::TracedRNumber{T}; - location=MLIR.IR.Location( - $(string(Symbol(dialect, :., op))), - MLIR.IR.Location(@__FILE__, @__LINE__, 0), - ), + location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T} res = MLIR.IR.result( $(:($dialect.$op))( @@ -148,10 +139,7 @@ for (dialect, op) in [ function $op( a::TracedRArray{T,N}, b::TracedRArray{T,N}; - location=MLIR.IR.Location( - $(string(Symbol(dialect, :., op))), - MLIR.IR.Location(@__FILE__, @__LINE__, 0), - ), + location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T,N} res = MLIR.IR.result( $(:($dialect.$op))( @@ -167,10 +155,7 @@ for (dialect, op) in [ function $op( a::TracedRNumber{T}, b::TracedRNumber{T}; - location=MLIR.IR.Location( - $(string(Symbol(dialect, :., op))), - MLIR.IR.Location(@__FILE__, @__LINE__, 0), - ), + location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T} res = MLIR.IR.result( $(:($dialect.$op))( @@ -195,10 +180,7 @@ for (dialect, op) in [ @eval begin function $op( x::TracedRArray{T,N}; - location=MLIR.IR.Location( - $(string(Symbol(dialect, :., op))), - MLIR.IR.Location(@__FILE__, @__LINE__, 0), - ), + location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T,N} res = MLIR.IR.result( $(:($dialect.$op))( @@ -210,10 +192,7 @@ for (dialect, op) in [ function $op( x::TracedRNumber{T}; - location=MLIR.IR.Location( - $(string(Symbol(dialect, :., op))), - MLIR.IR.Location(@__FILE__, @__LINE__, 0), - ), + location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T} res = MLIR.IR.result( $(:($dialect.$op))( @@ -226,10 +205,7 @@ for (dialect, op) in [ end function is_finite( - x::TracedRArray{T,N}; - location=MLIR.IR.Location( - "stablehlo.is_finite", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + x::TracedRArray{T,N}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( stablehlo.is_finite( @@ -240,10 +216,7 @@ function is_finite( end function is_finite( - x::TracedRNumber{T}; - location=MLIR.IR.Location( - "stablehlo.is_finite", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + x::TracedRNumber{T}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( stablehlo.is_finite(x.mlir_data; y=mlir_type(TracedRArray{Bool,0}, ()), location) @@ -253,8 +226,7 @@ end # fixes to default automated implementations function abs( - x::TracedRArray{Complex{T},N}; - location=MLIR.IR.Location("stablehlo.abs", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("abs", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( stablehlo.abs(x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location) @@ -263,8 +235,7 @@ function abs( end function abs( - x::TracedRNumber{Complex{T}}; - location=MLIR.IR.Location("stablehlo.abs", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("abs", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( stablehlo.abs(x.mlir_data; result=mlir_type(TracedRArray{T,0}, ()), location) @@ -280,9 +251,7 @@ end function reshape( x::TracedRArray{T,N}, dims::Vector{Int}; - location=MLIR.IR.Location( - "stablehlo.reshape", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("reshape", @__FILE__, @__LINE__), ) where {T,N} restype = mlir_type(TracedRArray{T,length(dims)}, dims) res = MLIR.IR.result(stablehlo.reshape(x.mlir_data; result_0=restype, location)) @@ -295,9 +264,7 @@ end function get_dimension_size( x::TracedRArray{T,N}, dim; - location=MLIR.IR.Location( - "stablehlo.get_dimension_size", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("get_dimension_size", @__FILE__, @__LINE__), ) where {T,N} dimension = MLIR.IR.Attribute(dim - 1) res = MLIR.IR.result( @@ -312,9 +279,7 @@ function set_dimension_size( x::TracedRArray{T,N}, size::TracedRNumber{Int}, dim::Int; - location=MLIR.IR.Location( - "stablehlo.set_dimension_size", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("set_dimension_size", @__FILE__, @__LINE__), ) where {T,N} dimension = MLIR.IR.Attribute(dim - 1) res = MLIR.IR.result( @@ -332,9 +297,7 @@ end function transpose( x::TracedRArray{T,N}, permutation; - location=MLIR.IR.Location( - "stablehlo.transpose", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("transpose", @__FILE__, @__LINE__), ) where {T,N} rsize = permute!(collect(size(x)), permutation) permutation = permutation .- 1 @@ -351,7 +314,7 @@ function pad( low=fill(0, N), high=fill(0, N), interior=fill(0, N), - location=MLIR.IR.Location("stablehlo.pad", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + location=mlir_stacktrace("pad", @__FILE__, @__LINE__), ) where {T,N} rsize = size(x) .+ low .+ high .+ max.(size(x) .- 1, 0) .* interior res = MLIR.IR.result( @@ -372,7 +335,7 @@ function slice( start_indices, limit_indices; strides=nothing, - location=MLIR.IR.Location("stablehlo.slice", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + location=mlir_stacktrace("slice", @__FILE__, @__LINE__), ) where {T,N} start_indices = start_indices .- 1 limit_indices = limit_indices @@ -396,9 +359,7 @@ end function complex( real::TracedRArray{T,N}, imag::TracedRArray{T,N}; - location=MLIR.IR.Location( - "stablehlo.complex", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("complex", @__FILE__, @__LINE__), ) where {T,N} res = MLIR.IR.result( stablehlo.complex( @@ -414,9 +375,7 @@ end function complex( real::TracedRNumber{T}, imag::TracedRNumber{T}; - location=MLIR.IR.Location( - "stablehlo.complex", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("complex", @__FILE__, @__LINE__), ) where {T} res = MLIR.IR.result( stablehlo.complex( @@ -430,8 +389,7 @@ function complex( end function real( - x::TracedRArray{Complex{T},N}; - location=MLIR.IR.Location("stablehlo.real", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("real", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( stablehlo.real(x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location) @@ -440,8 +398,7 @@ function real( end function real( - x::TracedRNumber{Complex{T}}; - location=MLIR.IR.Location("stablehlo.real", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("real", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( stablehlo.real(x.mlir_data; result=mlir_type(TracedRArray{T,0}, ()), location) @@ -450,8 +407,7 @@ function real( end function imag( - x::TracedRArray{Complex{T},N}; - location=MLIR.IR.Location("stablehlo.imag", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( stablehlo.imag(x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location) @@ -460,8 +416,7 @@ function imag( end function imag( - x::TracedRNumber{Complex{T}}; - location=MLIR.IR.Location("stablehlo.imag", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( stablehlo.imag(x.mlir_data; result=mlir_type(TracedRArray{T,0}, ()), location) @@ -472,8 +427,8 @@ end # function bitcast_convert( # ::Type{TracedRArray{U,N}}, # x::TracedRArray{T,N}; -# location=MLIR.IR.Location( -# "stablehlo.bitcast_convert", MLIR.IR.Location(@__FILE__, @__LINE__, 0) +# location=mlir_stacktrace( +# "bitcast_convert", @__FILE__, @__LINE__ # ), # ) where {T,N} # res = MLIR.IR.result( @@ -488,9 +443,9 @@ function fft( x::TracedRArray{T,N}; type::String, length, - location=MLIR.IR.Location("stablehlo.fft", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + location=mlir_stacktrace("fft", @__FILE__, @__LINE__), ) where {T,N} - @assert 1 <= Base.length(length) <= 3 "stablehlo.fft only supports up to rank 3" + @assert 1 <= Base.length(length) <= 3 "fft only supports up to rank 3" if type ∈ ("FFT", "IFFT") @assert T <: Complex @@ -528,9 +483,7 @@ end function cholesky( x::TracedRArray{T,N}; lower::Bool=false, - location=MLIR.IR.Location( - "stablehlo.cholesky", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("cholesky", @__FILE__, @__LINE__), ) where {T,N} lower = MLIR.IR.Attribute(lower) res = MLIR.IR.result( @@ -545,7 +498,7 @@ function clamp( min::Union{TracedRNumber{T},TracedRArray{T,N}}, x::TracedRArray{T,N}, max::Union{TracedRNumber{T},TracedRArray{T,N}}; - location=MLIR.IR.Location("stablehlo.clamp", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + location=mlir_stacktrace("clamp", @__FILE__, @__LINE__), ) where {T,N} res = MLIR.IR.result( stablehlo.clamp( @@ -573,8 +526,8 @@ end # padding=nothing, # lhs_dilation=nothing, # rhs_dilation=nothing, -# location=MLIR.IR.Location( -# "stablehlo.convolution", MLIR.IR.Location(@__FILE__, @__LINE__, 0) +# location=mlir_stacktrace( +# "convolution", @__FILE__, @__LINE__ # ), # ) where {T,N} # res = MLIR.IR.result( @@ -600,8 +553,8 @@ end # lhs_contracting_dimensions, # rhs_contracting_dimensions, # result_permutation, -# location=MLIR.IR.Location( -# "stablehlo.dot_general", MLIR.IR.Location(@__FILE__, @__LINE__, 0) +# location=mlir_stacktrace( +# "dot_general", @__FILE__, @__LINE__ # ), # ) where {T,N} # res = MLIR.IR.result( @@ -623,9 +576,7 @@ function einsum( lhs::TracedRArray{T}, rhs::TracedRArray{T}; equation::String, - location=MLIR.IR.Location( - "stablehlo.einsum", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("einsum", @__FILE__, @__LINE__), ) where {T} ins, ic = split(equation, "->") ia, ib = split(ins, ",") @@ -654,8 +605,8 @@ end # function unary_einsum( # x::TracedRArray{T}; # equation::String, -# location=MLIR.IR.Location( -# "stablehlo.unary_einsum", MLIR.IR.Location(@__FILE__, @__LINE__, 0) +# location=mlir_stacktrace( +# "unary_einsum", @__FILE__, @__LINE__ # ), # ) where {T} # ia, ic = split(equation, "->") @@ -676,30 +627,17 @@ end # end # paralell ops -function partition_id(; - location=MLIR.IR.Location( - "stablehlo.partition_id", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), -) +function partition_id(; location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)) res = MLIR.IR.result(stablehlo.partition_id(; location)) return TracedRNumber{UInt32}((), res) end -function replica_id(; - location=MLIR.IR.Location( - "stablehlo.replica_id", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), -) +function replica_id(; location=mlir_stacktrace("replica_id", @__FILE__, @__LINE__)) res = MLIR.IR.result(stablehlo.replica_id(; location)) return TracedRNumber{UInt32}((), res) end -function after_all( - tokens...; - location=MLIR.IR.Location( - "stablehlo.after_all", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), -) +function after_all(tokens...; location=mlir_stacktrace("after_all", @__FILE__, @__LINE__)) tokens = [token.mlir_data for token in tokens] res = MLIR.IR.result(stablehlo.after_all(tokens; location)) return Token(res) @@ -707,9 +645,7 @@ end function optimization_barrier( operands::Union{TracedRNumber,TracedRArray}...; - location=MLIR.IR.Location( - "stablehlo.optimization_barrier", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("optimization_barrier", @__FILE__, @__LINE__), ) values = [operand.mlir_data for operand in operands] op = stablehlo.optimization_barrier(values; location) @@ -732,9 +668,7 @@ function outfeed( operands::Union{TracedRNumber,TracedRArray}...; token, config="", - location=MLIR.IR.Location( - "stablehlo.outfeed", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("outfeed", @__FILE__, @__LINE__), ) values = [operand.mlir_data for operand in operands] outfeed_config = MLIR.IR.Attribute(config) @@ -750,7 +684,7 @@ function send( channel_id::Int, channel_type::Int, is_host_transfer=nothing, - location=MLIR.IR.Location("stablehlo.send", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + location=mlir_stacktrace("send", @__FILE__, @__LINE__), ) values = [operand.mlir_data for operand in operands] channel_handle = MLIR.API.stablehloChannelHandleGet( @@ -773,7 +707,7 @@ function recv( channel_id::Int, channel_type::Int, is_host_transfer=nothing, - location=MLIR.IR.Location("stablehlo.recv", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + location=mlir_stacktrace("recv", @__FILE__, @__LINE__), ) channel_handle = MLIR.API.stablehloChannelHandleGet( MLIR.IR.context(), channel_id, channel_type @@ -807,8 +741,8 @@ end # function broadcast_in_dim( # x::TracedRArray{T,N}, # dims::Vector{Int}; -# location=MLIR.IR.Location( -# "stablehlo.broadcast_in_dim", MLIR.IR.Location(@__FILE__, @__LINE__, 0) +# location=mlir_stacktrace( +# "broadcast_in_dim", @__FILE__, @__LINE__ # ), # ) where {T,N} # rsize = restype = MLIR.IR.TensorType([...], mlir_type(T)) # mlir_type(TracedRArray{T,N}, size(x)) @@ -830,7 +764,7 @@ end # comparator, # dimension=-1, # is_stable=false, -# location=MLIR.IR.Location("stablehlo.sort", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), +# location=mlir_stacktrace("sort", @__FILE__, @__LINE__), # ) where {T,N} # dimension = MLIR.IR.Attribute(dimension) # is_stable = MLIR.IR.Attribute(is_stable) @@ -847,9 +781,7 @@ end # end function top_k( - x::TracedRArray{T,N}, - k; - location=MLIR.IR.Location("chlo.top_k", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__) ) where {T,N} rsize = [size(x)[1:(end - 1)]..., k] values = mlir_type(TracedRArray{T,N}, rsize) @@ -865,7 +797,7 @@ function iota( T::Type, shape::Vector{Int}; iota_dimension, - location=MLIR.IR.Location("stablehlo.iota", MLIR.IR.Location(@__FILE__, @__LINE__, 0)), + location=mlir_stacktrace("iota", @__FILE__, @__LINE__), ) N = length(shape) output = mlir_type(TracedRArray{T,N}, shape) @@ -877,9 +809,7 @@ end function reverse( x::TracedRArray{T,N}; dimensions, - location=MLIR.IR.Location( - "stablehlo.reverse", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("reverse", @__FILE__, @__LINE__), ) where {T,N} res = MLIR.IR.result( stablehlo.reverse( @@ -897,9 +827,7 @@ function rng_bit_generator( seed::TracedRArray{UInt64,1}, shape; algorithm::String="DEFAULT", - location=MLIR.IR.Location( - "stablehlo.rng_bit_generator", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__), ) output = MLIR.IR.TensorType(TracedRArray{UInt64,1}, shape) rng_algorithm = MLIR.API.stablehloRngAlgorithmAttrGet(MLIR.IR.context(), algorithm) @@ -913,9 +841,7 @@ end # functional ops function return_( results::Union{TracedRArray,TracedRNumber}...; - location=MLIR.IR.Location( - "stablehlo.return_", MLIR.IR.Location(@__FILE__, @__LINE__, 0) - ), + location=mlir_stacktrace("return_", @__FILE__, @__LINE__), ) return stablehlo.return_([x.mlir_data for x in results]; location) end diff --git a/src/mlir/IR/Location.jl b/src/mlir/IR/Location.jl index d98e4fa1c7..f7d1b12e7b 100644 --- a/src/mlir/IR/Location.jl +++ b/src/mlir/IR/Location.jl @@ -14,7 +14,7 @@ function Location(filename, line, column; context::Context=context()) end function Location(callee::Location, caller::Location; context::Context=context()) - return Location(API.mlirLocationCallSiteGet(context, callee, caller)) + return Location(API.mlirLocationCallSiteGet(callee, caller)) end function Location(name::String, location::Location; context::Context=context()) diff --git a/src/utils.jl b/src/utils.jl index 18d93d8d07..b37e00fd19 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -247,3 +247,32 @@ function make_mlir_fn( linear_results, ) end + +const DEBUG_MODE::Ref{Bool} = Ref(false) + +function with_debug(f) + old = DEBUG_MODE[] + DEBUG_MODE[] = true + try + return f() + finally + DEBUG_MODE[] = old + end +end + +function mlir_stacktrace(name, file, line)::MLIR.IR.Location + # calling `stacktrace` can add a lot of time overhead, so let's avoid adding debug info if not used + if DEBUG_MODE[] + return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) + end + + # retrieve current stacktrace, remove this function's frame and translate to MLIR Location + st = stacktrace() + deleteat!(st, 1) + return mapfoldl(MLIR.IR.Location, st) do stackframe + name = string(stackframe.func) + file = stackframe.file + line = stackframe.line + return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) + end +end