|
| 1 | +--- |
| 2 | +title: Threadsafe Evaluation |
| 3 | +engine: julia |
| 4 | +julia: |
| 5 | + exeflags: |
| 6 | + - "--threads=4" |
| 7 | +--- |
| 8 | + |
| 9 | +A common technique to speed up Julia code is to use multiple threads to run computations in parallel. |
| 10 | +The Julia manual [has a section on multithreading](https://docs.julialang.org/en/v1/manual/multi-threading), which is a good introduction to the topic. |
| 11 | + |
| 12 | +We assume that the reader is familiar with some threading constructs in Julia, and the general concept of data races. |
| 13 | +This page specificaly discusses Turing's support for threadsafe model evaluation. |
| 14 | + |
| 15 | +:::{.callout-note} |
| 16 | +Please note that this is a rapidly-moving topic, and things may change in future releases of Turing. |
| 17 | +If you are ever unsure about what works and doesn't, please don't hesitate to ask on Slack or Discourse (links can be found at the footer of this site)! |
| 18 | +::: |
| 19 | + |
| 20 | +## MCMC sampling |
| 21 | + |
| 22 | +For complete clarity, this page has nothing to do with parallel sampling of MCMC chains using |
| 23 | + |
| 24 | +```julia |
| 25 | +sample(model, sampler, MCMCThreads(), N, nchains) |
| 26 | +``` |
| 27 | + |
| 28 | +That parallelisation exists outside of the model evaluation, and thus is independent of the model contents. |
| 29 | +This page only discusses threading _inside_ Turing models. |
| 30 | + |
| 31 | +## Threading in Turing models |
| 32 | + |
| 33 | +Given that Turing models mostly contain 'plain' Julia code, one might expect that all threading constructs such as `Threads.@threads` or `Threads.@spawn` can be used inside Turing models. |
| 34 | + |
| 35 | +This is, to some extent, true: for example, you can use threading constructs to speed up deterministic computations. |
| 36 | +For example, here we use parallelism to speed up a transformation of `x`: |
| 37 | + |
| 38 | +```julia |
| 39 | +@model function f(y) |
| 40 | + x ~ dist |
| 41 | + x_transformed = similar(x) |
| 42 | + Threads.@threads for i in eachindex(x) |
| 43 | + x_transformed[i] = some_expensive_function(x[i]) |
| 44 | + end |
| 45 | + y ~ some_likelihood(x_transformed) |
| 46 | +end |
| 47 | +``` |
| 48 | + |
| 49 | +In general, for code that does not involve tilde-statements (`x ~ dist`), threading works exactly as it does in regular Julia code. |
| 50 | + |
| 51 | +**However, extra care must be taken when using tilde-statements (`x ~ dist`) inside threaded blocks.** |
| 52 | +The reason for this is because tilde-statements modify the internal VarInfo object used for model evaluation. |
| 53 | +Essentially, `x ~ dist` expands to something like |
| 54 | + |
| 55 | +```julia |
| 56 | +x, __varinfo__ = DynamicPPL.tilde_assume!!(..., __varinfo__) |
| 57 | +``` |
| 58 | + |
| 59 | +and writing into `__varinfo__` is, _in general_, not threadsafe. |
| 60 | +Thus, parallelising tilde-statements can lead to data races [as described in the Julia manual](https://docs.julialang.org/en/v1/manual/multi-threading/#Using-@threads-without-data-races). |
| 61 | + |
| 62 | +## Threaded tilde-statements |
| 63 | + |
| 64 | +**As of version 0.41, Turing only supports the use of tilde-statements inside threaded blocks when these are observations (i.e., likelihood terms).** |
| 65 | + |
| 66 | +This means that the following code is safe to use: |
| 67 | + |
| 68 | +```{julia} |
| 69 | +using Turing |
| 70 | +
|
| 71 | +@model function threaded_obs(N) |
| 72 | + x ~ Normal() |
| 73 | + y = Vector{Float64}(undef, N) |
| 74 | + Threads.@threads for i in 1:N |
| 75 | + y[i] ~ Normal(x) |
| 76 | + end |
| 77 | +end |
| 78 | +
|
| 79 | +N = 100 |
| 80 | +y = randn(N) |
| 81 | +model = threaded_obs(N) | (; y = y) |
| 82 | +``` |
| 83 | + |
| 84 | +Evaluating this model is threadsafe, in that Turing guarantees to provide the correct result in functions such as: |
| 85 | + |
| 86 | +```{julia} |
| 87 | +logjoint(model, (; x = 0.0)) |
| 88 | +``` |
| 89 | + |
| 90 | +(we can compare with the true value) |
| 91 | + |
| 92 | +```{julia} |
| 93 | +logpdf(Normal(), 0.0) + sum(logpdf.(Normal(0.0), y)) |
| 94 | +``` |
| 95 | + |
| 96 | +When sampling, you must disable model checking, but otherwise results will be correct: |
| 97 | + |
| 98 | +```{julia} |
| 99 | +sample(model, NUTS(), 100; check_model=false, progress=false) |
| 100 | +``` |
| 101 | + |
| 102 | +::: {.callout-warning} |
| 103 | +## Upcoming changes |
| 104 | + |
| 105 | +In the next release of Turing, if you use tilde-observations inside threaded blocks, you will have to declare this upfront using: |
| 106 | + |
| 107 | +```julia |
| 108 | +model = threaded_obs() | (; y = randn(N)) |
| 109 | +threadsafe_model = setthreadsafe(model, true) |
| 110 | +``` |
| 111 | + |
| 112 | +Then you can sample from `threadsafe_model` as before. |
| 113 | + |
| 114 | +The reason for this change is because threadsafe evaluation comes with a performance cost, which can sometimes be substantial. |
| 115 | +In the past, threadsafe evaluation was always enabled, i.e., this cost was *always* incurred whenever Julia was launched with more than one thread. |
| 116 | +However, this is not an appropriate way to determine whether threadsafe evaluation is needed! |
| 117 | +::: |
| 118 | + |
| 119 | +**On the other hand, parallelising the sampling of latent values is not supported.** |
| 120 | +Attempting to do this will either error or give wrong results. |
| 121 | + |
| 122 | +```{julia} |
| 123 | +#| error: true |
| 124 | +@model function threaded_assume_bad(N) |
| 125 | + x = Vector{Float64}(undef, N) |
| 126 | + Threads.@threads for i in 1:N |
| 127 | + x[i] ~ Normal() |
| 128 | + end |
| 129 | + return x |
| 130 | +end |
| 131 | +
|
| 132 | +model = threaded_assume_bad(100) |
| 133 | +
|
| 134 | +# This will throw an error (and probably a different error |
| 135 | +# each time it's run...) |
| 136 | +model() |
| 137 | +``` |
| 138 | + |
| 139 | +**Note, in particular, that this means that you cannot use `predict` to sample new data in parallel.** |
| 140 | +That is, even for `threaded_obs` where `y` was originally an observed term, you _cannot_ do: |
| 141 | + |
| 142 | +```{julia} |
| 143 | +#| error: true |
| 144 | +model = threaded_obs(N) | (; y = y) |
| 145 | +chn = sample(model, NUTS(), 100; check_model=false, progress=false) |
| 146 | +
|
| 147 | +pmodel = threaded_obs(N) # don't condition on data |
| 148 | +predict(pmodel, chn) |
| 149 | +``` |
| 150 | + |
| 151 | + |
| 152 | +:::{.callout-note} |
| 153 | +## Threaded `predict` |
| 154 | + |
| 155 | +Support for the above call to `predict` may land in the near future, with [this pull request](https://github.com/TuringLang/DynamicPPL.jl/pull/1130). |
| 156 | +::: |
| 157 | + |
| 158 | +## Alternatives to threaded observation |
| 159 | + |
| 160 | +An alternative to using threaded observations is to manually calculate the log-likelihood term (which can be parallelised using any of Julia's standard mechanisms), and then _outside_ of the threaded block, [add it to the model using `@addlogprob!`]({{< meta usage-modifying-logprob >}}). |
| 161 | + |
| 162 | +For example: |
| 163 | + |
| 164 | +```{julia} |
| 165 | +# Note that `y` has to be passed as an argument; you can't |
| 166 | +# condition on it because otherwise `y[i]` won't be defined. |
| 167 | +@model function threaded_obs_addlogprob(N, y) |
| 168 | + x ~ Normal() |
| 169 | +
|
| 170 | + # Instead of this: |
| 171 | + # Threads.@threads for i in 1:N |
| 172 | + # y[i] ~ Normal(x) |
| 173 | + # end |
| 174 | +
|
| 175 | + # Do this instead: |
| 176 | + lls = map(1:N) do i |
| 177 | + Threads.@spawn begin |
| 178 | + logpdf(Normal(x), y[i]) |
| 179 | + end |
| 180 | + end |
| 181 | + @addlogprob! sum(fetch.(lls)) |
| 182 | +end |
| 183 | +``` |
| 184 | + |
| 185 | +In a similar way, you can also use your favourite parallelism package, such as `FLoops.jl` or `OhMyThreads.jl`. |
| 186 | +See [this Discourse post](https://discourse.julialang.org/t/parallelism-within-turing-jl-model/54064/9) for some examples. |
| 187 | + |
| 188 | +We make no promises about the use of tilde-statements _with_ these packages (indeed it will most likely error), but as long as you use them to only parallelise regular Julia code (i.e., not tilde-statements), they will work as intended. |
| 189 | + |
| 190 | +One benefit of rewriting the model this way is that sampling from this model with `MCMCThreads()` will always be reproducible. |
| 191 | + |
| 192 | +```{julia} |
| 193 | +using Random |
| 194 | +N = 100 |
| 195 | +y = randn(N) |
| 196 | +model = threaded_obs_addlogprob(N, y) |
| 197 | +nuts_kwargs = (check_model=false, progress=false, verbose=false) |
| 198 | +
|
| 199 | +chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...) |
| 200 | +chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...) |
| 201 | +mean(chain1[:x]), mean(chain2[:x]) # should be identical |
| 202 | +``` |
| 203 | + |
| 204 | +In contrast, the original `threaded_obs` (which used tilde inside `Threads.@threads`) is not reproducible when using `MCMCThreads()`. |
| 205 | + |
| 206 | +```{julia} |
| 207 | +model = threaded_obs(N) | (; y = y) |
| 208 | +chain1 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...) |
| 209 | +chain2 = sample(Xoshiro(468), model, NUTS(), MCMCThreads(), 100, 4; nuts_kwargs...) |
| 210 | +mean(chain1[:x]), mean(chain2[:x]) # oops! |
| 211 | +``` |
| 212 | + |
| 213 | +## AD support |
| 214 | + |
| 215 | +Finally, if you are [using Turing with automatic differentiation]({{< meta usage-automatic-differentiation >}}), you also need to keep track of which AD backends support threadsafe evaluation. |
| 216 | + |
| 217 | +ForwardDiff is the only AD backend that we find to work reliably with threaded model evaluation. |
| 218 | + |
| 219 | +In particular: |
| 220 | + |
| 221 | + - ReverseDiff sometimes gives right results, but quite often gives incorrect gradients. |
| 222 | + - Mooncake [currently does not support multithreading at all](https://github.com/chalk-lab/Mooncake.jl/issues/570). |
| 223 | + - Enzyme [mostly gives the right result, but sometimes gives incorrect gradients](https://github.com/TuringLang/DynamicPPL.jl/issues/1131). |
| 224 | + |
| 225 | +## Under the hood |
| 226 | + |
| 227 | +:::{.callout-note} |
| 228 | +This part will likely only be of interest to DynamicPPL developers and the very curious user. |
| 229 | +::: |
| 230 | + |
| 231 | +TODO: Something about metadata, accumulators, and TSVI. |
| 232 | + |
| 233 | +TODO: Say how OnlyAccsVarInfo and FastLDF changes this. |
| 234 | + |
| 235 | +Essentially, `predict(model, chn)` SHOULD work after #1130 because that uses OAVI, which doesn't have Metadata. It uses VAIMAcc to accumulate the values, but that is threadsafe as long as TSVI is used. |
| 236 | + |
| 237 | +FastLDF, _once constructed_, also works with threaded assume. The only problem is that to get the ranges and linked status it has to first generate a VarInfo, which cannot be done. But if there's a way to either manually provide the ranges OR use an accumulator instead to get the ranges/linked status, then it would straight up enable threaded assume with NUTS / any sampler that only uses FastLDF. |
0 commit comments