# How to make things fast!

- Basics of shared-memory and per-process parallelism

&nbsp;

- How to use that in python

&nbsp;

- Multiprocessing pools

&nbsp;

- Multiprocess Queues & Pipes, which are the recommended way to do parallelism in python

&nbsp;

- The Global Interpreter Lock and why shared memory parallelism is limited in Python

&nbsp;

- Floating point representations, what they really mean and how to think about them, eg Float16 vs bfloat16 vs float32

&nbsp;

- How to do multi-node training on the Mila Cluster

# Concurrency

&nbsp;

Concurrency is the idea of running multiple chains of program execution in parallel. Concurrency allows for 

&nbsp;

- Faster execution (in the context of parallel computing)

&nbsp;

- Better resource utilization (eg., not waste time waiting for things that are slow)

&nbsp;

The idea is that you split your program in multiple lines of execution that work at the time.

&nbsp;

### Threads

&nbsp;

In *threads*, the memory of the different lines of execution is shared.

&nbsp;

As such, it is easy for threads to work on shared resources and shared variables and objects, 
and there is no need to copy data between threads.

&nbsp;

A disadvantage of threads is that it's very easy to not be careful and have them interfere with each other.

&nbsp;

### Processes

&nbsp;

In *processes*, the memory of the different lines of execution is not shared. 

&nbsp;

The memory is completely separate. The only way to communicate between processes is to use
special objects called *queues* and *pipes*. 

&nbsp;

It is costly to communicate between processes.

&nbsp;

### Threads in Python are Limited: The Global Interpreter Lock

&nbsp;

Python has false multi-threading. In order to make memory management (garbage collection) 
fast enough for the single-thread case, they made it so the interpreter can only run one "instruction" at a time.

&nbsp;

This is called the Global Interpreter Lock (GIL). There is literally a lock (something stopping other threads from executing) that is held by the interpreter
when it is running a line of code. This heavily li



In [1]:
# Example of Threading / Shared memory Parallelism

%reset -f
import threading

shared_thing_lock = threading.Lock()
shared_thing = []

def increment(th_id, shared_thing):
    with shared_thing_lock:
        shared_thing.append(th_id)

threads = [threading.Thread(target=increment, args=(i, shared_thing)) for i in range(100)]
[thread.start() for thread in threads]

print(shared_thing)


[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]


In [2]:
# Example of Isolated Memory Parallelism / Multiprocessing

%reset -f
import multiprocess

shared_thing_lock = multiprocess.Lock()
shared_thing = []

def increment(proc_id, shared_thing):
    with shared_thing_lock:
        shared_thing.append(proc_id)

processes = [multiprocess.Process(target=increment, args=(i, shared_thing)) for i in range(100)]
[processe.start() for processe in processes]

print(shared_thing)

[]


# Threading and Processing pools

&nbsp;

Threading and processing pools are a way to make parallelism easier in python.

&nbsp;

They also reduce the overhead of creating and destroying processes and threads by reusing them.

&nbsp;

They also make sure the exceptions propagate to the main process in the case of multiprocessing.

&nbsp;

In the case of multiprocessing, they also return the results without having to use queues explicitely.

&nbsp;

### Multiprocessing pools

&nbsp;

```python
import multiprocessing as mp
import os

num_processes = os.cpu_count()
with mp.Pool(num_processes) as pool:
    results = pool.map(func, iterable)

print(results)
```

&nbsp;

### Multithreading pools

&nbsp;

```python
import multiprocessing.dummy as mp
import os

num_threads = os.cpu_count()
with mp.Pool(num_threads) as pool:
    results = pool.map(func, iterable)

print(results)
```



In [7]:
# Demonstration of multiprocessing Pools:

%reset -f
import os
import math
import time
import multiprocess


# Function to be parallelized
def expensive_thing(proc_id):
    return math.factorial(50000) + proc_id


# Call the function once as a baseline
start = time.time()
expensive_thing(0)
duration_once = time.time() - start
n_processes = 10


# Call the function in parallel
with multiprocess.Pool(n_processes) as pool:
    start = time.time()
    results = pool.map(expensive_thing, range(n_processes))
    duration_10x_proc = time.time() - start


# Print the results
print(f"Just once: {duration_once: 0.2f} seconds")
print(f"{10 * duration_once = :0.2f} seconds")
print(f"{duration_10x_proc  = :0.2f} seconds")
print(f"Speedup of {(10 * duration_once) / duration_10x_proc : 0.2f} times")


Just once:  0.07 seconds
10 * duration_once = 0.65 seconds
duration_10x_proc  = 0.24 seconds
Speedup of  2.76 times


In [10]:
# Communication without a Pool
%reset -f
import os
import threading
import time
import queue
import math

# The queue to communicate the results
results = queue.Queue()


# The expensive function to be parallelized
def expensive_thing(th_id):
    results.put(math.factorial(50000) + th_id)


# Call the function once as a baseline
start = time.perf_counter()
expensive_thing(0)
duration_once = time.perf_counter() - start


# Create the threads
n_threads = 10
threads = [threading.Thread(target=expensive_thing, args=(i,)) for i in range(n_threads)]


# Start the threads
start = time.perf_counter()
[thread.start() for thread in threads]


# Wait for the threads to finish
[thread.join() for thread in threads]


# Get the results
results = [results.get() for _ in range(n_threads)]
duration_threading = time.perf_counter() - start


# Print the results
print(f"10x Duration once: {10 * duration_once:0.5f} seconds")
print(f"Duration 10 threads: {duration_threading:0.5f} seconds")
print(f"Speedup of {10 * duration_once / duration_threading: 0.2} times")

10x Duration once: 0.66534 seconds
Duration 10 threads: 0.77780 seconds
Speedup of  0.86 times


In [16]:
# Same as above, but with a Pool
%reset -f
import os
import threading
import time
import queue
import math
from concurrent.futures import ThreadPoolExecutor


# Expensive function to be parallelized
def expensive_thing(proc_id):
    return math.factorial(50000) + proc_id


# Call the function once as a baseline
start = time.perf_counter()
expensive_thing(0)
duration_once = time.perf_counter() - start


print(f"{10 = }")


# Create the pool & call the function in parallel
with ThreadPoolExecutor(10) as pool:
    start = time.perf_counter()
    results = pool.map(expensive_thing, range(os.cpu_count()))
    duration_pool = time.perf_counter() - start


# Print the results
print(f"10x Duration once: {10 * duration_once:0.5f} seconds")
print(f"Duration 10 threads: {duration_pool:0.5f} seconds")
print(f"Speedup of {10 * duration_once / duration_pool: 0.2} times")

10 = 10
10x Duration once: 0.83267 seconds
Duration 10 threads: 0.46481 seconds
Speedup of  1.8 times


In [3]:
# Example of Threads In "Regular" Languages

%reset -f
import os
import time

import copperhead as cpp


code = """

#include <iostream>
#include <thread>
#include <vector>
#include <cmath>
#include <random>
using namespace std;

// Function to be parallelized
void busy_work(int n, double* result) {
    *result = 1;
    for (int i = 1; i <= n; i++) {
        *result = fmod(*result + rand(), 10000000);
    }
}

// Start the threads from inside C++
vector<double> run(int n_threads, int fn_arg) {

    // This will hold the results
    vector<double> results(n_threads, 0);
    
    // Create pointers for the threads.
    std::thread* threads[n_threads];

    // Create and start each thread.
    // Each thread will write its result to the corresponding element of `results`.
    for (int i = 0; i < n_threads; i++) {
        threads[i] = new std::thread(busy_work, fn_arg, &results[i]);
    }

    // Wait for the threads to finish
    for (int i = 0; i < n_threads; i++) {
        threads[i]->join();
    }

    return results;
}


"""

# Compile the code
run_threads = cpp.generate("run", "std::vector<double> (int, int)", code, rebuild=True)


# Call the function once as a baseline
start = time.perf_counter()
once = run_threads(1, 50000000)
duration_single = time.perf_counter() - start
print(once)


# Call the function in parallel
start = time.perf_counter()
outputs = run_threads(os.cpu_count(), 50000000)
duration_threading = time.perf_counter() - start


# Print the results
print("\n\n")
print(f"10 * duration single: {10 * duration_single:0.5f} seconds")
print(f"Duration {10} threads: {duration_threading:0.5f} seconds")
print(f"Speedup of {(10 * duration_single) / duration_threading: 0.3f} times")



running install
running bdist_egg
running egg_info
writing run.egg-info/PKG-INFO
writing dependency_links to run.egg-info/dependency_links.txt
writing top-level names to run.egg-info/top_level.txt
reading manifest file 'run.egg-info/SOURCES.txt'
writing manifest file 'run.egg-info/SOURCES.txt'
installing library code to build/bdist.macosx-10.9-x86_64/egg
running install_lib
running build_ext
building 'run' extension
gcc -Wno-unused-result -Wsign-compare -Wunreachable-code -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -I/Users/jules/.anaconda3/include -arch x86_64 -I/Users/jules/.anaconda3/include -arch x86_64 -I/Users/jules/.anaconda3/include/python3.8 -c /Users/jules/Documents/Work/talk/.copperhead_cache/run/run_block.cpp -o build/temp.macosx-10.9-x86_64-3.8/Users/jules/Documents/Work/talk/.copperhead_cache/run/run_block.o




g++ -bundle -undefined dynamic_lookup -L/Users/jules/.anaconda3/lib -arch x86_64 -L/Users/jules/.anaconda3/lib -arch x86_64 -arch x86_64 build/temp.macosx-10.9-x86_64-3.8/Users/jules/Documents/Work/talk/.copperhead_cache/run/run_block.o -o build/lib.macosx-10.9-x86_64-3.8/run.cpython-38-darwin.so
creating build/bdist.macosx-10.9-x86_64/egg
copying build/lib.macosx-10.9-x86_64-3.8/run.cpython-38-darwin.so -> build/bdist.macosx-10.9-x86_64/egg
creating stub loader for run.cpython-38-darwin.so
byte-compiling build/bdist.macosx-10.9-x86_64/egg/run.py to run.cpython-38.pyc
creating build/bdist.macosx-10.9-x86_64/egg/EGG-INFO
copying run.egg-info/PKG-INFO -> build/bdist.macosx-10.9-x86_64/egg/EGG-INFO
copying run.egg-info/SOURCES.txt -> build/bdist.macosx-10.9-x86_64/egg/EGG-INFO
copying run.egg-info/dependency_links.txt -> build/bdist.macosx-10.9-x86_64/egg/EGG-INFO
copying run.egg-info/top_level.txt -> build/bdist.macosx-10.9-x86_64/egg/EGG-INFO
writing build/bdist.macosx-10.9-x86_64/egg/E

zip_safe flag not set; analyzing archive contents...
__pycache__.run.cpython-38: module references __file__


creating 'dist/run-0.0.0-py3.8-macosx-10.9-x86_64.egg' and adding 'build/bdist.macosx-10.9-x86_64/egg' to it
removing 'build/bdist.macosx-10.9-x86_64/egg' (and everything under it)
Processing run-0.0.0-py3.8-macosx-10.9-x86_64.egg
removing '/Users/jules/Documents/Work/talk/.copperhead_cache/run/run-0.0.0-py3.8-macosx-10.9-x86_64.egg' (and everything under it)
creating /Users/jules/Documents/Work/talk/.copperhead_cache/run/run-0.0.0-py3.8-macosx-10.9-x86_64.egg
Extracting run-0.0.0-py3.8-macosx-10.9-x86_64.egg to /Users/jules/Documents/Work/talk/.copperhead_cache/run
run 0.0.0 is already the active version in easy-install.pth

Installed /Users/jules/Documents/Work/talk/.copperhead_cache/run/run-0.0.0-py3.8-macosx-10.9-x86_64.egg
Processing dependencies for run==0.0.0
Finished processing dependencies for run==0.0.0
[4457177.0]



10 * duration single: 15.98842 seconds
Duration 10 threads: 7.10033 seconds
Speedup of  2.252 times


# Important Points to Remember:

&nbsp;

### Processes

&nbsp;

- Multiprocessing is the recommended way to do parallelism in Python. Use it. It's fast and easy with `multiprocessing.Pool` or `concurrent.futures.ProcessPoolExecutor`.

&nbsp;

- Communication between processes is costly. Everything needs to be pickled, sent through a pipe, then unpickled.

&nbsp;

- If you need communication between processes, read up on `multiprocessing.Manager()` and `multiprocessing.Queue()`.

&nbsp;

### Queues

&nbsp;

- *If you are doing things that are not CPU-bound, you should use threads instead of processes.*

&nbsp;

- For example, if you are reading heavy files. use `multiprocessing.dummy.Pool` or `concurrent.futures.ThreadPoolExecutor`.

&nbsp;

- Also you are downloading multiple things at the same time, use threads.

&nbsp;

- If you need to communicate between threads, use `queue.Queue` or `queue.Pipe`. They are thread-safe.


# Floating Point Numbers & Representations


### Why this is important:

Torch and Pytorch Lightning (through torch) allow the use of different floating point representations. 

&nbsp;

These representations allow very large speedups, but there is a tradeoff in precision.

&nbsp;

Here are illustrations of the differences in computation power as a function of precision:

&nbsp;

![why_precision](why_precision.png)

&nbsp;

Understanding the meaning of these representations is important to understand the tradeoff.

&nbsp;


### How to use mixed precision training in Pytorch and Pytorch Lightning:

&nbsp;


First, how to use lower precision. With Pytorch Lightning, it's as simple as setting `precision=16` or `precision=bfloat16` in the trainer:

&nbsp;


![precision](precision.png)

&nbsp;


For further information, see https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html#mixed-precision-16-bit-training

&nbsp;

You can do something similar in raw Pytorch of course, but it's slightly more complicated.

&nbsp;

![precision2](precision2.png)

&nbsp;

For further information, see https://pytorch.org/docs/stable/amp.html#autocasting

&nbsp;


### What are floating point numbers & how do they work ?

&nbsp;

Floating point numbers are what CPUs and GPUs use to represent numbers.

&nbsp;

They are scary, but they are in fact identical scientific notation, except in base 2.

&nbsp;

Assuming we have 5 significant digits:

&nbsp;

Scientific notation:   - 3.5484 x 10^25

&nbsp;

This number has three parts: the sign, the mantissa, and the exponent.

&nbsp;

Here, the sign is -1.

&nbsp;

The mantissa is 3.5484. It's a number between 1 and 10 representing the signifigant digits.

&nbsp;

The exponent is 25.

&nbsp;

### Floating point numbers work the same way, except in base 2:

&nbsp;

- The sign is either 0 or 1.

&nbsp;

- The mantissa is an integer in base 2, represented by a string of bits. 

&nbsp;

- The exponent is an integer in base 2, represented by a string of bits, and the number that is exponentiated is 2. It's value is always between 1 and 2, just as the mantissa is always between 1 and 10 in scientific notation.

&nbsp;

- This is all you really need to know, the rest are details to get the details right.

&nbsp;

### Technical details:

&nbsp;

- We always divide the mantissa by its maximum int value, then add 1. This is to constrain it to its role of representing a number between 1 and 2. You can also guess the number of significant digits from this.

&nbsp;

- We substract half of the number of bits of the exponent from the exponent, making it so that the exponent is always between -2^(number_of_bits_of_exponent - 1) and 2^(number_of_bits_of_exponent - 1). This representation apparently makes comparison betweeb floats faster.

&nbsp;

### Important point:

&nbsp;

- In *float16*, there are 5 bits for the exponent, and 10 bits for the mantissa. This means that the exponent is between -16 and 15 

&nbsp;

![float16](float16.png)

&nbsp;

- *float32* has 8 bits of exponent, and 23 bits of mantissa. This means that the exponent is between -128 and 127. 

&nbsp;

![float32](single.png)

&nbsp;

- *bfloat16* also has 16 bits, but it has 8 bits of exponent like the float32, and 7 bits of mantissa. This means that the exponent is also between -128 and 127. 

&nbsp;

![bfloat16](bfloat16.png)



In [277]:
%reset -f
import torch
import sys
import bitstring
import math

def get_raw_bits(tensor, num_bits):
    if num_bits == 16:
        tensor = tensor.view(torch.int16)
    elif num_bits == 32:
        tensor = tensor.view(torch.int32)
    elif num_bits == 64:
        tensor = tensor.view(torch.int64)
    else:
        raise ValueError("num_bits must be 16, 32, or 64")
    
    array = bitstring.BitArray(bytes=tensor.numpy().tobytes())
    array.byteswap()
    return array.bin

16
16
32


In [301]:
import rich
import rich.table
console = rich.console.Console(width=240)

def to_int(bits):
    return sum([int(bit) * 2 ** (i) for i, bit in enumerate(bits[::-1])])

def float32(bits):
    sign_bit = int(bits[0])
    exponent = to_int(bits[1:9])
    mantissa = to_int(bits[9:32])


    # Compute Sign
    sign = (-1) ** sign_bit


    # Compute the exponentiation factor
    resulting_power = exponent - 127
    exponentiation_factor = 2 ** resulting_power


    # Compute the significant factor
    mantissa_multiplier = 1 + mantissa / 2 ** 23   # 23 is the number of bits in the mantissa. 



    # Put things together
    result = sign * exponentiation_factor * mantissa_multiplier
    result_direct = (-1) ** sign_bit * 2 ** (exponent - 127) * (1 + mantissa / 2 ** 23)
    assert result == result_direct


    # Print the details
    table = rich.table.Table("Descript",   "Eqn",                                                       "Value", title="Float32")
    table.add_row("Sign bit",              "bits[0]",                                                   str(sign_bit),            end_section=True)
    table.add_row("Sign",                  "(-1) ** (bits[0])",                                         str(sign))
    table.add_row("Mantissa bits",         "bits[9:32]",                                                str(bits[9:32]))
    table.add_row("Mantissa",              "to_int(bits[9:32])",                                        str(mantissa),            )
    table.add_row("Mantissa Multiplier",   "1 + mantissa / 2 ** 23",                                    str(mantissa_multiplier), end_section=True)
    table.add_row("Exponent bits",         "bits[1:9]",                                                 str(bits[1:9]))
    table.add_row("Exponent",              "to_int(bits[1:9])",                                         str(exponent))
    table.add_row("Exponentiation factor", "2 ** resulting_power",                                      str(exponentiation_factor))
    table.add_row("Resulting Power",       "exponent - 127",                                            str(resulting_power),     end_section=True)
    table.add_row("Result",                f"{sign} * {mantissa_multiplier} * {exponentiation_factor}", f"[bold blue]{result}")
    console.print(table)
    print()
    return result

def float16(bits):
    sign_bit = int(bits[0])
    exponent = to_int(bits[1:6])
    mantissa = to_int(bits[6:16])



    # Compute Sign
    sign = (-1) ** sign_bit



    # Compute the exponentiation factor
    resulting_power = exponent - 15
    exponentiation_factor = 2 ** resulting_power



    # Compute the significant factor
    mantissa_multiplier = 1 + mantissa / 2 ** 10


    # Put things together
    result = sign * exponentiation_factor * mantissa_multiplier
    result_direct = (-1) ** sign_bit * 2 ** (exponent - 15) * (1 + mantissa / 2 ** 10)


    # Print the details
    table = rich.table.Table("Descript",   "Eqn",                                                       "Value", title="Float16")
    table.add_row("Sign bit",              "bits[0]",                                                   str(sign_bit),              end_section=True)
    table.add_row("Sign",                  "(-1) ** to_int(bits[0])",                                   str(sign),                  end_section=True)
    table.add_row("Mantissa bits",         "bits[6:16]",                                                str(bits[6:16]))
    table.add_row("Mantissa",              "to_int(bits[6:16])",                                        str(mantissa))
    table.add_row("Mantissa Factor",       "1 + mantissa / 2 ** 10",                                    str(mantissa_multiplier),   end_section=True)
    table.add_row("Exponent bits",         "bits[1:6]",                                                 str(bits[1:6]))
    table.add_row("Exponent",              "to_int(bits[1:6])",                                         str(exponent))
    table.add_row("Resulting Power",       "exponent - 15",                                             str(resulting_power),)
    table.add_row("Exponent Factor",       "2 ** resulting_power",                                      str(exponentiation_factor), end_section=True)
    table.add_row("Result",                f"{sign} * {mantissa_multiplier} * {exponentiation_factor}", f"[bold blue]{result}")
    console.print(table)
    print()

    return result

def bfloat16(bits):
    sign_bits = int(bits[0])
    exponent = to_int(bits[1:9])
    mantissa = to_int(bits[9:16])


    # Compute Sign
    sign = (-1) ** sign_bits


    # Compute the exponentiation factor
    resulting_power = exponent - 127
    exponent_factor = 2 ** resulting_power


    # Compute the significant factor
    mantissa_multiplier = 1 + mantissa / 2 ** 7


    # Put things together
    result = sign * exponent_factor * mantissa_multiplier
    result_direct = (-1) ** sign_bits * 2 ** (exponent - 127) * (1 + mantissa / 2 ** 7)



    # Print the details
    table = rich.table.Table("Descript",   "Eqn",                                                    "Value", title="BFloat16")
    table.add_row("Sign bits",             "bits[0]",                                                str(bits[0]), end_section=True)
    table.add_row("Sign",                  "(-1) ** to_int(bits[0])",                                str(sign), end_section=True)
    table.add_row("Mantissa bits",         "bits[9:16]",                                             str(bits[9:16]),)
    table.add_row("Mantissa",              "to_int(bits[9:16])",                                     str(mantissa),)
    table.add_row("Mantissa Multiplier",   "1 + mantissa / 2 ** 7",                                  str(mantissa_multiplier), end_section=True)
    table.add_row("Exponent bits",         "bits[1:9]",                                              str(bits[1:9]),)
    table.add_row("Exponent",              "to_int(bits[1:9])",                                      str(exponent))
    table.add_row("Resulting Power",       "exponent - 127",                                         str(resulting_power))
    table.add_row("Exponent Factor",       "2 ** resulting_power",                                   str(exponent_factor), end_section=True)
    table.add_row("Result",                f"{sign} * {mantissa_multiplier} * {exponent_factor}",    f"[bold blue]{result}")
    console.print(table)
    print()
    return result

In [305]:
target = math.e

float16bits  = get_raw_bits(torch.tensor(target, dtype=torch.float16),  16)
float32bits  = get_raw_bits(torch.tensor(target, dtype=torch.float32),  32)
bfloat16bits = get_raw_bits(torch.tensor(target, dtype=torch.bfloat16), 16)

float32(float32bits)
bfloat16(bfloat16bits)
float16(float16bits)











2.71875

In [306]:
target = math.pi

float16bits  = get_raw_bits(torch.tensor(target, dtype=torch.float16),  16)
float32bits  = get_raw_bits(torch.tensor(target, dtype=torch.float32),  32)
bfloat16bits = get_raw_bits(torch.tensor(target, dtype=torch.bfloat16), 16)

float32(float32bits)
bfloat16(bfloat16bits)
float16(float16bits)










3.140625

# How to launch a job with multiple nodes on the Mila cluster

The main idea is that `srun` launches the same process multiple times, scaling on a single node or on multiple nodes.

&nbsp;

Substituting `$NUMBER_OF_TASKS_PER_NODE` with the number of GPUs on per node, the following flag dictates how many parallel processes will be launched on each node:

&nbsp;

`--ntasks-per-node=$NUMBER_OF_TASKS_PER_NODE`

&nbsp;

As such, to lauch jobs, do one of the following, substituting `$NUMBER_OF_GPUS_PER_NODE` with the number of GPUs you want to use per node.

&nbsp;

To launch jobs for later, 

&nbsp;

```bash
sbatch script.sh --gres=gpu:$NUMBER_OF_GPUS_PER_NODE --mem=16G --cpus-per-task=6 --ntasks-per-node=$NUMBER_OF_GPUS_PER_NODE
```

&nbsp;


to launch jobs right now:

&nbsp;

```bash
salloc --gres=gpu:$NUMBER_OF_GPUS_PER_NODE --mem=16G --cpus-per-task=6 --ntasks-per-node=$NUMBER_OF_GPUS_PER_NODE
```

&nbsp;


Then, in your `script.sh` or in your interactive shell, launch your job with `srun`:

&nbsp;

```bash
srun python script.py
```

&nbsp;


`srun` will launch the same process multiple times, multiple times per node, and scaling to multiple nodes.

&nbsp;


Each process will be have multiple environment variables set. This allows you to code different behavior for each process.

&nbsp;

- `SLURM_PROCID`, for example, is the global id of the process.

&nbsp;

- `SLURM_LOCALID` is the id of the process on the node.

&nbsp;

- `SLURM_NTASKS_PER_NODE` is the number of processes per node.

&nbsp;

The following command will let you explore what environment variables are set. `env` prints all the environment variables:

&nbsp;

```bash
srun bash -c 'env | grep SLURM'
```

&nbsp;


### Multi Node Training With Pytorch Lightning

&nbsp;

After launching your script with `srun`, you just have to have the following in your trainer code in Pytorch Lightning:

&nbsp;

![Getting Started](lightning_multi_node.png)

&nbsp;

Behind the scenes, Pytorch Lightning will automatically use the SLURM environment variables to launch a server on the first node. It will also use the environment variables to connect to the server on that node, from the other servers, to synchronize the training.

Pytorch Lightning will also automatically use a distributed sampler, so each process will only train on a subset of the data.

&nbsp;

More information here: https://pytorch-lightning.readthedocs.io/en/stable/clouds/cluster_advanced.html#run-on-a-slurm-managed-cluster

&nbsp;

You can set-up something similar on raw Pytorch with `torch.distributed.launch`, but it is more complicated. See an example here https://gist.github.com/TengdaHan/1dd10d335c7ca6f13810fff41e809904

