Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.4"
Reactant_jll = "0.0.52"
Reactant_jll = "0.0.58"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Expand Down
4 changes: 3 additions & 1 deletion ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
module ReactantCUDAExt

using CUDA
using Reactant: Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, AnyConcreteRArray, MLIR, TracedRNumber
using ReactantCore: @trace
using KernelAbstractions: KernelAbstractions
using Libdl

using Adapt

KernelAbstractions.get_backend(::AnyTracedRArray) = CUDABackend()
KernelAbstractions.get_backend(::AnyConcreteRArray) = CUDABackend()

struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
ptr::Core.LLVMPtr{T,A}
Expand Down
43 changes: 24 additions & 19 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ end

const DEBUG_KERNEL = Ref{Bool}(false)

function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
function compile_mlir!(
mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu"
)
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues

Expand All @@ -456,7 +458,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
if isdefined(Reactant_jll, :ptxas_path)
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
end
if DEBUG_KERNEL[]

if backend == "cpu"
kern = "lower-kernel{openmp=false backend=cpu},symbol-dce"
elseif DEBUG_KERNEL[]
curesulthandler = XLA.Libdl.dlsym(
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
)
Expand Down Expand Up @@ -604,7 +609,9 @@ end
@code_hlo [optimize = ...] [no_nan = <true/false>] f(args...)
"""
macro code_hlo(args...)
default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false)
default_options = Dict{Symbol,Any}(
:optimize => true, :no_nan => false, :backend => "gpu"
)
compile_expr, (; compiled) = compile_call_expr(
__module__, compile_mlir, default_options, args...
)
Expand Down Expand Up @@ -975,16 +982,26 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

if client !== nothing
backend = XLA.ClientGetPlatformName(client)
else
backend = XLA.ClientGetPlatformName(XLA.default_backend[])
end
if backend == "CUDA"
backend = "GPU"
elseif backend == "CPU"
backend = "cpu"
end

MLIR.IR.activate!(ctx)
results = try
# compile function to MLIR module
mod = MLIR.IR.Module(MLIR.IR.Location())
linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!(
mod, f, args; optimize, no_nan
mod, f, args; optimize, no_nan, backend
)

# Resolve client and device
device_ordinal = -1
if device === nothing
if length(linear_args) > 0
devices_list = [
Expand All @@ -1002,32 +1019,20 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic
client = XLA.client(device)
else
client = XLA.default_backend[]
device = XLA.ClientGetDevice(client, XLA.default_device_idx[])
device_ordinal = XLA.default_device_idx[]
end
else
if device !== nothing
@assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same"
else
device = XLA.ClientGetDevice(client, XLA.default_device_idx[])
device_ordinal = XLA.default_device_idx[]
end
end

if device_ordinal < 0
device_ordinal = XLA.DeviceToClientDeviceOrdinal(device)
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# compile MLIR module to XLA executable
exec = XLA.Compile(
client,
mod;
device_ordinal,
num_replicas=1,
num_partitions=1,
use_shardy_partitioner=false,
mod
)
Comment on lines 1031 to 1034
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
exec = XLA.Compile(
client,
mod;
device_ordinal,
num_replicas=1,
num_partitions=1,
use_shardy_partitioner=false,
mod
)
exec = XLA.Compile(client, mod)

return (
(
exec,
linear_args,
linear_results,
Expand Down
98 changes: 74 additions & 24 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,51 @@ end

mutable struct Client
client::Ptr{Cvoid}
global_ordinals::Vector{Cint}

function Client(client::Ptr{Cvoid})
@assert client != C_NULL
return new(client)
global_ordinals = Cint[]
client = new(client, global_ordinals)

# https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127
devices = [
ClientGetAddressableDevice(client, i - 1) for
i in 1:ClientNumAddressableDevices(client)
]
sort!(devices; lt=(a, b) -> DeviceGetLocalDeviceId(a) < DeviceGetLocalDeviceId(b))

local_ids = [DeviceGetLocalDeviceId(device) + 1 for device in devices]
max_local_id = maximum(local_ids)
resize!(global_ordinals, max_local_id)
global_ordinals .= -1
for (i, device) in enumerate(devices)
global_ordinals[local_ids[i]] = i - 1
end
return client
end
end

Base.:(==)(a::Client, b::Client) = a.client == b.client

function Base.show(io::IO, ::MIME"text/plain", client::Client)
print(io, "Client($(client.client), platform_name=$(ClientGetPlatformName(client)))")
struct Device
device::Ptr{Cvoid}
end

function device_ordinal(client::Client, device::Device)
return client.global_ordinals[DeviceGetLocalDeviceId(device) + 1]
end

function DeviceToString(device::Device)
pjrtclient = client(device)
platform_name = ClientGetPlatformName(pjrtclient)
return "$(uppercase(platform_name)):$(device_ordinal(pjrtclient, device))"
end

function Base.show(io::IO, ::MIME"text/plain", device::Device)
pjrtclient = client(device)
platform_name = ClientGetPlatformName(pjrtclient)
print(io, "Device($(device.device), platform_name=$(platform_name))")
return nothing
end

Expand Down Expand Up @@ -178,6 +212,7 @@ function __init__()
end
end

@ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid

# This wasn't properly exported on macos, we'll remove the try once macOS JLL
Expand Down Expand Up @@ -227,17 +262,6 @@ mutable struct Buffer
end
end

struct Device
device::Ptr{Cvoid}
end

function Base.show(io::IO, ::MIME"text/plain", device::Device)
pjrtclient = client(device)
platform_name = ClientGetPlatformName(pjrtclient)
print(io, "Device($(device.device), platform_name=$(platform_name))")
return nothing
end

function DeviceToClientDeviceOrdinal(device::Device)
pjrtclient = client(device)
naddressable_devices = ClientNumAddressableDevices(pjrtclient)
Expand Down Expand Up @@ -336,7 +360,7 @@ Return an [`AllocatorStats`](@ref) instance with information about the device sp
This method is currently not implemented for the CPU device.
"""
function allocatorstats(
device::Device=ClientGetDevice(default_backend[], default_device_idx[])
device::Device=ClientGetAddressableDevice(default_backend[], default_device_idx[])
)
ref = Ref{JLAllocatorStats}()
@ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats(
Expand Down Expand Up @@ -539,21 +563,16 @@ end

function Compile(
client::Client,
mod::MLIR.IR.Module;
device_ordinal::Int=-1,
num_replicas::Int=1,
num_partitions::Int=1,
use_shardy_partitioner::Bool=false,
mod::MLIR.IR.Module
)
Comment on lines 564 to 567
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function Compile(
client::Client,
mod::MLIR.IR.Module;
device_ordinal::Int=-1,
num_replicas::Int=1,
num_partitions::Int=1,
use_shardy_partitioner::Bool=false,
mod::MLIR.IR.Module
)
function Compile(client::Client, mod::MLIR.IR.Module)

max_local_id = length(client.global_ordinals)
GC.@preserve client mod begin
executable = LoadedExecutable(
@ccall MLIR.API.mlir_c.ClientCompile(
client.client::Ptr{Cvoid},
mod.module_::MLIR.API.MlirModule,
device_ordinal::Cint,
num_replicas::Cint,
num_partitions::Cint,
use_shardy_partitioner::Bool,
client.global_ordinals::Ptr{Cint},
max_local_id::Cint,
)::Ptr{Cvoid}
)
end
Expand Down Expand Up @@ -608,6 +627,37 @@ function ClientGetPlatformName(client::Client)
return unsafe_string(str)
end

function DeviceGetLocalDeviceId(device::Device)
GC.@preserve device begin
return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalDeviceId(
device.device::Ptr{Cvoid}
)::Cint
end
end

function PjRtLoadedExecutableGetClient(exec::LoadedExecutable)
GC.@preserve exec begin
return Client(
@ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetClient(
exec.exec::Ptr{Cvoid}
)::Ptr{Cvoid}
)
end
end

function replicate_buffer_on_all_addressable_devices(buffer::Buffer)
pjrtclient = client(buffer)
devices = [
ClientGetAddressableDevice(pjrtclient, i - 1) for
i in 1:ClientNumAddressableDevices(pjrtclient)
]
orig_device = device(buffer)
return [
device == orig_device ? buffer : CopyBufferToDevice(buffer, device) for
device in devices
]
end

function is_ready(future::Future)
GC.@preserve future begin
return (@ccall MLIR.API.mlir_c.FutureIsReady(future.future::Ptr{Cvoid})::UInt8) != 0
Expand Down
Loading