diff --git a/src/Compiler.jl b/src/Compiler.jl index 8c081cce2f..900c360be7 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -336,12 +336,12 @@ const cuModule = Ref{UInt}(0) function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues - + MLIR.IR.activate!(mod) MLIR.IR.activate!(MLIR.IR.body(mod)) - fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, - linear_results = - try + fnwrapped, + func2, traced_result, result, seen_args, ret, linear_args, in_tys, + linear_results = try Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) finally MLIR.IR.deactivate!(MLIR.IR.body(mod)) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 73979b4ca3..20f4d6a83c 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -1,12 +1,18 @@ function ConcreteRNumber{T}( - data::T2; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing + data::T2; + client::XLA.Client=XLA.default_backend[], + idx::Int=XLA.default_device_idx[], + device::Union{Nothing,XLA.Device}=nothing, ) where {T<:Number,T2<:Number} data = convert(T, data) crarray = ConcreteRArray(fill(data); client, idx, device) return ConcreteRNumber{T}(crarray.data) end function ConcreteRNumber( - data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing + data::T; + client::XLA.Client=XLA.default_backend[], + idx::Int=XLA.default_device_idx[], + device::Union{Nothing,XLA.Device}=nothing, ) where {T<:Number} crarray = ConcreteRArray(fill(data); client, idx, device) return ConcreteRNumber{T}(crarray.data) @@ -37,7 +43,10 @@ end Base.convert(::Type{T}, x::ConcreteRNumber) where {T<:Number} = convert(T, to_number(x)) function ConcreteRArray( - data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing + data::T; + client::XLA.Client=XLA.default_backend[], + idx::Int=XLA.default_device_idx[], + device::Union{Nothing,XLA.Device}=nothing, ) where {T<:Number} Base.depwarn( "ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead", @@ -54,7 +63,7 @@ function ConcreteRArray( data::Array{T,N}; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], - device::Union{Nothing, XLA.Device}=nothing, + device::Union{Nothing,XLA.Device}=nothing, ) where {T,N} device = device === nothing ? XLA.ClientGetDevice(client, idx) : device return ConcreteRArray{T,N}( diff --git a/src/Precompile.jl b/src/Precompile.jl index 4684287b74..98c60dee5a 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -45,7 +45,7 @@ end # infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}}) x = ConcreteRNumber(2.0; client) Reactant.compile(sin, (x,); client) - + y = ConcreteRArray([2.0]; client) Reactant.compile(Base.sum, (y,); client) end diff --git a/src/utils.jl b/src/utils.jl index 8ea2591036..9b751d14f5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -528,7 +528,7 @@ function call_with_reactant_generator( # octup = Tuple{method.sig.parameters[2:end]...} octup = Tuple{tys[2:end]...} ocva = false - + # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). # Opaque closures also require taking the function argument. We can work around the latter