diff --git a/docs/make.jl b/docs/make.jl index fc1ac82e0f..80b6404cc7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -32,8 +32,11 @@ pages = [ "Getting Started" => "introduction/index.md", "Configuration" => "introduction/configuration.md", ], - "Tutorials" => - ["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"], + "Tutorials" => [ + "Overview" => "tutorials/index.md", + "Profiling" => "tutorials/profiling.md", + "Distributed" => "tutorials/multihost.md", + ], "API Reference" => [ "Reactant API" => "api/api.md", "Ops" => "api/ops.md", diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 74e6d4ec59..4657bffc01 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -63,6 +63,7 @@ export default defineConfig({ items: [ {text: "Overview", link: "/tutorials/"}, {text: "Profiling", link: "/tutorials/profiling"}, + {text: "Distributed", link: "/tutorials/multihost"}, ], }, { @@ -122,6 +123,7 @@ export default defineConfig({ items: [ { text: "Overview", link: "/tutorials/" }, { text: "Profiling", link: "/tutorials/profiling" }, + { text: "Distributed", link: "/tutorials/multihost" }, ], }, "/api/": { diff --git a/docs/src/api/sharding.md b/docs/src/api/sharding.md index 037f8dbac0..a32b0183a8 100644 --- a/docs/src/api/sharding.md +++ b/docs/src/api/sharding.md @@ -2,7 +2,7 @@ CollapsedDocStrings = true ``` -# Sharding API +# [Sharding API](@id sharding-api) `Reactant.Sharding` module provides a high-level API to construct MLIR operations with support for sharding. diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 87c2c8ddd3..dadc887bee 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -1,5 +1,6 @@ # Tutorials - [Profiling](@ref profiling). + - [Multi-Host Environments](@ref distributed). We are currently working on adding more tutorials to Reactant!! Please check back soon! diff --git a/docs/src/tutorials/multihost.md b/docs/src/tutorials/multihost.md new file mode 100644 index 0000000000..535824c423 --- /dev/null +++ b/docs/src/tutorials/multihost.md @@ -0,0 +1,82 @@ +# [Multi-Host Environments](@ref distributed) + +!!! tip "Use XLA IFRT Runtime" + + While PJRT does support some minimal distributed capabilities on CUDA GPUs, distributed + support in Reactant is primarily provided via IFRT. Before loading Reactant, set the + "xla_runtime" preference to be "IFRT". This can be done with: + + ```julia + using Preferences, UUIDs + + Preferences.set_preference!( + UUID("3c362404-f566-11ee-1572-e11a4b42c853"), + "xla_runtime" => "IFRT" + ) + ``` + +At the top of your code, just after loading Reactant and before running any Reactant related +operations, run `Reactant.Distributed.initialize()`. + +!!! tip "Enable debug logging for debugging" + + Reactant emits a lot of useful debugging information when setting up the Distributed + Runtime. This can be printing by setting the env var `JULIA_DEBUG` to contain + `Reactant`. + +After this simply setup your code with [`Reactant.Sharding`](@ref sharding-api) and the code +will run on multiple devices across multiple nodes. + +## Example Slurm Script for Multi-Host Matrix Multiplication + +::: code-group + +```bash [main.sbatch] +#!/bin/bash -l +# +#SBATCH --job-name=matmul-sharding-reactant +#SBATCH --time=00:20:00 +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=72 +#SBATCH --account= +#SBATCH --constraint=gpu + +export JULIA_DEBUG="Reactant,Reactant_jll" + +srun --preserve-env bash ./matmul.sh +``` + +```bash [matmul.sh] +#!/bin/bash -l + +# Important else XLA might hang indefinitely +unset no_proxy http_proxy https_proxy NO_PROXY HTTP_PROXY HTTPS_PROXY + +julia --project=. --threads=auto matmul_sharded.jl +``` + +```julia [matmul_sharded.jl] +using Reactant + +Reactant.Distributed.initialize(; single_gpu_per_process=false) + +@assert length(Reactant.devices()) >= 2 + +N = min((length(Reactant.devices()) ÷ 2) * 2, 8) + +mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :), (:x, :y)) +sharding = Sharding.NamedSharding(mesh, (:x, :y)) + +x = reshape(collect(Float32, 1:64), 8, 8) +y = reshape(collect(Float32, 1:64), 8, 8) + +x_ra = Reactant.to_rarray(x; sharding) +y_ra = Reactant.to_rarray(y; sharding) + +res = @jit x_ra * y_ra + +display(res) +``` + +::: diff --git a/src/Distributed.jl b/src/Distributed.jl index 5e7c498777..f369880b2e 100644 --- a/src/Distributed.jl +++ b/src/Distributed.jl @@ -8,10 +8,16 @@ function initialize(; coordinator_address::Union{Nothing,String}=nothing, num_processes::Union{Nothing,Integer}=nothing, process_id::Union{Nothing,Integer}=nothing, + single_gpu_per_process::Bool=true, local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing, initialization_timeout_in_seconds::Integer=300, kwargs..., ) + if isinteractive() + @warn "Reactant.Distributed.initialize() should not be called in interactive mode. \ + Use Reactant.Distributed.initialize() in a script instead." + end + @assert !initialized[] "`Distributed.initialize` has already been called" (coordinator_address, num_processes, process_id, local_gpu_device_ids) = auto_detect_unset_distributed_params(; @@ -20,6 +26,7 @@ function initialize(; process_id, local_gpu_device_ids, initialization_timeout_in_seconds, + single_gpu_per_process, ) @debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_gpu_device_ids @@ -43,6 +50,8 @@ struct OpenMPIPMIXEnvDetector <: AbstractOMPIClusterEnvDetector end struct MPIEnvDetector <: AbstractClusterEnvDetector end +struct SlurmEnvDetector <: AbstractClusterEnvDetector end + # Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/cluster.py is_env_present(::AbstractClusterEnvDetector) = false @@ -53,12 +62,19 @@ function get_process_id end function get_local_process_id end function auto_detect_unset_distributed_params(; - detector_list=[OpenMPIORTEEnvDetector(), OpenMPIPMIXEnvDetector(), MPIEnvDetector()], + detector_list=[ + SlurmEnvDetector(), + OpenMPIORTEEnvDetector(), + MPIEnvDetector(), + # Keep this at the end since parsing for this is a bit flaky + OpenMPIPMIXEnvDetector(), + ], coordinator_address::Union{Nothing,String}=nothing, num_processes::Union{Nothing,Integer}=nothing, process_id::Union{Nothing,Integer}=nothing, local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing, initialization_timeout_in_seconds::Integer=300, + single_gpu_per_process::Bool=true, ) if all( Base.Fix2(!==, nothing), @@ -91,7 +107,7 @@ function auto_detect_unset_distributed_params(; process_id = get_process_id(detector) end - if local_gpu_device_ids === nothing + if local_gpu_device_ids === nothing && single_gpu_per_process local_gpu_device_ids = [get_local_process_id(detector)] end @@ -108,16 +124,18 @@ const _PMIX_SERVER_URI = ( "PMIX_SERVER_URI41", "PMIX_SERVER_URI21", ) +const _PMIX_NAMESPACE = "PMIX_NAMESPACE" +const _PRTERUN = "PRTE_LAUNCHED" +const _PMIX_VERSION = "PMIX_VERSION" const _OMPI_PROCESS_COUNT = "OMPI_COMM_WORLD_SIZE" const _OMPI_PROCESS_ID = "OMPI_COMM_WORLD_RANK" const _OMPI_LOCAL_PROCESS_ID = "OMPI_COMM_WORLD_LOCAL_RANK" is_env_present(::OpenMPIORTEEnvDetector) = haskey(ENV, _ORTE_URI) -is_env_present(::OpenMPIPMIXEnvDetector) = any(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI) +is_env_present(::OpenMPIPMIXEnvDetector) = haskey(ENV, _PMIX_NAMESPACE) function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer) orte_uri = ENV[_ORTE_URI] - job_id = parse(Int, split(orte_uri, '.'; limit=2)[1]) port = job_id % 2^12 + (65535 - 2^12 + 1) @@ -132,11 +150,48 @@ function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer) return "$(launcher_ip):$(port)" end +function _throw_pmix_env_error(msg) + msg = msg * " Open an issue on Reactant with the relevant PMIX Enviroment Variables \ + (you might want to obfuscate identifiable variables from this log \ + before opening an issue)\n\n" + for (var, val) in [var => val for (var, val) in ENV if startswith(var, "PMIX")] + msg *= " * $var => $val.\n" + end + return error(msg) +end + function get_coordinator_address(::OpenMPIPMIXEnvDetector, ::Integer) - varname = findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI) - pmix_uri = ENV[_PMIX_SERVER_URI[varname]] + pmix_version = parse(VersionNumber, ENV[_PMIX_VERSION]) + pmix_uri = ENV[_PMIX_SERVER_URI[findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)]] + @debug "PMIX VERSION: $(pmix_version)" + if v"5" ≤ pmix_version < v"6" + return get_coordinator_address_pmixv5(pmix_uri) + elseif v"2" ≤ pmix_version < v"4" + return get_coordinator_address_pmixv2_or_3(pmix_uri) + else + _throw_pmix_env_error("Unsupported PMIX version: $(pmix_version).") + end +end + +function get_coordinator_address_pmixv2_or_3(pmix_uri) + pre_semicolon = first(split(pmix_uri, ";")) + if startswith(pre_semicolon, "pmix-server.") + job_id = parse(Int, first(split(last(split(pre_semicolon, '.'; limit=2))))) + elseif contains(pre_semicolon, ".") + job_id = parse(Int, first(split(pre_semicolon, '.'))) + else + _throw_pmix_env_error("Could not parse coordinator address from Open MPI \ + environment.") + end + return get_coordinator_address_from_pmix_uri(pmix_uri, job_id) +end - job_id = parse(Int, split(split(pmix_uri, '-'; limit=3)[3], "@"; limit=2)[1]) +function get_coordinator_address_pmixv5(pmix_uri) + job_id = parse(Int, first(split(last(split(pmix_uri, '-'; limit=3)), "@"; limit=2))) + return get_coordinator_address_from_pmix_uri(pmix_uri, job_id) +end + +function get_coordinator_address_from_pmix_uri(pmix_uri, job_id) port = job_id % 2^12 + (65535 - 2^12 + 1) launcher_ip_match = match(r"tcp4://(.+?):|tcp6://\[(.+?)\]", pmix_uri) @@ -159,4 +214,45 @@ function get_local_process_id(::AbstractOMPIClusterEnvDetector) return parse(Int, ENV[_OMPI_LOCAL_PROCESS_ID]) end +# SlurmEnvDetector +# Based on https://github.com/jax-ml/jax/blob/d89835acbacec938971400d6fa54ea6dd5efe76c/jax/_src/clusters/slurm_cluster.py#L3 +const _SLURM_JOB_ID = "SLURM_JOB_ID" +const _SLURM_NODELIST = "SLURM_STEP_NODELIST" +const _SLURM_PROCESS_COUNT = "SLURM_NTASKS" +const _SLURM_PROCESS_ID = "SLURM_PROCID" +const _SLURM_LOCAL_PROCESS_ID = "SLURM_LOCALID" +const _SLURM_NUM_NODES = "SLURM_STEP_NUM_NODES" + +is_env_present(::SlurmEnvDetector) = haskey(ENV, _SLURM_JOB_ID) + +function get_coordinator_address(::SlurmEnvDetector, ::Integer) + port = parse(Int, ENV[_SLURM_JOB_ID]) % 2^12 + (65535 - 2^12 + 1) + + # Parse the first hostname of the job + # If we are looking for 'node001', + # node_list potential formats are 'node001', 'node001,host2', + # 'node[001-0015],host2', and 'node[001,007-015],host2'. + node_list = ENV[_SLURM_NODELIST] + ind = findfirst(Base.Fix2(in, (',', '[')), node_list) + ind = isnothing(ind) ? length(node_list) + 1 : ind + + if ind == length(node_list) + 1 || node_list[ind] == ',' + # 'node001' or 'node001,host2' + return "$(node_list[1:ind-1]):$(port)" + else + # 'node[001-0015],host2' or 'node[001,007-015],host2' + prefix = node_list[1:(ind - 1)] + suffix = node_list[(ind + 1):end] + ind2 = findfirst(Base.Fix2(in, (',', '-')), suffix) + ind2 = isnothing(ind2) ? length(suffix) : ind2 + return "$(prefix)$(suffix[1:ind2-1]):$(port)" + end +end + +get_process_count(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_PROCESS_COUNT]) + +get_process_id(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_PROCESS_ID]) + +get_local_process_id(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_LOCAL_PROCESS_ID]) + end diff --git a/src/xla/Distributed.jl b/src/xla/Distributed.jl index 791d3cdc11..d645691381 100644 --- a/src/xla/Distributed.jl +++ b/src/xla/Distributed.jl @@ -129,7 +129,7 @@ function update!( coordinator_address::String, num_processes::Int, process_id::Int, - local_gpu_device_ids::Vector{Int}, + local_gpu_device_ids::Union{Nothing,Vector{Int}}, coordinator_bind_address::Union{Nothing,String}=nothing, cluster_register_timeout_in_minutes::Integer=60, rpc_timeout_in_seconds::Integer=120, @@ -141,7 +141,9 @@ function update!( @assert 0 ≤ process_id < num_processes state.coordinator_address = coordinator_address - state.local_gpu_device_ids = local_gpu_device_ids + if local_gpu_device_ids !== nothing + state.local_gpu_device_ids = local_gpu_device_ids + end state.process_id = process_id state.num_processes = num_processes diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index 2dea365a55..0c3792116b 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -138,7 +138,9 @@ function XLA.buffer_on_cpu(::Array) end function XLA.to_host(buffer::Array, data, reactant_sharding) - if length(XLA.devices(XLA.sharding(buffer))) == 1 + reactant_sharding = Reactant.Sharding.unwrap_shardinfo(reactant_sharding) + + if reactant_sharding isa Reactant.Sharding.NoSharding GC.@preserve buffer data begin @ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer( buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid} @@ -147,7 +149,6 @@ function XLA.to_host(buffer::Array, data, reactant_sharding) return data end - reactant_sharding = Reactant.Sharding.unwrap_shardinfo(reactant_sharding) @assert reactant_sharding isa Reactant.Sharding.HloSharding client = XLA.client(buffer) all_devices = XLA.get_device.((client,), reactant_sharding.mesh.device_ids) diff --git a/test/cluster_detector.jl b/test/cluster_detector.jl new file mode 100644 index 0000000000..13d39dcf2d --- /dev/null +++ b/test/cluster_detector.jl @@ -0,0 +1,111 @@ +using Reactant, Test + +@testset "ORTE_URI parsing" begin + addr = withenv( + "OMPI_MCA_orte_hnp_uri" => "1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911", + ) do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.OpenMPIORTEEnvDetector(), -1 + ) + end + @test startswith(addr, "10.96.0.1:") + + addr = withenv( + "OMPI_MCA_orte_hnp_uri" => "1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370", + ) do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.OpenMPIORTEEnvDetector(), -1 + ) + end + @test startswith(addr, "fe80::b9b:ac5d:9cf0:b858:") +end + +@testset "PMIX_SERVER_URI parsing" begin + @test_throws ErrorException withenv( + "PMIX_SERVER_URI21" => "961478656.0;tcp4://127.0.0.1:35625", + "PMIX_NAMESPACE" => "961478657", + "PMIX_VERSION" => "4.1.5", + ) do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.OpenMPIPMIXEnvDetector(), -1 + ) + end + + addr = withenv( + "PMIX_SERVER_URI21" => "961478656.0;tcp4://127.0.0.1:35625", + "PMIX_NAMESPACE" => "961478657", + "PMIX_VERSION" => "3.1.5", + ) do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.OpenMPIPMIXEnvDetector(), -1 + ) + end + @test startswith(addr, "127.0.0.1:") + + addr = withenv( + "PMIX_SERVER_URI21" => "pmix-server.40985;tcp4://127.0.0.1:48103", + "PMIX_NAMESPACE" => "slurm.pmix.1591154.6", + "PMIX_VERSION" => "3.1.5rc4", + ) do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.OpenMPIPMIXEnvDetector(), -1 + ) + end + @test startswith(addr, "127.0.0.1:") + + addr = withenv( + "PMIX_SERVER_URI3" => "pmix-server.41512;tcp4://127.0.0.1:60120", + "PMIX_NAMESPACE" => "slurm.pmix.1591154.7", + "PMIX_VERSION" => "2.2.2", + ) do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.OpenMPIPMIXEnvDetector(), -1 + ) + end + @test startswith(addr, "127.0.0.1:") + + addr = withenv( + "PMIX_SERVER_URI2" => "prterun-hydra-3874047@0.0;tcp4://118.143.212.23:49157", + "PMIX_NAMESPACE" => "prterun-hydra-3874047@1", + "PMIX_VERSION" => "5.0.5rc10", + ) do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.OpenMPIPMIXEnvDetector(), -1 + ) + end + @test startswith(addr, "118.143.212.23:") +end + +@testset "Slurm parsing" begin + addr = withenv("SLURM_STEP_NODELIST" => "node001", "SLURM_JOB_ID" => "12345") do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.SlurmEnvDetector(), -1 + ) + end + @test startswith(addr, "node001:") + + addr = withenv("SLURM_STEP_NODELIST" => "node001,host2", "SLURM_JOB_ID" => "12345") do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.SlurmEnvDetector(), -1 + ) + end + @test startswith(addr, "node001:") + + addr = withenv( + "SLURM_STEP_NODELIST" => "node[001-015],host2", "SLURM_JOB_ID" => "12345" + ) do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.SlurmEnvDetector(), -1 + ) + end + @test startswith(addr, "node001:") + + addr = withenv( + "SLURM_STEP_NODELIST" => "node[001,007-015],host2", "SLURM_JOB_ID" => "12345" + ) do + Reactant.Distributed.get_coordinator_address( + Reactant.Distributed.SlurmEnvDetector(), -1 + ) + end + @test startswith(addr, "node001:") +end diff --git a/test/runtests.jl b/test/runtests.jl index 4e51efe428..58253e5231 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,6 +62,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Custom Number Types" include("custom_number_types.jl") end @safetestset "Sharding" include("sharding.jl") + @safetestset "Cluster Detection" include("cluster_detector.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" diff --git a/test/sharding.jl b/test/sharding.jl index 5acd9c82c9..6072a21fd8 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -25,11 +25,10 @@ end cdata_sharded2 = Reactant.to_rarray(data; sharding=data_sharding2) cdata_sharded3 = Reactant.to_rarray(data; sharding=data_sharding3) - @test data ≈ - Array(cdata) ≈ - Array(cdata_sharded) ≈ - Array(cdata_sharded2) ≈ - Array(cdata_sharded3) + @test data ≈ Array(cdata) + @test data ≈ Array(cdata_sharded) + @test data ≈ Array(cdata_sharded2) + @test data ≈ Array(cdata_sharded3) @test cdata_sharded.sharding isa Sharding.ShardInfo{<:Sharding.HloSharding} @test cdata_sharded2.sharding isa Sharding.ShardInfo{<:Sharding.HloSharding}