# Unifying Distributed Training Strategies Under one Scheme

This notebook is a companion to [this paper](parallelsim.html.pdf). It shows how
various forms of parallelism can be rewritted as variants of a single scheduling
algorithm by specifying just two functions: a function that maps work to a
compute worker, and a function that maps weights to a storage worker. By just
changing these functions, we recover Pipeline Parallel, Data Parallle, Fully
Sharded Data Parallel, Looped Pipeline Parallel, their variants, and some new
forms of parallelism.


To run the notebook, execute

```
bazel run //cruise/mlp/robotorch2/lightning/strategies/notebooks:parallelsim_notebook
```

The pipeline diagrams below illustrate various training schedules. In
these diagrams, the columns are time indices, and each row corresponds to a
worker. The cells in the diagram describe the activity of the worker during the
corresponing time.  The color of each cell indicates the id of the batch being
processed. The cell value indicates whether the work is for the forward or
backward pass, and the pipeline stage being processed.

In [100]:
from importlib import reload

import parallelsim

reload(parallelsim)

from parallelsim import Work, ComputeAndWeightWorkers


num_workers = 6
num_stages = 6
num_batches = 6

In [88]:
def DistributedDataParallel(work: Work):
    return ComputeAndWeightWorkers(work.batch, work.batch)


assert num_batches == num_workers


parallelsim.simulate(num_workers, num_stages, num_batches, DistributedDataParallel)

In [89]:
def FullyShardedDistributedDataParallel(work: Work):
    return ComputeAndWeightWorkers(work.batch, work.stage % num_workers)


assert num_batches == num_workers


parallelsim.simulate(
    num_workers, num_stages, num_batches, FullyShardedDistributedDataParallel
)

In [97]:
def GPipe(work: Work):
    return ComputeAndWeightWorkers(work.stage, work.stage)


assert num_stages == num_workers

parallelsim.simulate(num_workers, num_stages, num_batches, GPipe)

In [91]:
assert num_stages == num_workers

parallelsim.simulate(
    num_workers,
    num_stages,
    num_batches,
    GPipe,
    parallelsim.OldestBatchFirst(num_stages, num_batches),
)

In [108]:
group_size = 3
num_batches = 12


def LoopedPipelineParallelism(work: Work):
    def h(x, y):
        return (group_size * y % num_workers) + (x % group_size)

    worker = h(work.stage, work.batch)
    return ComputeAndWeightWorkers(worker, worker)


parallelsim.simulate(
    num_workers,
    num_stages,
    num_batches,
    LoopedPipelineParallelism,
    parallelsim.OldestBatchFirst(num_stages, num_batches),
)

In [106]:
group_size = 3


def FullyShardedLoopedPipelineParallelism(work: Work):
    def h(x, y):
        return (group_size * y % num_workers) + (x % group_size)

    return ComputeAndWeightWorkers(h(work.stage, work.batch), h(work.stage, work.stage))


parallelsim.simulate(
    num_workers,
    num_stages,
    num_batches,
    FullyShardedLoopedPipelineParallelism,
    parallelsim.OldestBatchFirst(num_stages, num_batches),
)

In [110]:
group_size = 3


def CyclicLoopedPipelineParallelism(work: Work):
    def h(x, y):
        return ((group_size * y % num_workers) + x) % num_workers

    worker = h(work.stage, work.batch)
    return ComputeAndWeightWorkers(worker, worker)


parallelsim.simulate(
    num_workers,
    num_stages,
    num_batches,
    CyclicLoopedPipelineParallelism,
    parallelsim.OldestBatchFirst(num_stages, num_batches),
)