Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ const cuModule = Ref{UInt}(0)
function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues

MLIR.IR.activate!(mod)
MLIR.IR.activate!(MLIR.IR.body(mod))
fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results =
try
fnwrapped,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = try
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
finally
MLIR.IR.deactivate!(MLIR.IR.body(mod))
Expand Down
17 changes: 13 additions & 4 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
function ConcreteRNumber{T}(
data::T2; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing
data::T2;
client::XLA.Client=XLA.default_backend[],
idx::Int=XLA.default_device_idx[],
device::Union{Nothing,XLA.Device}=nothing,
) where {T<:Number,T2<:Number}
data = convert(T, data)
crarray = ConcreteRArray(fill(data); client, idx, device)
return ConcreteRNumber{T}(crarray.data)
end
function ConcreteRNumber(
data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing
data::T;
client::XLA.Client=XLA.default_backend[],
idx::Int=XLA.default_device_idx[],
device::Union{Nothing,XLA.Device}=nothing,
) where {T<:Number}
crarray = ConcreteRArray(fill(data); client, idx, device)
return ConcreteRNumber{T}(crarray.data)
Expand Down Expand Up @@ -37,7 +43,10 @@ end
Base.convert(::Type{T}, x::ConcreteRNumber) where {T<:Number} = convert(T, to_number(x))

function ConcreteRArray(
data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing
data::T;
client::XLA.Client=XLA.default_backend[],
idx::Int=XLA.default_device_idx[],
device::Union{Nothing,XLA.Device}=nothing,
) where {T<:Number}
Base.depwarn(
"ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead",
Expand All @@ -54,7 +63,7 @@ function ConcreteRArray(
data::Array{T,N};
client::XLA.Client=XLA.default_backend[],
idx::Int=XLA.default_device_idx[],
device::Union{Nothing, XLA.Device}=nothing,
device::Union{Nothing,XLA.Device}=nothing,
) where {T,N}
device = device === nothing ? XLA.ClientGetDevice(client, idx) : device
return ConcreteRArray{T,N}(
Expand Down
2 changes: 1 addition & 1 deletion src/Precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
# infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}})
x = ConcreteRNumber(2.0; client)
Reactant.compile(sin, (x,); client)

y = ConcreteRArray([2.0]; client)
Reactant.compile(Base.sum, (y,); client)
end
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ function call_with_reactant_generator(
# octup = Tuple{method.sig.parameters[2:end]...}
octup = Tuple{tys[2:end]...}
ocva = false

# jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right
# inner code during compilation without special handling (i.e. call_in_world_total).
# Opaque closures also require taking the function argument. We can work around the latter
Expand Down