-
Couldn't load subscription status.
- Fork 34
CPU backend #647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU backend #647
Changes from all commits
29c41a3
58bc103
e58c4b6
a2de150
575fc92
80e4e66
ca7d5b6
9818c8f
e1b1e8f
4ea9f49
7d906ae
6014df7
b51ad2d
264ec4b
f845940
b778228
75cb2ae
b625452
53a1b52
3c25856
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -429,7 +429,9 @@ end | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| const DEBUG_KERNEL = Ref{Bool}(false) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) | ||||||||||||||||||||||
| function compile_mlir!( | ||||||||||||||||||||||
| mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| # Explicitly don't use block! to avoid creating a closure, which creates | ||||||||||||||||||||||
| # both compile-time and relocatability issues | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -456,7 +458,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: | |||||||||||||||||||||
| if isdefined(Reactant_jll, :ptxas_path) | ||||||||||||||||||||||
| toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| if DEBUG_KERNEL[] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if backend == "cpu" | ||||||||||||||||||||||
| kern = "lower-kernel{openmp=false backend=cpu},symbol-dce" | ||||||||||||||||||||||
| elseif DEBUG_KERNEL[] | ||||||||||||||||||||||
| curesulthandler = XLA.Libdl.dlsym( | ||||||||||||||||||||||
| Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
@@ -604,7 +609,9 @@ end | |||||||||||||||||||||
| @code_hlo [optimize = ...] [no_nan = <true/false>] f(args...) | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| macro code_hlo(args...) | ||||||||||||||||||||||
| default_options = Dict{Symbol,Any}(:optimize => true, :no_nan => false) | ||||||||||||||||||||||
| default_options = Dict{Symbol,Any}( | ||||||||||||||||||||||
| :optimize => true, :no_nan => false, :backend => "gpu" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| compile_expr, (; compiled) = compile_call_expr( | ||||||||||||||||||||||
| __module__, compile_mlir, default_options, args... | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
@@ -975,16 +982,26 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic | |||||||||||||||||||||
| context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0) | ||||||||||||||||||||||
| @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if client !== nothing | ||||||||||||||||||||||
| backend = XLA.ClientGetPlatformName(client) | ||||||||||||||||||||||
| else | ||||||||||||||||||||||
| backend = XLA.ClientGetPlatformName(XLA.default_backend[]) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| if backend == "CUDA" | ||||||||||||||||||||||
| backend = "GPU" | ||||||||||||||||||||||
| elseif backend == "CPU" | ||||||||||||||||||||||
| backend = "cpu" | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| MLIR.IR.activate!(ctx) | ||||||||||||||||||||||
| results = try | ||||||||||||||||||||||
| # compile function to MLIR module | ||||||||||||||||||||||
| mod = MLIR.IR.Module(MLIR.IR.Location()) | ||||||||||||||||||||||
| linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!( | ||||||||||||||||||||||
| mod, f, args; optimize, no_nan | ||||||||||||||||||||||
| mod, f, args; optimize, no_nan, backend | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Resolve client and device | ||||||||||||||||||||||
| device_ordinal = -1 | ||||||||||||||||||||||
| if device === nothing | ||||||||||||||||||||||
| if length(linear_args) > 0 | ||||||||||||||||||||||
| devices_list = [ | ||||||||||||||||||||||
|
|
@@ -1002,32 +1019,20 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, devic | |||||||||||||||||||||
| client = XLA.client(device) | ||||||||||||||||||||||
| else | ||||||||||||||||||||||
| client = XLA.default_backend[] | ||||||||||||||||||||||
| device = XLA.ClientGetDevice(client, XLA.default_device_idx[]) | ||||||||||||||||||||||
| device_ordinal = XLA.default_device_idx[] | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| else | ||||||||||||||||||||||
| if device !== nothing | ||||||||||||||||||||||
| @assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same" | ||||||||||||||||||||||
| else | ||||||||||||||||||||||
| device = XLA.ClientGetDevice(client, XLA.default_device_idx[]) | ||||||||||||||||||||||
| device_ordinal = XLA.default_device_idx[] | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if device_ordinal < 0 | ||||||||||||||||||||||
| device_ordinal = XLA.DeviceToClientDeviceOrdinal(device) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| # compile MLIR module to XLA executable | ||||||||||||||||||||||
| exec = XLA.Compile( | ||||||||||||||||||||||
| client, | ||||||||||||||||||||||
| mod; | ||||||||||||||||||||||
| device_ordinal, | ||||||||||||||||||||||
| num_replicas=1, | ||||||||||||||||||||||
| num_partitions=1, | ||||||||||||||||||||||
| use_shardy_partitioner=false, | ||||||||||||||||||||||
| mod | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
1031
to
1034
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| return ( | ||||||||||||||||||||||
| ( | ||||||||||||||||||||||
| exec, | ||||||||||||||||||||||
| linear_args, | ||||||||||||||||||||||
| linear_results, | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,17 +15,51 @@ end | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| mutable struct Client | ||||||||||||||||||||||
| client::Ptr{Cvoid} | ||||||||||||||||||||||
| global_ordinals::Vector{Cint} | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function Client(client::Ptr{Cvoid}) | ||||||||||||||||||||||
| @assert client != C_NULL | ||||||||||||||||||||||
| return new(client) | ||||||||||||||||||||||
| global_ordinals = Cint[] | ||||||||||||||||||||||
| client = new(client, global_ordinals) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127 | ||||||||||||||||||||||
| devices = [ | ||||||||||||||||||||||
| ClientGetAddressableDevice(client, i - 1) for | ||||||||||||||||||||||
| i in 1:ClientNumAddressableDevices(client) | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| sort!(devices; lt=(a, b) -> DeviceGetLocalDeviceId(a) < DeviceGetLocalDeviceId(b)) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| local_ids = [DeviceGetLocalDeviceId(device) + 1 for device in devices] | ||||||||||||||||||||||
| max_local_id = maximum(local_ids) | ||||||||||||||||||||||
| resize!(global_ordinals, max_local_id) | ||||||||||||||||||||||
| global_ordinals .= -1 | ||||||||||||||||||||||
| for (i, device) in enumerate(devices) | ||||||||||||||||||||||
| global_ordinals[local_ids[i]] = i - 1 | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| return client | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Base.:(==)(a::Client, b::Client) = a.client == b.client | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function Base.show(io::IO, ::MIME"text/plain", client::Client) | ||||||||||||||||||||||
| print(io, "Client($(client.client), platform_name=$(ClientGetPlatformName(client)))") | ||||||||||||||||||||||
| struct Device | ||||||||||||||||||||||
| device::Ptr{Cvoid} | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function device_ordinal(client::Client, device::Device) | ||||||||||||||||||||||
| return client.global_ordinals[DeviceGetLocalDeviceId(device) + 1] | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function DeviceToString(device::Device) | ||||||||||||||||||||||
| pjrtclient = client(device) | ||||||||||||||||||||||
| platform_name = ClientGetPlatformName(pjrtclient) | ||||||||||||||||||||||
| return "$(uppercase(platform_name)):$(device_ordinal(pjrtclient, device))" | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function Base.show(io::IO, ::MIME"text/plain", device::Device) | ||||||||||||||||||||||
| pjrtclient = client(device) | ||||||||||||||||||||||
| platform_name = ClientGetPlatformName(pjrtclient) | ||||||||||||||||||||||
| print(io, "Device($(device.device), platform_name=$(platform_name))") | ||||||||||||||||||||||
| return nothing | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -178,6 +212,7 @@ function __init__() | |||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid | ||||||||||||||||||||||
| @ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # This wasn't properly exported on macos, we'll remove the try once macOS JLL | ||||||||||||||||||||||
|
|
@@ -227,17 +262,6 @@ mutable struct Buffer | |||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| struct Device | ||||||||||||||||||||||
| device::Ptr{Cvoid} | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function Base.show(io::IO, ::MIME"text/plain", device::Device) | ||||||||||||||||||||||
| pjrtclient = client(device) | ||||||||||||||||||||||
| platform_name = ClientGetPlatformName(pjrtclient) | ||||||||||||||||||||||
| print(io, "Device($(device.device), platform_name=$(platform_name))") | ||||||||||||||||||||||
| return nothing | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function DeviceToClientDeviceOrdinal(device::Device) | ||||||||||||||||||||||
| pjrtclient = client(device) | ||||||||||||||||||||||
| naddressable_devices = ClientNumAddressableDevices(pjrtclient) | ||||||||||||||||||||||
|
|
@@ -336,7 +360,7 @@ Return an [`AllocatorStats`](@ref) instance with information about the device sp | |||||||||||||||||||||
| This method is currently not implemented for the CPU device. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| function allocatorstats( | ||||||||||||||||||||||
| device::Device=ClientGetDevice(default_backend[], default_device_idx[]) | ||||||||||||||||||||||
| device::Device=ClientGetAddressableDevice(default_backend[], default_device_idx[]) | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| ref = Ref{JLAllocatorStats}() | ||||||||||||||||||||||
| @ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats( | ||||||||||||||||||||||
|
|
@@ -539,21 +563,16 @@ end | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| function Compile( | ||||||||||||||||||||||
| client::Client, | ||||||||||||||||||||||
| mod::MLIR.IR.Module; | ||||||||||||||||||||||
| device_ordinal::Int=-1, | ||||||||||||||||||||||
| num_replicas::Int=1, | ||||||||||||||||||||||
| num_partitions::Int=1, | ||||||||||||||||||||||
| use_shardy_partitioner::Bool=false, | ||||||||||||||||||||||
| mod::MLIR.IR.Module | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
564
to
567
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| max_local_id = length(client.global_ordinals) | ||||||||||||||||||||||
| GC.@preserve client mod begin | ||||||||||||||||||||||
| executable = LoadedExecutable( | ||||||||||||||||||||||
| @ccall MLIR.API.mlir_c.ClientCompile( | ||||||||||||||||||||||
| client.client::Ptr{Cvoid}, | ||||||||||||||||||||||
| mod.module_::MLIR.API.MlirModule, | ||||||||||||||||||||||
| device_ordinal::Cint, | ||||||||||||||||||||||
| num_replicas::Cint, | ||||||||||||||||||||||
| num_partitions::Cint, | ||||||||||||||||||||||
| use_shardy_partitioner::Bool, | ||||||||||||||||||||||
| client.global_ordinals::Ptr{Cint}, | ||||||||||||||||||||||
| max_local_id::Cint, | ||||||||||||||||||||||
| )::Ptr{Cvoid} | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
@@ -608,6 +627,37 @@ function ClientGetPlatformName(client::Client) | |||||||||||||||||||||
| return unsafe_string(str) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function DeviceGetLocalDeviceId(device::Device) | ||||||||||||||||||||||
| GC.@preserve device begin | ||||||||||||||||||||||
| return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalDeviceId( | ||||||||||||||||||||||
| device.device::Ptr{Cvoid} | ||||||||||||||||||||||
| )::Cint | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function PjRtLoadedExecutableGetClient(exec::LoadedExecutable) | ||||||||||||||||||||||
| GC.@preserve exec begin | ||||||||||||||||||||||
| return Client( | ||||||||||||||||||||||
| @ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetClient( | ||||||||||||||||||||||
| exec.exec::Ptr{Cvoid} | ||||||||||||||||||||||
| )::Ptr{Cvoid} | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function replicate_buffer_on_all_addressable_devices(buffer::Buffer) | ||||||||||||||||||||||
| pjrtclient = client(buffer) | ||||||||||||||||||||||
| devices = [ | ||||||||||||||||||||||
| ClientGetAddressableDevice(pjrtclient, i - 1) for | ||||||||||||||||||||||
| i in 1:ClientNumAddressableDevices(pjrtclient) | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| orig_device = device(buffer) | ||||||||||||||||||||||
| return [ | ||||||||||||||||||||||
| device == orig_device ? buffer : CopyBufferToDevice(buffer, device) for | ||||||||||||||||||||||
| device in devices | ||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| function is_ready(future::Future) | ||||||||||||||||||||||
| GC.@preserve future begin | ||||||||||||||||||||||
| return (@ccall MLIR.API.mlir_c.FutureIsReady(future.future::Ptr{Cvoid})::UInt8) != 0 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.