# Multithreading in Julia

_Part of this notebook is inspired by the material of th [Julia for HPC Course @ UCL ARC ](https://github.com/carstenbauer/JuliaUCL24) by Carsten Bauer._

## Setup

In [None]:
# Running this cell is important to make sure we install all the necessary packages.
using Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()

## Thread pinning

In [None]:
using ThreadPinning
pinthreads(:cores)
threadinfo(; slurm=ThreadPinning.SLURM.isslurmjob())

## Spawning parallel tasks

In [None]:
using Base.Threads

@show nthreads();

In [None]:
@time t = @spawn begin # `@spawn` returns right away
    sleep(2)
    3+3
end

@time fetch(t) # `fetch` waits for the task to finish

## Exercise: task-based parallelised `map`

In [None]:
using LinearAlgebra, BenchmarkTools

BLAS.set_num_threads(1) # Fix number of BLAS threads

# Exercise: define a task-based `map` function, using tasks
function tmap(fn, itr)
    # Define a variable called `tasks`, which is a vector, which for all element `x` of `itr` holds the task running `fn(x)`
    # Hint #1: you can use `map(f, itr)` to  apply the function `f` on each element of `itr`.
    # Hint #2: anonymous functions `x -> f(x)` are a convenient syntax for defining one-line functions inside other functions.
    tasks = map(i -> ..., itr)
    # Call `fetch` on all elements of `tasks` to collect the result of all spawned tasks, and return the result.
    # Hint: you can use broadcasting for running a function element-wise on an iterator.
    return # ...
end

M = [rand(100,100) for i in 1:(8 * nthreads())];

@btime  map(svdvals, $M) samples=10 evals=3;
@btime tmap(svdvals, $M) samples=10 evals=3;

***Bonus***: do you see any difference if you increase the number of BLAS threads?

## Exercise: multi-threaded `for` loop (reduction)

<div class="alert alert-info">
  <strong><tt>ChunkSplitters.jl</tt></strong> <br />The simple package <a href="https://juliafolds2.github.io/ChunkSplitters.jl/stable/" class="alert-link"><tt>ChunkSplitters.jl</tt></a> provides the function <a href="https://juliafolds2.github.io/ChunkSplitters.jl/stable/references/#ChunkSplitters.chunks" class="alert-link"><tt>chunks</tt></a> which returns an iterator with chunks of the input data, which you can then use for a threaded <tt>for</tt> loop, or to spawn tasks in parallel.
</div>

<div class="alert alert-info">
  <strong>Computing <tt>sum</tt> after applying function elementwise</strong> <br />The <a href="https://docs.julialang.org/en/v1/base/collections/#Base.sum" class="alert-link"><tt>sum</tt></a> function takes optionally as first argument a function to apply elementwise to all elements of the input iteartor, before computing the sum.
</div>

In [None]:
using ChunkSplitters, Base.Threads, BenchmarkTools

# Define a function which computes the sum of the elements of an iterator `data` in parallel,
# after applying a user-supplied function `fn` element-wise
function sum_threads(fn, data; nchunks=nthreads())
    psums = zeros(eltype(data), nchunks)
    # Hint: place `@threads` in front of the `for` loop run it multi-threaded.
    @threads for (c, elements) in enumerate(chunks(data; n=nchunks))
        # Hint: each element of `psums` should be the sum of `fn` applied
        # elelementwise to all the items of the current iteration.
        psums[c] = # .....
    end
    return # .....
end

v = randn(20_000_000);

@btime sum(sin, $v);

@btime sum_threads(sin, $v);

***Bonus***: `Threads.@threads` lets you choose the scheduler with a symbol between `@threads` and `for`, see [its docstring](https://docs.julialang.org/en/v1/base/multi-threading/#Base.Threads.@threads) for more details.  Do you see differences if you change the scheduler type?  Remember you can choose between `:dynamic` (currently the default if omitted), `:greedy` (only available when using Julia v1.11+), and `:static`.

<div class="alert alert-info">
  <strong>Syntax tip</strong> <br />The <a href="https://docs.julialang.org/en/v1/base/base/#do" class="alert-link"><tt>do</tt>-block syntax</a> allows you to write (possibly long) anonymous functions and automatically pass them as first argument to functions which take another function as first argument. This syntax is often used in conjunction with <a href="https://docs.julialang.org/en/v1/base/collections/#Base.map" class="alert-link"><tt>map</tt></a>, which takes a function as first argument.
</div>

In [None]:
# Define a function which does the parallel sum using `map` + `@spawn`
# as we've done above, instead of `@threads for`.
function sum_map_spawn(fn, data; nchunks=nthreads())
    ts = map(chunks(data, n=nchunks)) do elements
        # ....
    end
    return # ....
end

@btime sum_map_spawn(sin, $v);

### Bonus: using OhMyThreads.jl

<div class="alert alert-info">
  <strong><tt>OhMyThreads.jl</tt></strong> <br />The package <a href="https://juliafolds2.github.io/OhMyThreads.jl/" class="alert-link"><tt>OhMyThreads.jl</tt></a> provides user-friendly constructs for task-based multithreaded computing.
</div>

In [None]:
using OhMyThreads: @tasks

# Define a function which uses `OhMyThreads`' `@tasks` instead of `Threads.@threads`.
function sum_tasks(fn, data; nchunks=nthreads())
    psums = zeros(eltype(data), nchunks)
    # Hint: this function will look a lot like `sum_threads` above,
    # but with `@tasks` instead of `@threads`.
    @tasks for (c, elements) in enumerate(chunks(data; n=nchunks))
        psums[c] = # ....
    end
    return # ....
end

@btime sum_tasks(sin, $v);

<div class="alert alert-info">
  <strong>Parallel <tt>mapreduce</tt> with <tt>OhMyThreads.jl</tt></strong> <br /><tt>OhMyThreads.jl</tt> provides also a function called <a href="https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#OhMyThreads.tmapreduce" class="alert-link"><tt>tmapreduce</tt></a> which does a task-based multi-threaded <a href="https://docs.julialang.org/en/v1/base/collections/#Base.mapreduce-Tuple%7BAny,%20Any,%20Any%7D" class="alert-link"><tt>mapreduce</tt></a> operation, pretty much what we wanted to do above.
</div>

In [None]:
using OhMyThreads: tmapreduce

# Call `tmapreduce` applying the function `sin` to all elements of `v`
# and then take the sum of all these elements.
@btime tmapreduce(..., ..., $v);

## Multi-threading: is it always worth it?

In [None]:
using BenchmarkTools

function overhead!(v)
    for idx in eachindex(v)
        v[idx] = idx
    end
end

# Write a function equivalent to the one above, but with a `@threads`-ed `for` loop.
function overhead_threads!(v)
    # ....
end

N = 10

@btime overhead!(v) setup=(v = Vector{Int}(undef, N))
@btime overhead_threads!(v) setup=(v = Vector{Int}(undef, N))

<div class="alert alert-warning">
  <strong>Multi-threading overhead</strong> <br />Since each iteration of the <tt>for</tt> loops above is very fast (single memory copy), what you are measuring when running <tt>overhead_threads!</tt> is basically the overhead of spawning <tt>Threads.nthreads()</tt> tasks.
    The cost of spawning tasks is of the order of ~microseconds, so to make multi-threading beneficial you need to make sure each iteration/task is computationally intensive enough to overcome the cost of spawning tasks in the first place.
    This is the reason why we iterated over larger chunks of data in the examples above, and performed some operations on it: we wanted to do some substantial computation in each iteration/task, to make the parallelisation efficient.
</div>

***Bonus***: do you see any improvement in the parallel efficiency if you change the size of the problem (here: `N`)? Can you think of a better strategy for this problem?

## Unbalanced workload: computing hexadecimal $\pi$

_This section is inspired by the blogpost [Computing the hexadecimal value of pi](https://giordano.github.io/blog/2017-11-21-hexadecimal-pi/) by Mosè Giordano._

The [Bailey–Borwein–Plouffe formula](https://en.wikipedia.org/wiki/Bailey%E2%80%93Borwein%E2%80%93Plouffe_formula) is one of the [several algorithms to compute $\pi$](https://en.wikipedia.org/wiki/Approximations_of_%CF%80):

$$
\pi = \sum_{k = 0}^{\infty}\left[ \frac{1}{16^k} \left( \frac{4}{8k + 1} -
\frac{2}{8k + 4} - \frac{1}{8k + 5} - \frac{1}{8k + 6} \right) \right]
$$

What makes this formula stand out among other approximations of $\pi$ is that it allows one to directly extract the $n$-th fractional digit of the hexadecimal value of $\pi$ without computing the preceding ones.

The Wikipedia article about the Bailey–Borwein–Plouffe formula explains that the $n + 1$-th fractional digit $d_n$ is given by

$$
d_{n} = 16 \left[ 4 \Sigma(n, 1) - 2 \Sigma(n, 4) - \Sigma(n, 5) - \Sigma(n,
6) \right]
$$

where

$$
\Sigma(n, j) = \sum_{k = 0}^{n} \frac{16^{n-k} \bmod (8k+j)}{8k+j} + \sum_{k
= n+1}^{\infty} \frac{16^{n-k}}{8k+j}
$$

Only the fractional part of expression in square brackets on the right side of $d_n$ is relevant, thus, in order to avoid rounding errors, when we compute each term of the finite sum above we can take only the fractional part. This allows us to always use ordinary double precision floating-point arithmetic, without resorting to arbitrary-precision numbers. In addition note that the terms of the infinite sum get quickly very small, so we can stop the summation when they become negligible.

### Serial implementation

In [None]:
# Return the fractional part of x, modulo 1, always positive
fpart(x) = mod(x, one(x))

function Σ(n, j)
    # Compute the finite sum
    s = 0.0
    denom = j
    for k in 0:n
        s = fpart(s + powermod(16, n - k, denom) / denom)
        denom += 8
    end
    # Compute the infinite sum
    num = 1 / 16
    while (frac = num / denom) > eps(s)
        s     += frac
        num   /= 16
        denom += 8
    end
    return fpart(s)
end

pi_digit(n) =
    floor(Int, 16 * fpart(4Σ(n-1, 1) - 2Σ(n-1, 4) - Σ(n-1, 5) - Σ(n-1, 6)))

pi_string(n) = "0x3." * join(string.(pi_digit.(1:n); base = 16)) * "p0"

Let's make sure this works:

In [None]:
pi_string(13)

In [None]:
# Parse the string as a double-precision floating point number
parse(Float64, pi_string(13))

In [None]:
Float64(π) == parse(Float64, pi_string(13))

In [None]:
N_pi = 1_000

setprecision(BigFloat, 4 * N_pi) do
    BigFloat(π) == parse(BigFloat, pi_string(N_pi))
end

In [None]:
using BenchmarkTools

b = @benchmark pi_string(N_pi)

pi_serial_t = minimum(b.times)

b

### Multi-threaded implementation

Since the Bailey–Borwtimesn–Plouffe formula extracts the $n$-th digit of $\pi$ without computing the other ones, we can write a multi-threaded version of `pi_string`, taking advantage of native support for [multi-threading](https://docs.julialang.org/en/v1/manual/multi-threading/) in Julia. However note that the computational cost of `pi_digit` is $O(n\log(n))$, so the larger the value of $n$, the longer the function will take, which makes this workload very unbalanced. ***Question***: what do you expect to be the best and worst performing schedulers?

#### For-loop: static scheduler

In [None]:
# Write a function which returns the same string as `pi_string(N_pi)`, using a `:static` multi-threaded `for` loop
function pi_string_threads_static(N)
    digits = Vector{Int}(undef, N)
    @threads ... for n in eachindex(digits)
        digits[n] = # ...
    end
    return "0x3." * ... * "p0"
end

@assert pi_string_threads_static(N_pi) == pi_string(N_pi)

b = @benchmark pi_string_threads_static(N_pi)

pi_threads_static_t = minimum(b.times)

display(b)

@info "parallel efficiency: $(round(pi_serial_t / pi_threads_static_t / nthreads() * 100; digits=2))%"

#### For-loop: dynamic scheduler

In [None]:
# Write a function which returns the same string as `pi_string(N_pi)`, using a `:dynamic` multi-threaded `for` loop
function pi_string_threads_dynamic(N)
    digits = Vector{Int}(undef, N)
    @threads ... for n in eachindex(digits)
        digits[n] = # ...
    end
    return "0x3." * .... * "p0"
end

@assert pi_string_threads_dynamic(N_pi) == pi_string(N_pi)

b = @benchmark pi_string_threads_dynamic(N_pi)

pi_threads_dynamic_t = minimum(b.times)

display(b)

@info "parallel efficiency: $(round(pi_serial_t / pi_threads_dynamic_t / nthreads() * 100; digits=2))%"

#### For-loop: greedy scheduler (only Julia v1.11+)

In [None]:
@static if VERSION >= v"1.11"

# Write a function which returns the same string as `pi_string(N_pi)`, using a `:greedy` multi-threaded `for` loop
function pi_string_threads_greedy(N)
    digits = Vector{Int}(undef, N)
    @threads ... for n in eachindex(digits)
        digits[n] = # ..
    end
    return "0x3." * .... * "p0"
end

@assert pi_string_threads_greedy(N_pi) == pi_string(N_pi)

b = @benchmark pi_string_threads_greedy(N_pi)

pi_threads_greedy_t = minimum(b.times)

display(b)

@info "parallel efficiency: $(round(pi_serial_t / pi_threads_greedy_t / nthreads() * 100; digits=2))%"

end

#### Tasks

In [None]:
# Write a function which returns the same string as `pi_string(N_pi)`, spawning tasks
function pi_string_tasks(N)
    tasks = map(n -> ..., 1:N)
    return "0x3." * ... * "p0"
end

@assert pi_string_tasks(N_pi) == pi_string(N_pi)

b = @benchmark pi_string_tasks(N_pi)

pi_tasks_t = minimum(b.times)

display(b)

@info "parallel efficiency: $(round(pi_serial_t / pi_tasks_t / nthreads() * 100; digits=2))%"

#### Bonus: using OhMyThreads.jl

<div class="alert alert-info">
  <strong>Tip</strong> <br />
  For this exercise you may want to have a look at the <a href="https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#OhMyThreads.@tasks" class="alert-link"><tt>OhMyThreads.@tasks</tt></a> macro, to be used in conjunction with <a href="https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#OhMyThreads.@set" class="alert-link"><tt>OhMyThreads.@set</tt></a>.
  You can take a look at the <a href="https://juliafolds2.github.io/OhMyThreads.jl/stable/translation/" class="alert-link">translation guide</a> for inspiration.
</div>

In [None]:
using OhMyThreads: @tasks, @set

# Write a function which returns the same string as `pi_string(N_pi)`, using `OhMyThreads.@tasks` for parallelisation.
function pi_string_omt(N; ntasks::Int=8 * nthreads(), scheduler::Symbol=:dynamic)
    digits = Vector{Int}(undef, N)
    @tasks for n in eachindex(digits)
        @set ntasks=ntasks
        @set scheduler=scheduler
        # ...
    end
    return "0x3." * .... * "p0"
end

@assert pi_string_omt(N_pi) == pi_string(N_pi)

b = @benchmark pi_string_omt(N_pi; ntasks=32 * nthreads())

pi_omt_t = minimum(b.times)

display(b)

@info "parallel efficiency: $(round(pi_serial_t / pi_omt_t / nthreads() * 100; digits=2))%"