Skip to content

Commit c09813a

Browse files
authored
Threadsafe evaluation (#667)
1 parent c04bd63 commit c09813a

File tree

2 files changed

+272
-0
lines changed

2 files changed

+272
-0
lines changed

_quarto.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ website:
8282
- usage/tracking-extra-quantities/index.qmd
8383
- usage/predictive-distributions/index.qmd
8484
- usage/mode-estimation/index.qmd
85+
- usage/threadsafe-evaluation/index.qmd
8586
- usage/performance-tips/index.qmd
8687
- usage/sampler-visualisation/index.qmd
8788
- usage/dynamichmc/index.qmd
@@ -215,6 +216,7 @@ usage-probability-interface: usage/probability-interface
215216
usage-sampler-visualisation: usage/sampler-visualisation
216217
usage-sampling-options: usage/sampling-options
217218
usage-submodels: usage/submodels
219+
usage-threadsafe-evaluation: usage/threadsafe-evaluation
218220
usage-tracking-extra-quantities: usage/tracking-extra-quantities
219221
usage-troubleshooting: usage/troubleshooting
220222

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
---
2+
title: Threadsafe Evaluation
3+
engine: julia
4+
julia:
5+
exeflags:
6+
- "--threads=4"
7+
---
8+
9+
A common technique to speed up Julia code is to use multiple threads to run computations in parallel.
10+
The Julia manual [has a section on multithreading](https://docs.julialang.org/en/v1/manual/multi-threading), which is a good introduction to the topic.
11+
12+
We assume that the reader is familiar with some threading constructs in Julia, and the general concept of data races.
13+
This page specificaly discusses Turing's support for threadsafe model evaluation.
14+
15+
:::{.callout-note}
16+
Please note that this is a rapidly-moving topic, and things may change in future releases of Turing.
17+
If you are ever unsure about what works and doesn't, please don't hesitate to ask on [Slack](https://julialang.slack.com/archives/CCYDC34A0) or [Discourse](https://discourse.julialang.org/c/domain/probprog/48)
18+
:::
19+
20+
## MCMC sampling
21+
22+
For complete clarity, this page has nothing to do with parallel sampling of MCMC chains using
23+
24+
```julia
25+
sample(model, sampler, MCMCThreads(), N, nchains)
26+
```
27+
28+
That parallelisation exists outside of the model evaluation, and thus is independent of the model contents.
29+
This page only discusses threading _inside_ Turing models.
30+
31+
## Threading in Turing models
32+
33+
Given that Turing models mostly contain 'plain' Julia code, one might expect that all threading constructs such as `Threads.@threads` or `Threads.@spawn` can be used inside Turing models.
34+
35+
This is, to some extent, true: for example, you can use threading constructs to speed up deterministic computations.
36+
For example, here we use parallelism to speed up a transformation of `x`:
37+
38+
```julia
39+
@model function f(y)
40+
x ~ dist
41+
x_transformed = similar(x)
42+
Threads.@threads for i in eachindex(x)
43+
x_transformed[i] = some_expensive_function(x[i])
44+
end
45+
y ~ some_likelihood(x_transformed)
46+
end
47+
```
48+
49+
In general, for code that does not involve tilde-statements (`x ~ dist`), threading works exactly as it does in regular Julia code.
50+
51+
**However, extra care must be taken when using tilde-statements (`x ~ dist`) inside threaded blocks.**
52+
The reason for this is because tilde-statements modify the internal VarInfo object used for model evaluation.
53+
Essentially, `x ~ dist` expands to something like
54+
55+
```julia
56+
x, __varinfo__ = DynamicPPL.tilde_assume!!(..., __varinfo__)
57+
```
58+
59+
and writing into `__varinfo__` is, _in general_, not threadsafe.
60+
Thus, parallelising tilde-statements can lead to data races [as described in the Julia manual](https://docs.julialang.org/en/v1/manual/multi-threading/#Using-@threads-without-data-races).
61+
62+
## Threaded tilde-statements
63+
64+
**As of version 0.41, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).**
65+
66+
This means that the following code is safe to use:
67+
68+
```{julia}
69+
using Turing
70+
71+
@model function threaded_obs(N)
72+
x ~ Normal()
73+
y = Vector{Float64}(undef, N)
74+
Threads.@threads for i in 1:N
75+
y[i] ~ Normal(x)
76+
end
77+
end
78+
79+
N = 100
80+
y = randn(N)
81+
model = threaded_obs(N) | (; y = y)
82+
```
83+
84+
Evaluating this model is threadsafe, in that Turing guarantees to provide the correct result in functions such as:
85+
86+
```{julia}
87+
logjoint(model, (; x = 0.0))
88+
```
89+
90+
(we can compare with the true value)
91+
92+
```{julia}
93+
logpdf(Normal(), 0.0) + sum(logpdf.(Normal(0.0), y))
94+
```
95+
96+
When sampling, you must disable model checking, but otherwise results will be correct:
97+
98+
```{julia}
99+
sample(model, NUTS(), 100; check_model=false, progress=false)
100+
```
101+
102+
::: {.callout-warning}
103+
## Upcoming changes
104+
105+
Starting from DynamicPPL 0.39, if you use tilde-statements or `@addlogprob!` inside threaded blocks, you will have to declare this upfront using:
106+
107+
```julia
108+
model = threaded_obs() | (; y = randn(N))
109+
threadsafe_model = setthreadsafe(model, true)
110+
```
111+
112+
Then you can sample from `threadsafe_model` as before.
113+
114+
The reason for this change is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial.
115+
In the past, threadsafe evaluation was always enabled, i.e., this cost was *always* incurred whenever Julia was launched with more than one thread.
116+
However, this is not an appropriate way to determine whether threadsafe evaluation is needed!
117+
:::
118+
119+
**On the other hand, parallelising the sampling of latent values is not supported.**
120+
Attempting to do this will either error or give wrong results.
121+
122+
```{julia}
123+
#| error: true
124+
@model function threaded_assume_bad(N)
125+
x = Vector{Float64}(undef, N)
126+
Threads.@threads for i in 1:N
127+
x[i] ~ Normal()
128+
end
129+
return x
130+
end
131+
132+
model = threaded_assume_bad(100)
133+
134+
# This will throw an error (and probably a different error
135+
# each time it's run...)
136+
model()
137+
```
138+
139+
**Note, in particular, that this means that you cannot currently use `predict` to sample new data in parallel.**
140+
141+
:::{.callout-note}
142+
## Threaded `predict`
143+
144+
Support for threaded `predict` will be added in DynamicPPL 0.39 (see [this pull request](https://github.com/TuringLang/DynamicPPL.jl/pull/1130)).
145+
:::
146+
147+
That is, even for `threaded_obs` where `y` was originally an observed term, you _cannot_ do:
148+
149+
```{julia}
150+
#| error: true
151+
model = threaded_obs(N) | (; y = y)
152+
chn = sample(model, NUTS(), 100; check_model=false, progress=false)
153+
154+
pmodel = threaded_obs(N) # don't condition on data
155+
predict(pmodel, chn)
156+
```
157+
158+
## Alternatives to threaded observation
159+
160+
An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia's standard mechanisms), and then _outside_ of the threaded block, [add it to the model using `@addlogprob!`]({{< meta usage-modifying-logprob >}}).
161+
162+
For example:
163+
164+
```{julia}
165+
# Note that `y` has to be passed as an argument; you can't
166+
# condition on it because otherwise `y[i]` won't be defined.
167+
@model function threaded_obs_addlogprob(N, y)
168+
x ~ Normal()
169+
170+
# Instead of this:
171+
# Threads.@threads for i in 1:N
172+
# y[i] ~ Normal(x)
173+
# end
174+
175+
# Do this instead:
176+
lls = map(1:N) do i
177+
Threads.@spawn begin
178+
logpdf(Normal(x), y[i])
179+
end
180+
end
181+
@addlogprob! sum(fetch.(lls))
182+
end
183+
```
184+
185+
In a similar way, you can also use your favourite parallelism package, such as `FLoops.jl` or `OhMyThreads.jl`.
186+
See [this Discourse post](https://discourse.julialang.org/t/parallelism-within-turing-jl-model/54064/9) for some examples.
187+
188+
We make no promises about the use of tilde-statements _with_ these packages (indeed it will most likely error), but as long as you use them to only parallelise regular Julia code (i.e., not tilde-statements), they will work as intended.
189+
190+
The main downside of this approach is:
191+
192+
1. You can't use conditioning syntax to provide data; it has to be passed as an argument or otherwise included inside the model.
193+
2. You can't use `predict` to sample new data.
194+
195+
On the other hand, one benefit of rewriting the model this way is that sampling from this model with `MCMCThreads()` will always be reproducible.
196+
197+
```{julia}
198+
using Random
199+
N = 100
200+
y = randn(N)
201+
model = threaded_obs_addlogprob(N, y)
202+
nuts_kwargs = (check_model=false, progress=false, verbose=false)
203+
204+
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
205+
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
206+
mean(chain1[:x]), mean(chain2[:x]) # should be identical
207+
```
208+
209+
In contrast, the original `threaded_obs` (which used tilde inside `Threads.@threads`) is not reproducible when using `MCMCThreads()`.
210+
(In principle, we would like to fix this bug, but we haven't yet investigated where it stems from.)
211+
212+
```{julia}
213+
model = threaded_obs(N) | (; y = y)
214+
nuts_kwargs = (check_model=false, progress=false, verbose=false)
215+
chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
216+
chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 1000, 4; nuts_kwargs...)
217+
mean(chain1[:x]), mean(chain2[:x]) # oops!
218+
```
219+
220+
## AD support
221+
222+
Finally, if you are [using Turing with automatic differentiation]({{< meta usage-automatic-differentiation >}}), you also need to keep track of which AD backends support threadsafe evaluation.
223+
224+
ForwardDiff is the only AD backend that we find to work reliably with threaded model evaluation.
225+
226+
In particular:
227+
228+
- ReverseDiff sometimes gives right results, but quite often gives incorrect gradients.
229+
- Mooncake [currently does not support multithreading at all](https://github.com/chalk-lab/Mooncake.jl/issues/570).
230+
- Enzyme [mostly gives the right result, but sometimes gives incorrect gradients](https://github.com/TuringLang/DynamicPPL.jl/issues/1131).
231+
232+
## Under the hood
233+
234+
:::{.callout-note}
235+
This part will likely only be of interest to DynamicPPL developers and the very curious user.
236+
:::
237+
238+
### Why is VarInfo not threadsafe?
239+
240+
As alluded to above, the issue with threaded tilde-statements stems from the fact that these tilde-statements modify the VarInfo object used for model evaluation, leading to potential data races.
241+
242+
Traditionally, VarInfo objects contain both *metadata* as well as *accumulators*.
243+
Metadata is where information about the random variables' values are stored.
244+
It is a Dict-like structure, and pushing to it from multiple threads is therefore not threadsafe (Julia's `Dict` has similar limitations).
245+
246+
On the other hand, accumulators are used to store outputs of the model, such as log-probabilities
247+
The way DynamicPPL's threadsafe evaluation works is to create one set of accumulators per thread, and then combine the results at the end of model evaluation.
248+
249+
In this way, any function call that _solely_ involving accumulators can be made threadsafe.
250+
For example, this is why observations are supported: there is no need to modify metadata, and only the log-likelihood accumulator needs to be updated.
251+
252+
However, `assume` tilde-statements always modify the metadata, and thus cannot currently be made threadsafe.
253+
254+
### OnlyAccsVarInfo
255+
256+
As it happens, much of what is needed in DynamicPPL can be constructed such that they *only* rely on accumulators.
257+
258+
For example, as long as there is no need to *sample* new values of random variables, it is actually fine to completely omit the metadata object.
259+
This is the case for `LogDensityFunction`: since values are provided as the input vector, there is no need to store it in metadata.
260+
We need only calculate the associated log-prior probability, which is stored in an accumulator.
261+
Thus, starting from DynamicPPL v0.39, `LogDensityFunction` itself will in fact be completely threadsafe.
262+
263+
Technically speaking, this is achieved using `OnlyAccsVarInfo`, which is a subtype of `VarInfo` that only contains accumulators, and no metadata at all.
264+
It implements enough of the `VarInfo` interface to be used in model evaluation, but will error if any functions attempt to modify or read its metadata.
265+
266+
There is currently an ongoing push to use `OnlyAccsVarInfo` in as many settings as we possibly can.
267+
For example, this is why `predict` will be threadsafe in DynamicPPL v0.39: instead of modifying metadata to store the predicted values, we store them inside a `ValuesAsInModelAccumulator` instead, and combine them at the end of evaluation.
268+
269+
However, propagating these changes up to Turing will require a substantial amount of additional work, since there are many places in Turing which currently rely on a full VarInfo (with metadata).
270+
See, e.g., [this PR](https://github.com/TuringLang/DynamicPPL.jl/pull/1154) for more information.

0 commit comments

Comments
 (0)