Skip to content

Commit

Permalink
Merge pull request #203 from Julia-Tempering/fix-load
Browse files Browse the repository at this point in the history
Stop exporting load
  • Loading branch information
miguelbiron committed Feb 14, 2024
2 parents 7b4c0ba + b213531 commit a4ce7bf
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Pigeons"
uuid = "0eb8d820-af6a-4919-95ae-11206f830c31"
authors = ["Alexandre Bouchard-Côté <bouchard@stat.ubc.ca>, Nikola Surjanovic <nikola.surjanovic@stat.ubc.ca>, Paul Tiede <ptiede91@gmail.com>, Trevor Campbell, Miguel Biron-Lattes, Saifuddin Syed"]
version = "0.3.0"
version = "0.4.0"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Expand Down
12 changes: 6 additions & 6 deletions docs/src/mpi.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ To analyze the output, see the documentation page on [post-processing for MPI ru
back to your interactive chain via:

```@example local
pt = load(result) # possible thanks to 'pigeons(..., checkpoint = true)' used above
pt = Pigeons.load(result) # possible thanks to 'pigeons(..., checkpoint = true)' used above
```

## Running MPI on a cluster
Expand Down Expand Up @@ -86,12 +86,12 @@ mpi_run = pigeons(
target = toy_mvn_target(1000000),
n_chains = 1000,
checkpoint = true,
on = MPI(
on = MPIProcesses(
n_mpi_processes = 1000,
n_threads = 1))
```

This will start a distributed PT algorithm with 1000 chains on 1000 MPI processes, each using one thread, targeting a one million
This will start a distributed PT algorithm with 1000 chains on 1000 MPIProcesses processes, each using one thread, targeting a one million
dimensional target distribution. On the UBC Sockeye cluster, the last
round of this run (i.e. the last 1024 iterations) takes 10 seconds to complete, versus more than
2 hours if run serially, i.e. a >700x speed-up.
Expand All @@ -115,7 +115,7 @@ To analyze the output, see the documentation page on [post-processing for MPI ru
back to your interactive chain via:

```
pt = load(mpi_run) # possible thanks to 'pigeons(..., checkpoint = true)' used above
pt = Pigeons.load(mpi_run) # possible thanks to 'pigeons(..., checkpoint = true)' used above
```


Expand All @@ -126,11 +126,11 @@ are built-in inside the Pigeons module.
However in typical use cases,
some user-provided code needs to be provided to
[`ChildProcess`](@ref)
and [`MPI`](@ref) so that the other participating Julia
and [`MPIProcesses`](@ref) so that the other participating Julia
processes have access to it.
This is done with the argument `dependencies` (of type `Vector`; present in
both [`ChildProcess`](@ref)
and [`MPI`](@ref)).
and [`MPIProcesses`](@ref)).
Two types of elements can be used in the vector of dependencies, and they can be mixed:

- elements of type `Module`: for each of those, an `using` statement will be generated in the script used by the child process;
Expand Down
6 changes: 3 additions & 3 deletions docs/src/output-mpi-postprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ Option (1) is more convenient than (2) but it uses more RAM.

Many of Pigeons' post-processing tools take as input a [`PT`](@ref) struct.
When running locally, [`pigeons()`](@ref) returns a [`PT`](@ref) struct,
however, when running a job via [`MPI`](@ref) or [`ChildProcess`](@ref),
however, when running a job via [`MPIProcesses`](@ref) or [`ChildProcess`](@ref),
[`pigeons()`](@ref) returns a [`Result`](@ref) struct (which only holds the
directory where samples are stored).

Use [`load()`](@ref) to convert a [`Result`](@ref) into a
Use `Pigeons.load(..)` to convert a [`Result`](@ref) into a
[`PT`](@ref) struct.
This will load the information distributed across several machines
into the interactive node.
Expand Down Expand Up @@ -59,7 +59,7 @@ pt_result = pigeons(target = an_unidentifiable_model,
record = [traces; round_trip; record_default()])
# (*) load the result across all machines into this interactive node
pt = load(pt_result)
pt = Pigeons.load(pt_result)
# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
Expand Down
5 changes: 3 additions & 2 deletions src/Pigeons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ include("includes.jl")

export pigeons, Inputs, PT,
# for running jobs:
ChildProcess, MPI,
ChildProcess, MPIProcesses,
# references:
DistributionLogPotential,
# targets:
Expand All @@ -76,7 +76,8 @@ export pigeons, Inputs, PT,
index_process, swap_acceptance_pr, log_sum_ratio, online, round_trip, energy_ac1, traces, disk,
record_online, record_default,
# utils to run on scheduler:
Result, load, setup_mpi, queue_status, queue_ncpus_free, kill_job, watch,
Result, setup_mpi, queue_status, queue_ncpus_free, kill_job, watch,
# load, <- removed to avoid clash - see https://github.com/Julia-Tempering/Pigeons.jl/issues/200
# getting information out of an execution:
stepping_stone, n_tempered_restarts, n_round_trips, process_sample, get_sample,
# variational references:
Expand Down
2 changes: 1 addition & 1 deletion src/includes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ include("tempering/NonReversiblePT.jl")
include("tempering/StabilizedPT.jl")
include("tempering/tempering.jl")
include("swap/swap_graph.jl")
include("submission/MPI.jl")
include("submission/MPIProcesses.jl")
include("targets/BlangTarget.jl")
include("submission/MPISettings.jl")
include("submission/ChildProcess.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/pt/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function preflight_checks(inputs::Inputs)
"""
end
if mpi_active() && !inputs.checkpoint
@warn "To be able to call load() to retrieve samples in-memory, use pigeons(target = ..., checkpoint = true)"
@warn "To be able to call Pigeons.load() to retrieve samples in-memory, use pigeons(target = ..., checkpoint = true)"
end
if Threads.nthreads() > 1 && !inputs.multithreaded
@warn "More than one threads are available, but explore!() loop is not parallelized as inputs.multithreaded == false"
Expand Down Expand Up @@ -56,7 +56,7 @@ function check_against_serial(pt)
# run a serial copy
dependencies =
if isfile("$(pt.exec_folder)/.dependencies.jls")
# this process was itself spawn using ChildProcess/MPI
# this process was itself spawn using ChildProcess/MPIProcesses
# so use the same dependencies as this process
deserialize("$(pt.exec_folder)/.dependencies.jls")
else
Expand Down
16 changes: 8 additions & 8 deletions src/submission/MPI.jl → src/submission/MPIProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Fields:
$FIELDS
"""
@kwdef struct MPI <: Submission
@kwdef struct MPIProcesses <: Submission
"""
The number of threads per MPI process, 1 by default.
"""
Expand Down Expand Up @@ -51,7 +51,7 @@ end
"""
$SIGNATURES
"""
function pigeons(pt_arguments, mpi_submission::MPI)
function pigeons(pt_arguments, mpi_submission::MPIProcesses)
if !is_mpi_setup()
error("call setup_mpi(..) first")
end
Expand All @@ -77,13 +77,13 @@ function pigeons(pt_arguments, mpi_submission::MPI)
end

# todo: abstract out to other submission systems
function mpi_submission_cmd(exec_folder, mpi_submission::MPI, julia_cmd)
function mpi_submission_cmd(exec_folder, mpi_submission::MPIProcesses, julia_cmd)
r = rosetta()
submission_script = mpi_submission_script(exec_folder, mpi_submission, julia_cmd)
return `$(r.submit) $submission_script`
end

function mpi_submission_script(exec_folder, mpi_submission::MPI, julia_cmd)
function mpi_submission_script(exec_folder, mpi_submission::MPIProcesses, julia_cmd)
# TODO: generalize to other submission systems
# TODO: move some more things over from mpi-run
info_folder = "$exec_folder/info"
Expand Down Expand Up @@ -133,23 +133,23 @@ const _rosetta = (;

supported_submission_systems() = filter(x -> x != :queue_concept && x != :custom, keys(_rosetta))

resource_string(m::MPI, symbol) = resource_string(m, Val(symbol))
resource_string(m::MPIProcesses, symbol) = resource_string(m, Val(symbol))

resource_string(m::MPI, ::Val{:pbs}) =
resource_string(m::MPIProcesses, ::Val{:pbs}) =
# +-- each chunks should request as many cpus as threads,
# +-- number of "chunks"... | +-- NB: if mpiprocs were set to more than 1 this would give a number of mpi processes equal to select*mpiprocs
# v v v
"#PBS -l walltime=$(m.walltime),select=$(m.n_mpi_processes):ncpus=$(m.n_threads):mpiprocs=1:mem=$(m.memory)"

resource_string(m::MPI, ::Val{:slurm}) =
resource_string(m::MPIProcesses, ::Val{:slurm}) =
"""
#SBATCH -t $(m.walltime)
#SBATCH --ntasks=$(m.n_mpi_processes)
#SBATCH --cpus-per-task=$(m.n_threads)
#SBATCH --mem-per-cpu=$(m.memory)
"""

function resource_string(m::MPI, ::Val{:lsf})
function resource_string(m::MPIProcesses, ::Val{:lsf})
@assert m.n_threads == 1 "TODO: find how to specify number of threads per node with LSF"
"""
#BSUB -W $(m.walltime)
Expand Down
4 changes: 2 additions & 2 deletions src/submission/submission_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ function find_rank_file(folder, machine::Int)
return nothing
end

# construct launch cmd and script for MPI and ChildProcess
# construct launch cmd and script for MPIProcesses and ChildProcess

function launch_cmd(pt_arguments, exec_folder, dependencies, n_threads::Int, on_mpi::Bool)
script_path = launch_script(pt_arguments, exec_folder, dependencies, on_mpi)
Expand Down Expand Up @@ -200,7 +200,7 @@ function launch_code(
Pigeons.deserialize_immutables!(raw"$path_to_serialized_immutables")
deserialize(raw"$path_to_serialized_pt_arguments")
catch e
println("Hint: probably missing dependencies, use the dependencies argument in MPI() or ChildProcess()")
println("Hint: probably missing dependencies, use the dependencies argument in MPIProcesses() or ChildProcess()")
rethrow(e)
end
Expand Down
2 changes: 1 addition & 1 deletion test/activate_test_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# If we do not do this, we will end up testing the
# latest released version instead of the one checked out.

# We have to do this because ChildProcess/MPI depend on
# We have to do this because ChildProcess/MPIProcesses depend on
# a single toml file to know how to load Pigeons and other
# dependencies.

Expand Down
2 changes: 1 addition & 1 deletion test/test_checkpoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end
compare_pts(p1, p2)

r = pigeons(;target, checkpoint = true, on = ChildProcess(n_local_mpi_processes = 2, dependencies=[DynamicPPL,]))
p3 = load(r)
p3 = Pigeons.load(r)
compare_pts(p1, p3)
end
end
2 changes: 1 addition & 1 deletion test/test_lazy_target.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include("supporting/lazy.jl")
n_threads = 2,
dependencies = ["$(@__DIR__)/supporting/lazy.jl"]
))
pt1 = load(r)
pt1 = Pigeons.load(r)
pt2 = pigeons(target = toy_mvn_target(1))

@test Pigeons.recursive_equal(pt1.replicas, pt2.replicas)
Expand Down
6 changes: 3 additions & 3 deletions test/test_resume.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ end
pt = pigeons(; target = toy_mvn_target(1), checkpoint = true, on = ChildProcess(n_local_mpi_processes = 2))
exec = increment_n_rounds!(pt.exec_folder, 2)
r = pigeons(exec, ChildProcess(n_local_mpi_processes = 2))
Pigeons.check_against_serial(load(r))
Pigeons.check_against_serial(Pigeons.load(r))
end

@testset "Extend number of rounds with PT object, on ChildProcess" begin
pt = pigeons(; target = toy_mvn_target(1), checkpoint = true)
pt = increment_n_rounds!(pt, 2)
r = pigeons(pt.exec_folder, ChildProcess(n_local_mpi_processes = 2))
Pigeons.check_against_serial(load(r))
Pigeons.check_against_serial(Pigeons.load(r))
end

@testset "Complex example of increasing number of rounds many times" begin
Expand Down Expand Up @@ -59,5 +59,5 @@ end
result = pigeons(new_exec_folder, ChildProcess(n_local_mpi_processes = 2))

# make sure it is equivalent to doing it in one shot
Pigeons.check_against_serial(load(result))
Pigeons.check_against_serial(Pigeons.load(result))
end
2 changes: 1 addition & 1 deletion test/test_traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
multithreaded = false, # setting to true puts too much pressure on CI instances? https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5627897144/job/15251121621?pr=90
checkpoint = true,
on = ChildProcess(n_local_mpi_processes = 2, n_threads = 1, dependencies=[DynamicPPL, BridgeStan])) # setting to more than 1 puts too much pressure on CI instances?
pt = load(r)
pt = Pigeons.load(r)
@test length(pt.reduced_recorders.traces) == 1024 * (extended_traces ? 10 : 1)
for chain in Pigeons.chains_with_samples(pt)
marginal = [get_sample(pt, chain, i)[1] for i in 1:1024]
Expand Down

0 comments on commit a4ce7bf

Please sign in to comment.