# Julia Multithreading Tutorial

This comprehensive tutorial demonstrates the basics of multithreading in Julia, including:

1. **Serial computation** (single-threaded baseline)
2. **Thread-unsafe multithreading** (common pitfall demonstration)
3. **Thread-safe multithreading** (correct approach)
4. **Performance comparison** using benchmarking
5. **Best practices** for multithreading

## Key Concepts

**Multithreading vs Serial Processing:**
- **Threads** share the same memory space (shared memory parallelism)
- **Communication** happens through shared variables
- **Scalability** is limited to CPU cores on a single machine
- **Lower overhead** but requires careful synchronization

## Learning Objectives

By the end of this tutorial, you will understand:
- When and how to use multithreading effectively
- Common pitfalls like race conditions and how to avoid them
- Performance benefits and trade-offs of multithreading
- Best practices for thread-safe programming in Julia

## Pre-requisites

In order to use multiple threads in Julia, we need to start with Julia with the `--threads` option. For example

```julia

julia --threads 8 # OR
julia --threads auto

```

The default value for number of threads is `1`.
However, since we are calling Julia within a Notebook, the process is slightly different. Here we describe how to do it in the following 2 contexts

### VS Code

If you are using VS Code, you will need to do so in a `settings.json` file. This file usually exists in the base project folder that is opened in VS Code. For example, if you opened `InteractiveParallelComputing` in VS Code, the file will be present in `InteractiveParallelComputing/.vscode/settings.json`.
We provide the settings file needed for this tutorial and it contains the following

```json

{
    "julia.NumThreads": "auto"
}

```

## Imports

In [None]:
using Pkg
println("Active project: ", Pkg.project().path)

We will load the appropriate Julia environment and the necessary modules for this tutorial.

In [None]:
using Statistics
using BenchmarkTools  # For performance benchmarking
using Base.Threads   # For multithreading support

# Display threading information
println("Julia Multithreading Tutorial")
println("=" ^ 40)
println("Available threads: $(nthreads())")
println("Thread IDs: $(1:nthreads())")
println()

if nthreads() == 1
    println("⚠️  WARNING: Only 1 thread available!")
    println("For better demonstration, start Julia with multiple threads:")
    println("JULIA_NUM_THREADS=4 julia or set in VS Code settings")
else
    println("✓ Multiple threads available for demonstration")
end

Let's check the number of available threads

In [None]:
@show nthreads()

The number of threads should be greater than 1. We can now proceed with the rest of our examples

## Example Functions

We define `qoi` as a function that can be evaluated in an embarassingly parallel manner. This function comprises of 2 function calls

1. `simulate_heavy_compute`: mimics a function that spends a lot of time in compute
2. `create_line`: mimics a function the computes the value of our quantity of interest (QoI)

In [None]:
"""
    simulate_heavy_compute(t::Float64)::Nothing

Simulates a computationally expensive operation by sleeping for `t` seconds.
This represents any CPU-intensive task like numerical computation, I/O operations, etc.

# Arguments
- `t::Float64`: Sleep duration in seconds

# Example
```julia
simulate_heavy_compute(0.1)  # Simulates 0.1 seconds of work
```
"""
function simulate_heavy_compute(t::Float64)::Nothing
    sleep(t)
    return nothing
end

"""
    create_line(x::Float64)::Float64

A simple linear function that computes `1.0 * x + 2.0`.
This represents a deterministic computation that we can verify for correctness.

# Arguments
- `x::Float64`: Input value

# Returns
- `Float64`: Linear transformation of input

# Example
```julia
result = create_line(5.0)  # Returns 7.0
```
"""
function create_line(x::Float64)::Float64
    return 1.0 * x + 2.0
end

"""
    qoi(x::Float64)::Float64

Computes the "Quantity of Interest" (QoI) for a given input.
Combines simulated heavy computation with a deterministic calculation.
This represents a typical scientific computing workload.

# Arguments
- `x::Float64`: Input value for computation

# Returns
- `Float64`: Computed quantity of interest

# Example
```julia
result = qoi(10.0)  # Computes heavy work + linear transformation
```
"""
function qoi(x::Float64)::Float64
    simulate_heavy_compute(0.01)  # Reduced time for faster demo
    return create_line(x)
end

println("✓ Core functions defined")



The following function computes the sum of the quantity of interest (QoI) in 3 different ways

1. Serially
2. Using thread unsafe method
3. Using a thread safe method

In [None]:
"""
    sum_serial(xarr::AbstractVector)::Float64

**SERIAL COMPUTATION (Single-threaded baseline)**

Computes the sum of QoI for all elements using a single thread.
This serves as our baseline for performance comparison.

# Arguments
- `xarr::AbstractVector`: Input array to process

# Returns
- `Float64`: Sum of QoI for all elements

# Performance Characteristics
- Uses only one thread
- Predictable, deterministic execution
- No thread synchronization overhead

# Example
```julia
xarr = 1.0:10.0
result = sum_serial(xarr)
```
"""
function sum_serial(xarr::AbstractVector)::Float64
    y_serial = 0.0
    for x in xarr
        y_serial += qoi(x)
    end
    return y_serial
end

println("✓ Serial function defined")

**Warning**: This next function is thread unsafe and susceptible to race conditions. Read more [here](https://docs.julialang.org/en/v1/manual/multi-threading/#Data-race-freedom)

In [None]:
"""
    sum_unsafe_threads(xarr::AbstractVector)::Float64

**THREAD-UNSAFE MULTITHREADING (Demonstrates race conditions)**

⚠️  WARNING: This implementation has potential race conditions!

This function demonstrates what NOT to do with multithreading.
Each thread writes to its own array index based on `threadid()`,
but the distribution of work is not guaranteed to be even.

# Race Condition Issues
- Threads may not get equal amounts of work
- `threadid()` distribution depends on thread scheduling
- Results may vary between runs due to non-deterministic scheduling

# Arguments
- `xarr::AbstractVector`: Input array to process

# Returns
- `Float64`: Sum of QoI (may be incorrect due to race conditions)

# Example
```julia
# DON'T use this in production code!
result = sum_unsafe_threads(xarr)
```
"""
function sum_unsafe_threads(xarr::AbstractVector)::Float64
    # Thread unsafe example - demonstrates potential race conditions
    y_threaded_arr = zeros(Float64, nthreads())
    @threads for x in xarr
        y_threaded_arr[threadid()] += qoi(x)
    end
    return sum(y_threaded_arr)
end

println("✓ Thread-unsafe function defined (⚠️  for demonstration only)")

Thread safe implementation

In [None]:
"""
    sum_safe_threads(xarr::AbstractVector)::Float64

**THREAD-SAFE MULTITHREADING (Correct approach)**

✅ This is the CORRECT way to implement multithreading!

Uses explicit work distribution by:
1. Dividing input into equal chunks
2. Spawning tasks for each chunk
3. Collecting results safely using `fetch`

# Key Concepts Demonstrated
- **Work partitioning**: Explicit division of labor
- **Task spawning**: Using `Threads.@spawn` for async execution
- **Result collection**: Safe aggregation with `fetch.(tasks)`
- **Load balancing**: Equal work distribution across threads

# Arguments
- `xarr::AbstractVector`: Input array to process

# Returns
- `Float64`: Correct sum of QoI for all elements

# Example
```julia
# This is the recommended approach
result = sum_safe_threads(xarr)
```
"""
function sum_safe_threads(xarr::AbstractVector)::Float64
    # Handle edge case: if fewer elements than threads, use serial computation
    if length(xarr) < nthreads()
        return sum_serial(xarr)
    end
    
    # Calculate chunk size, ensuring at least 1 element per chunk
    chunk_size = max(1, div(length(xarr), nthreads()))
    
    # Divide work into chunks
    chunks = Iterators.partition(xarr, chunk_size)
    
    # Spawn a task for each chunk
    tasks = map(chunks) do chunk
        Threads.@spawn begin
            local_sum = 0.0
            for x in chunk
                local_sum += qoi(x)
            end
            local_sum
        end
    end
    
    # Wait for all tasks to complete and collect results
    partial_results = fetch.(tasks)
    
    # Sum the partial results
    return sum(partial_results)
end

println("✓ Thread-safe function defined (✅ recommended approach)")

In [None]:
# =============================================================================
# PERFORMANCE COMPARISON DEMONSTRATION
# =============================================================================

# Create test data
println("\n" * "=" ^ 50)
println("PERFORMANCE COMPARISON DEMONSTRATION")
println("=" ^ 50)

println("Setting up test data...")
xarr = range(1.0, 50.0, length=50)  # Smaller array for notebook demo
println("Array size: $(length(xarr)) elements")
println("Expected result: $(sum(create_line.(xarr)))")

# Pre-compile functions (important for accurate benchmarking)
println("\nPre-compiling functions...")
_ = sum_serial(xarr[1:5])
_ = sum_unsafe_threads(xarr[1:5])
_ = sum_safe_threads(xarr[1:5])
println("✓ Functions compiled")

In [None]:
# =============================================================================
# 1. SERIAL COMPUTATION (Baseline)
# =============================================================================
println("\n1. SERIAL COMPUTATION")
println("-" ^ 30)
print("Running serial computation... ")
serial_result = @timed sum_serial(xarr)
println("✓ Complete")
println("Result: $(serial_result.value)")
println("Time: $(round(serial_result.time, digits=3)) seconds")
println("Memory: $(serial_result.bytes) bytes")
println("Threads used: 1 (main thread only)")

In [None]:
# =============================================================================
# 2. THREAD-UNSAFE MULTITHREADING
# =============================================================================
println("\n2. THREAD-UNSAFE MULTITHREADING")
println("-" ^ 35)
println("⚠️  Warning: This may produce incorrect results!")

# Run multiple times to show inconsistency
unsafe_results = Float64[]
unsafe_times = Float64[]

for i in 1:3
    print("Run $i... ")
    result = @timed sum_unsafe_threads(xarr)
    push!(unsafe_results, result.value)
    push!(unsafe_times, result.time)
    println("Result: $(result.value), Time: $(round(result.time, digits=3))s")
end

println("Results consistency: $(length(unique(unsafe_results)) == 1 ? "✓ Consistent" : "✗ Inconsistent")")
println("Average time: $(round(mean(unsafe_times), digits=3)) seconds")
println("Threads used: $(nthreads()) (all available)")

In [None]:
# =============================================================================
# 3. THREAD-SAFE MULTITHREADING
# =============================================================================
println("\n3. THREAD-SAFE MULTITHREADING")
println("-" ^ 32)
println("✅ This is the correct approach!")

# Run multiple times to show consistency
safe_results = Float64[]
safe_times = Float64[]

for i in 1:3
    print("Run $i... ")
    result = @timed sum_safe_threads(xarr)
    push!(safe_results, result.value)
    push!(safe_times, result.time)
    println("Result: $(result.value), Time: $(round(result.time, digits=3))s")
end

println("Results consistency: $(length(unique(safe_results)) == 1 ? "✓ Consistent" : "✗ Inconsistent")")
println("Average time: $(round(mean(safe_times), digits=3)) seconds")
println("Threads used: $(nthreads()) (all available)")

## Conclusion

In [None]:
# =============================================================================
# PERFORMANCE SUMMARY & ANALYSIS
# =============================================================================
println("\n" * "=" ^ 40)
println("PERFORMANCE SUMMARY")
println("=" ^ 40)
println("Serial time:      $(round(serial_result.time, digits=3)) seconds")
println("Unsafe avg time:  $(round(mean(unsafe_times), digits=3)) seconds") 
println("Safe avg time:    $(round(mean(safe_times), digits=3)) seconds")
println()
println("Speedup (unsafe): $(round(serial_result.time / mean(unsafe_times), digits=2))x")
println("Speedup (safe):   $(round(serial_result.time / mean(safe_times), digits=2))x")
println("Efficiency (safe): $(round(100 * serial_result.time / (mean(safe_times) * nthreads()), digits=1))%")

# Accuracy check
println("\nACCURACY CHECK")
println("-" ^ 20)
expected = serial_result.value
println("Expected result:  $expected")
println("Serial result:    $(serial_result.value) ✓")
println("Unsafe results:   $(unsafe_results) $(all(r ≈ expected for r in unsafe_results) ? "✓" : "✗")")
println("Safe results:     $(safe_results) $(all(r ≈ expected for r in safe_results) ? "✓" : "✗")")

println("\n" * "=" ^ 40)
println("KEY OBSERVATIONS")
println("=" ^ 40)
println("1. Serial computation provides the baseline performance")
println("2. Thread-unsafe code may appear to work but can produce incorrect results")
println("3. Thread-safe implementation provides both correctness and performance")
println("4. Speedup depends on problem size, thread count, and computation intensity")
println("5. Always verify correctness before optimizing for performance")

## Multithreading Best Practices

### ✅ Do's
- **Use explicit work partitioning** with `Threads.@spawn` for control
- **Avoid shared mutable state** when possible
- **Pre-compile functions** before benchmarking for accurate results
- **Verify correctness** across multiple runs
- **Consider problem size** - multithreading has overhead

### ❌ Don'ts  
- **Don't rely on `threadid()` distribution** for load balancing
- **Don't access shared memory** without proper synchronization
- **Don't assume multithreading always helps** for small problems
- **Don't ignore race conditions** even if code "seems to work"

### When to Use Multithreading
- **Embarrassingly parallel problems** (independent computations)
- **CPU-bound tasks** that can be divided
- **Problems larger than thread setup overhead**
- **Single-machine computations** with shared memory needs

In [None]:
# =============================================================================
# KEY TAKEAWAYS
# =============================================================================
println("\n" * "=" ^ 40)
println("KEY TAKEAWAYS")
println("=" ^ 40)
println("1. Always verify correctness before optimizing for performance")
println("2. Thread-unsafe code may appear to work but can produce wrong results")
println("3. Proper work distribution is essential for thread safety")
println("4. Use explicit task spawning and result collection for reliability")
println("5. Benchmark with @timed or @btime for accurate performance measurement")
println()
println("For more threads, start Julia with: JULIA_NUM_THREADS=N julia script.jl")
println("Or configure in VS Code settings: julia.NumThreads")
println()
println("✅ Tutorial completed successfully!")

if nthreads() > 1
    println("🚀 Great! You ran this with $(nthreads()) threads")
else
    println("💡 Tip: Try running with more threads to see better speedups!")
end