Skip to content

Commit b439fcf

Browse files
committed
Threadsafe, draft 1
1 parent 1f808d5 commit b439fcf

File tree

2 files changed

+239
-0
lines changed

2 files changed

+239
-0
lines changed

_quarto.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ website:
8282
- usage/tracking-extra-quantities/index.qmd
8383
- usage/predictive-distributions/index.qmd
8484
- usage/mode-estimation/index.qmd
85+
- usage/threadsafe-evaluation/index.qmd
8586
- usage/performance-tips/index.qmd
8687
- usage/sampler-visualisation/index.qmd
8788
- usage/dynamichmc/index.qmd
@@ -215,6 +216,7 @@ usage-probability-interface: usage/probability-interface
215216
usage-sampler-visualisation: usage/sampler-visualisation
216217
usage-sampling-options: usage/sampling-options
217218
usage-submodels: usage/submodels
219+
usage-threadsafe-evaluation: usage/threadsafe-evaluation
218220
usage-tracking-extra-quantities: usage/tracking-extra-quantities
219221
usage-troubleshooting: usage/troubleshooting
220222

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
---
2+
title: Threadsafe Evaluation
3+
engine: julia
4+
julia:
5+
exeflags:
6+
- "--threads=4"
7+
---
8+
9+
A common technique to speed up Julia code is to use multiple threads to run computations in parallel.
10+
The Julia manual [has a section on multithreading](https://docs.julialang.org/en/v1/manual/multi-threading), which is a good introduction to the topic.
11+
12+
We assume that the reader is familiar with some threading constructs in Julia, and the general concept of data races.
13+
This page specificaly discusses Turing's support for threadsafe model evaluation.
14+
15+
:::{.callout-note}
16+
Please note that this is a rapidly-moving topic, and things may change in future releases of Turing.
17+
If you are ever unsure about what works and doesn't, please don't hesitate to ask on Slack or Discourse (links can be found at the footer of this site)!
18+
:::
19+
20+
## MCMC sampling
21+
22+
For complete clarity, this page has nothing to do with parallel sampling of MCMC chains using
23+
24+
```julia
25+
sample(model, sampler, MCMCThreads(), N, nchains)
26+
```
27+
28+
That parallelisation exists outside of the model evaluation, and thus is independent of the model contents.
29+
This page only discusses threading _inside_ Turing models.
30+
31+
## Threading in Turing models
32+
33+
Given that Turing models mostly contain 'plain' Julia code, one might expect that all threading constructs such as `Threads.@threads` or `Threads.@spawn` can be used inside Turing models.
34+
35+
This is, to some extent, true: for example, you can use threading constructs to speed up deterministic computations.
36+
For example, here we use parallelism to speed up a transformation of `x`:
37+
38+
```julia
39+
@model function f(y)
40+
x ~ dist
41+
x_transformed = similar(x)
42+
Threads.@threads for i in eachindex(x)
43+
x_transformed[i] = some_expensive_function(x[i])
44+
end
45+
y ~ some_likelihood(x_transformed)
46+
end
47+
```
48+
49+
In general, for code that does not involve tilde-statements (`x ~ dist`), threading works exactly as it does in regular Julia code.
50+
51+
**However, extra care must be taken when using tilde-statements (`x ~ dist`) inside threaded blocks.**
52+
The reason for this is because tilde-statements modify the internal VarInfo object used for model evaluation.
53+
Essentially, `x ~ dist` expands to something like
54+
55+
```julia
56+
x, __varinfo__ = DynamicPPL.tilde_assume!!(..., __varinfo__)
57+
```
58+
59+
and writing into `__varinfo__` is, _in general_, not threadsafe.
60+
Thus, parallelising tilde-statements can lead to data races [as described in the Julia manual](https://docs.julialang.org/en/v1/manual/multi-threading/#Using-@threads-without-data-races).
61+
62+
## Threaded tilde-statements
63+
64+
**As of version 0.41, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).**
65+
66+
This means that the following code is safe to use:
67+
68+
```{julia}
69+
using Turing
70+
71+
@model function threaded_obs(N)
72+
x ~ Normal()
73+
y = Vector{Float64}(undef, N)
74+
Threads.@threads for i in 1:N
75+
y[i] ~ Normal(x)
76+
end
77+
end
78+
79+
N = 100
80+
y = randn(N)
81+
model = threaded_obs(N) | (; y = y)
82+
```
83+
84+
Evaluating this model is threadsafe, in that Turing guarantees to provide the correct result in functions such as:
85+
86+
```{julia}
87+
logjoint(model, (; x = 0.0))
88+
```
89+
90+
(we can compare with the true value)
91+
92+
```{julia}
93+
logpdf(Normal(), 0.0) + sum(logpdf.(Normal(0.0), y))
94+
```
95+
96+
When sampling, you must disable model checking, but otherwise results will be correct:
97+
98+
```{julia}
99+
sample(model, NUTS(), 100; check_model=false, progress=false)
100+
```
101+
102+
::: {.callout-warning}
103+
## Upcoming changes
104+
105+
In the next release of Turing, if you use tilde-observations inside threaded blocks, you will have to declare this upfront using:
106+
107+
```julia
108+
model = threaded_obs() | (; y = randn(N))
109+
threadsafe_model = setthreadsafe(model, true)
110+
```
111+
112+
Then you can sample from `threadsafe_model` as before.
113+
114+
The reason for this change is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial.
115+
In the past, threadsafe evaluation was always enabled, i.e., this cost was *always* incurred whenever Julia was launched with more than one thread.
116+
However, this is not an appropriate way to determine whether threadsafe evaluation is needed!
117+
:::
118+
119+
**On the other hand, parallelising the sampling of latent values is not supported.**
120+
Attempting to do this will either error or give wrong results.
121+
122+
```{julia}
123+
#| error: true
124+
@model function threaded_assume_bad(N)
125+
x = Vector{Float64}(undef, N)
126+
Threads.@threads for i in 1:N
127+
x[i] ~ Normal()
128+
end
129+
return x
130+
end
131+
132+
model = threaded_assume_bad(100)
133+
134+
# This will throw an error (and probably a different error
135+
# each time it's run...)
136+
model()
137+
```
138+
139+
**Note, in particular, that this means that you cannot use `predict` to sample new data in parallel.**
140+
That is, even for `threaded_obs` where `y` was originally an observed term, you _cannot_ do:
141+
142+
```{julia}
143+
#| error: true
144+
model = threaded_obs(N) | (; y = y)
145+
chn = sample(model, NUTS(), 100; check_model=false, progress=false)
146+
147+
pmodel = threaded_obs(N) # don't condition on data
148+
predict(pmodel, chn)
149+
```
150+
151+
152+
:::{.callout-note}
153+
## Threaded `predict`
154+
155+
Support for the above call to `predict` may land in the near future, with [this pull request](https://github.com/TuringLang/DynamicPPL.jl/pull/1130).
156+
:::
157+
158+
## Alternatives to threaded observation
159+
160+
An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia's standard mechanisms), and then _outside_ of the threaded block, [add it to the model using `@addlogprob!`]({{< meta usage-modifying-logprob >}}).
161+
162+
For example:
163+
164+
```{julia}
165+
# Note that `y` has to be passed as an argument; you can't
166+
# condition on it because otherwise `y[i]` won't be defined.
167+
@model function threaded_obs_addlogprob(N, y)
168+
x ~ Normal()
169+
170+
# Instead of this:
171+
# Threads.@threads for i in 1:N
172+
# y[i] ~ Normal(x)
173+
# end
174+
175+
# Do this instead:
176+
lls = map(1:N) do i
177+
Threads.@spawn begin
178+
logpdf(Normal(x), y[i])
179+
end
180+
end
181+
@addlogprob! sum(fetch.(lls))
182+
end
183+
```
184+
185+
In a similar way, you can also use your favourite parallelism package, such as `FLoops.jl` or `OhMyThreads.jl`.
186+
See [this Discourse post](https://discourse.julialang.org/t/parallelism-within-turing-jl-model/54064/9) for some examples.
187+
188+
We make no promises about the use of tilde-statements _with_ these packages (indeed it will most likely error), but as long as you use them to only parallelise regular Julia code (i.e., not tilde-statements), they will work as intended.
189+
190+
One benefit of rewriting the model this way is that sampling from this model with `MCMCThreads()` will always be reproducible.
191+
192+
```{julia}
193+
using Random
194+
N = 100
195+
y = randn(N)
196+
model = threaded_obs_addlogprob(N, y)
197+
nuts_kwargs = (check_model=false, progress=false, verbose=false)
198+
199+
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...)
200+
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...)
201+
mean(chain1[:x]), mean(chain2[:x]) # should be identical
202+
```
203+
204+
In contrast, the original `threaded_obs` (which used tilde inside `Threads.@threads`) is not reproducible when using `MCMCThreads()`.
205+
206+
```{julia}
207+
model = threaded_obs(N) | (; y = y)
208+
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...)
209+
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...)
210+
mean(chain1[:x]), mean(chain2[:x]) # oops!
211+
```
212+
213+
## AD support
214+
215+
Finally, if you are [using Turing with automatic differentiation]({{< meta usage-automatic-differentiation >}}), you also need to keep track of which AD backends support threadsafe evaluation.
216+
217+
ForwardDiff is the only AD backend that we find to work reliably with threaded model evaluation.
218+
219+
In particular:
220+
221+
- ReverseDiff sometimes gives right results, but quite often gives incorrect gradients.
222+
- Mooncake [currently does not support multithreading at all](https://github.com/chalk-lab/Mooncake.jl/issues/570).
223+
- Enzyme [mostly gives the right result, but sometimes gives incorrect gradients](https://github.com/TuringLang/DynamicPPL.jl/issues/1131).
224+
225+
## Under the hood
226+
227+
:::{.callout-note}
228+
This part will likely only be of interest to DynamicPPL developers and the very curious user.
229+
:::
230+
231+
TODO: Something about metadata, accumulators, and TSVI.
232+
233+
TODO: Say how OnlyAccsVarInfo and FastLDF changes this.
234+
235+
Essentially, `predict(model, chn)` SHOULD work after #1130 because that uses OAVI, which doesn't have Metadata. It uses VAIMAcc to accumulate the values, but that is threadsafe as long as TSVI is used.
236+
237+
FastLDF, _once constructed_, also works with threaded assume. The only problem is that to get the ranges and linked status it has to first generate a VarInfo, which cannot be done. But if there's a way to either manually provide the ranges OR use an accumulator instead to get the ranges/linked status, then it would straight up enable threaded assume with NUTS / any sampler that only uses FastLDF.

0 commit comments

Comments
 (0)