# Pytorch Distributed

# Introduction

`torch.distributed`中支持的分布式相关的特征可以归类为3个方面的组件：

[Distributed Data-Parallel Training](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) (DDP) 

DDP is a widely adopted single-program multiple-data training paradigm. With DDP, the model is replicated on every process, and every model replica will be fed with a different set of input data samples. DDP takes care of gradient communication to keep model replicas synchronized and overlaps it with the gradient computations to speed up training. DDP来负责进行各个进程间的梯度通信，并且在梯度计算和通信之间进行异步处理来隐藏通信上的开销。

[RPC-Based Distributed Training](https://pytorch.org/docs/stable/rpc.html) (RPC) 

RPC supports general training structures that cannot fit into data-parallel training such as distributed pipeline parallelism, parameter server paradigm, and combinations of DDP with other training paradigms. It helps manage remote object lifetime and extends the [autograd engine](https://pytorch.org/docs/stable/autograd.html) beyond machine boundaries.

RPC在DDP之外提供了一种更加通用的分布式方案，比如分布式流水线并行、参数服务器模式等。


[Collective Communication](https://pytorch.org/docs/stable/distributed.html) (c10d) 

The library supports sending tensors across processes within a group. It offers both collective communication APIs (e.g., `all_reduce` and `all_gather`) and P2P communication APIs (e.g., `send` and `isend`). DDP and RPC ([ProcessGroup Backend](https://pytorch.org/docs/stable/rpc.html#process-group-backend)) are built on c10d, where the former uses collective communications and the latter uses P2P communications. Usually, developers do not need to directly use this raw communication API, as the DDP and RPC APIs can serve many distributed training scenarios. However, there are use cases where this API is still helpful. One example would be distributed parameter averaging, where applications would like to compute the average values of all model parameters after the backward pass instead of using DDP to communicate gradients. This can decouple communications from computations and allow finer-grain control over what to communicate, but on the other hand, it also gives up the performance optimizations offered by DDP. [Writing Distributed Applications with PyTorch](https://pytorch.org/tutorials/intermediate/dist_tuto.html) shows examples of using c10d communication APIs.

c10d提供了更加底层的集合通信和P2P通信的API，给用户提供了更加灵活的接口，做满足一些特别的需求。

# Data Parallel Traning

PyTorch provides several options for data-parallel training. For applications that gradually grow from simple to complex and from prototype to production, the common development trajectory would be:

* 单机单卡无分布式：Use single-device training if the data and model can fit in one GPU, and training speed is not a concern.
* 单机多卡多线程分布式：Use single-machine multi-GPU [DataParallel](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html) to make use of multiple GPUs on a single machine to speed up training with minimal code changes.
* 单机多卡多进程分布式：Use single-machine multi-GPU [DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html), if you would like to further speed up training and are willing to write a little more code to set it up.
* 多机多卡多进程分布式：Use multi-machine DistributedDataParallel and the [launching script](https://github.com/pytorch/examples/blob/master/distributed/ddp/README.md), if the application needs to scale across machine boundaries.
* 有伸缩功能的多机分布式：Use [torch.distributed.elastic](https://pytorch.org/docs/stable/distributed.elastic.html) to launch distributed training if errors (e.g., out-of-memory) are expected or if resources can join and leave dynamically during training.

## `torch.nn.DataParallel`

单机多GPU下最简单的并行方案，只需要简单几行代码的修改。便是它的性能一般都不是最优的，一个是因为每次forward时，它都需要在多个卡上进行复制模型，另一个是因为DataParallel用的是多线程的试，所以它也会受到GIL锁的影响。

相关教程：

* 模型介绍：[DataParallel](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html)
* 教程：[Optional: Data Parallelism](https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html)

In [4]:
import torch
import torch.nn as nn

model = nn.Linear(5, 10)
devices = [torch.device("cuda", i) for i in range(2)]
model = nn.DataParallel(model, devices)
model.to(devices[0])

data = torch.randn(8, 5)
data = data.to(device=devices[0])
output = model(data)
print(output.shape)

torch.Size([8, 10])


## `torch.nn.parallel.DistributedDataParallel`

相较于`DataParallel`，`DistributedDataParallel`会多需要几步设置，主要是用于调用`init_process_group`。另外DDP只需要在初始化构建的时候进行一次模型的Broadcast，而不是像DP一样，每次Forward都需要进行模型拷贝。

DDP的相关参考资料为：

* [Y] [DDP notes](https://pytorch.org/docs/stable/notes/ddp.html): 通过一个简单的示例快速说明一个DDP的流程是什么样的，相比于单机单卡的代码的改动，同时也介绍了一些内部实现的原理。
* [Y] [Getting Started with Distributed Data Parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html): 进一步说明了DDP的使用方法，并介绍了在DDP下如何保存了checkpoint，以及如何将DDP和MP进行结合，最后也引入了使用`torchrun`来对DDP进行初始化的方法。
* [Y] [Writing Distributed applications with pytorch](https://pytorch.org/tutorials/intermediate/dist_tuto.html): 教程中介绍了`torch.distributed`模块中的一些分布式通信的原语支持，包括了点对点的通信`send/recv`和`isend/irecv`，以及集合通信的几种模式（Scatter/Gather/Reduce/AllReduce/Broadcast/All-Gather），最后实现了一个简单的同步的分布式SGD的训练流程。
* [Y] [Launching and configuring distributed data parallel applications ](https://github.com/pytorch/examples/blob/main/distributed/ddp/README.md): 介绍了使用`torch.distributed.launch`来初始化DDP的方法，这个方法看起来已经完全被`torchrun`替代了，因为它内部实际调用的也是`torch.distributed.run`
* [N] [Single Machine Model Parallel Best Practices](https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html)
* [N] [Shard Optimizer States With ZeroRedundancyOptimizer](https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html): Distributed Optimizer
* [N] [Distributed Training with Uneven Inputs Using the Join Context Manager](https://pytorch.org/tutorials/advanced/generic_join.html)

`DistributedDataParallel`的一个简单的示例如下：

In [1]:
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import os


def steup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29544"
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


def example(rank, world_size):
    steup(rank, world_size)

    model = nn.Linear(10, 10).to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    outputs = ddp_model(torch.randn(20, 10).to(rank))
    labels = torch.randn(20, 10).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()


def main():
    world_size = 2
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    # 多进程代码不能直接在notebook中运行，需要直接用python解释器来运行
    main()

如上面的示例代码所示：DDP把底层的分布式通信全都隐藏起来了，当前我调用`backward()`时，梯度会一边进行反向计算，一般在多个结点之间进行同步通信。当`backward()`返回时，每个参数的梯度`param.grad`都已经包括了同步后的梯度。

### rank、world_size、deivce

通过上面代码可以看出，每个进程都需要感知的变量有：

* `rank`表示当前进程在分布式环境中的唯一标识符。每个进程都有一个独特的`rank`值，用于区分不同的进程。`rank`的取值范围是从 `0` 到 `world_size - 1`。
* `world_size` 表示分布式环境中的进程总数。它表示了整个分布式训练的规模，即参与训练的进程数量。
* `device` 表示本进程在结点上使用的gpu的设备id

在单机情况下，我可以用`rank`来计算`device`，但在多结点下，我们必须计算好对应的`device`

### Skewed(偏斜) Processing Speeds

在DDP的每个进程时，DDP的构造、Forward、Backward都是分布式的同步点，原则上要求不同进程能够在差不多的时间到达同步点，如果有一些进程处理的过快或过慢，变会导致存在进程在同步点的位置长时间阻塞，时间过长的话，则会触发timeout，所以这就需要用户平衡好不同的结点，不同卡的计算负载。

但有间进程不同步不可避免，因为一些网络原因，资源限制或其他意外异常，所以在调用`init_process_group`时，需要设置好一些合理的timeout值。

### DDP模式下如何保存和加载Checkpoint

为了避免所有进程同时保存了模型快照，我们可以只让rank=0的进程来保存快照，加载时所有进程都从快照加载模型。

保存Checkpoint的代码如下：

```python
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
    # All processes should see same parameters as they all start from same
    # random parameters and gradients are synchronized in backward passes.
    # Therefore, saving it in one process is sufficient.
    torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

# Use a barrier() to make sure that process 1 loads the model after process
# saves it.
dist.barrier()
```

加载Checkpoint的代码如下：

```python
# configure map_location properly
# 这里假设是每个rank使用对应的device id
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
state_dict = torch.load(CHECKPOINT_PATH, map_location=map_location)
ddp_model.load_state_dict(state_dict)
```

### Internal Design

* **Prerequisite**: 准备阶段。DDP依赖c10d的进程组来进行通信，所以在实例化构造DDP之前，需要先初始化一个进程组，也就是需要调用`init_process_group`
* **Construction**: 在DDP的构造阶段。
    * DDP的实例从一个local的torch Module开始构建，但是会立即在rank 0的进程里boardcasts参数`state_dict()`，从而保证在初始阶段，所有进程上拿到的模型参数完全一致。
    * 每个DDP进程都会创建一个`Reducer`，它来负责后续梯度的同步工作，`Reducer`将整个模型的参数按`bucket`来管理，每一次进行分布式reduce操作都会在一个`bucket`上进行。
    * `Reducer`会向`autograd`注册hooks，每个参数一个hook，这些hooks会在bacward pass时被触发。
* **Forward Pass**: 在前向过程中，如果`find_unused_parameters`被设置为True，那么DDP会对计算图进行遍历分析，实际计算的可能只是一个子图，对于不需要计算的参数，DDP会将其标记为 ready for reduce
* **Backward Pass**: 由于`backward()`方法是直接在Loss Tensor上执行的，这个已经脱离了`DDP`的上下文了。
    * 当一个参数对应梯度计算出来后，就会触发DDP hook来标记这个参数梯度为ready for reduction
    * 当有一个bucket中的所有梯度都是ready状态时，Reducer就会执行`allreduce`操作来计算所有进程的平均梯度。
    * 所有进程中的`Reducer`会按相同的bucket的顺序来执行`allreduce`，而不是按照这些bucket的ready顺序。
* **Optimizer Step**: 从优化器的角度看，它看到的就只是一个本地的模型，没有任何区别。

![](./images/ddp_reducer.png)

## `DataParallel`和`DistributedDataParallel`的区别

1. `DataParallel`是单进程多线程的方式，所以这就限定了`DataParallel`只能在单结点下使用，而`DistributedDataParallel`使用的是多进程模式，可以同时用于单结点内或多结点间。一般来说`DataParallel`是比`DistributedDataParallel`慢的，因为`DataParallel`会受限于GIL锁，同时因为每次迭代时，都需要进行模型复制、输入scattering和输出gathering。
2. `DistributedDataParallel`可以很好的和model parallel一起使用，每个DDP进程内使用多卡上的模型并行，而`DataParallel`则不支持。

## `torch.distributed.elastic`

`torchrun`

随着模型与数据规模的增加，分式下单点的容错变得十分必要，因为随机分布式结点的变多，不可避免的会出现一些进程遇到像OOM或IO异常的问题。而这样的单点失败的问题，在DDP里是解决不了的，因为DDP要求进程组里的所有进程要保持几乎完全的同步，一旦有一个进程挂了，那么基本进程基本都会卡死在`AllReduce`上。

`torch.distributed.elastic`分布式模块增加了错误容忍，从而使得在分布式训练时，可以使用一个动态的pool

模块文档链接：https://pytorch.org/docs/stable/distributed.elastic.html

知乎上的Pytorch弹性训练原理解析：https://zhuanlan.zhihu.com/p/519410235

# RPC-Based Distributed Traning

RPC-based的分布式训练是为了提供一种更加general的分布式机制，可以让我们用在一些不适配用`data parallelism`的场合，比如：参数服务器和流水线并行等模式下，以及多个observers和agents的强化学习等应用上。

`torch.distributed.rpc`由以下四个主要模块构成：

* [RPC](https://pytorch.org/docs/stable/rpc.html#rpc) supports running a given function on a remote worker.
* [RRef](https://pytorch.org/docs/stable/rpc.html#rref) helps to manage the lifetime of a remote object. The reference counting protocol is presented in the [RRef notes](https://pytorch.org/docs/stable/rpc/rref.html#remote-reference-protocol).
* [Distributed Autograd](https://pytorch.org/docs/stable/rpc.html#distributed-autograd-framework) extends the autograd engine beyond machine boundaries. Please refer to [Distributed Autograd Design](https://pytorch.org/docs/stable/rpc/distributed_autograd.html#distributed-autograd-design) for more details.
* [Distributed Optimizer](https://pytorch.org/docs/stable/rpc.html#module-torch.distributed.optim) automatically reaches out to all participating workers to update parameters using gradients computed by the distributed autograd engine.

主要的教程如下：

1. The [Getting Started with Distributed RPC Framework](https://pytorch.org/tutorials/intermediate/rpc_tutorial.html) tutorial first uses a simple Reinforcement Learning (RL) example to demonstrate `RPC` and `RRef`. Then, it applies a basic distributed model parallelism to an RNN example to show how to use `distributed autograd` and `distributed optimizer`.
2. The [Implementing a Parameter Server Using Distributed RPC Framework](https://pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html) tutorial borrows the spirit of HogWild! training and applies it to an asynchronous parameter server (PS) training application.
3. The [Distributed Pipeline Parallelism](https://pytorch.org/tutorials/intermediate/dist_pipeline_parallel_tutorial.html) Using RPC tutorial extends the single-machine pipeline parallel example (presented in Single-Machine Model Parallel Best Practices) to a distributed environment and shows how to implement it using RPC.
4. The [Implementing Batch RPC Processing Using Asynchronous Executions](https://pytorch.org/tutorials/intermediate/rpc_async_execution.html) tutorial demonstrates how to implement RPC batch processing using the @rpc.functions.async_execution decorator, which can help speed up inference and training. It uses RL and PS examples similar to those in the above tutorials 1 and 2.
5. [The Combining Distributed DataParallel with Distributed RPC Framework](https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html) tutorial demonstrates how to combine DDP with RPC to train a model using distributed data parallelism combined with distributed model parallelism.

# Communication Primitives

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/High-performant_DL/Multi_GPU/hpdlmultigpu.html

In [2]:
import torch
import torch.distributed as dist

import torch.multiprocessing as mp
import os


def all_reduce(rank):
    t = torch.ones((5, 5), device=rank) * rank
    # t = p0.t + p1.t + p2.t + p3.t
    dist.all_reduce(t, op=dist.ReduceOp.SUM)
    assert t.mean().item() == 6  # [0, 1, 2, 3]


def reduce(rank):
    t = torch.ones((5, 5), device=rank) * rank
    dist.reduce(t, dst=0, op=dist.ReduceOp.SUM)
    # print(f"{os.getpid()}: {t.mean().item()}")
    if rank == 0:
        assert t.mean().item() == 6
    else:
        # 在gloo中结果不对
        assert t.mean().item() == rank


def boardcast(rank):
    t = torch.ones((5, 5), device=rank) * rank
    # 将rank 3的进程中的 t 广播到其他进程中
    dist.broadcast(t, src=3)
    assert t.mean().item() == 3


def all_gather(rank):
    t = torch.ones((5, 5), device=rank) * rank
    outputs = []
    for _ in range(dist.get_world_size()):
        outputs.append(torch.zeros((5, 5), device=rank))
    dist.all_gather(outputs, t)
    gather = torch.concat(outputs, dim=0)
    assert gather.shape == torch.Size([20, 5])
    assert gather.float().mean() == torch.tensor([0, 1.0, 2.0, 3.0]).mean()


def reduce_scatter(rank):
    world_size = dist.get_world_size()
    t = torch.ones((world_size * 5, 5), device=rank) * rank
    l = torch.split(t, 5, dim=0)
    reduce_rst = torch.zeros((5, 5), device=rank)
    dist.reduce_scatter(reduce_rst, list(l), dist.ReduceOp.SUM)
    assert reduce_rst.mean().item() == 6


def main_process(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "25321"
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["RANK"] = str(rank)
    dist.init_process_group(backend="nccl")
    all_reduce(rank)
    reduce(rank)
    boardcast(rank)
    all_gather(rank)
    reduce_scatter(rank)
    dist.destroy_process_group()


if __name__ == "__main__":
    nprocs = 4
    # Not support in Notebook
    # mp.spawn(main_process, nprocs=nprocs, args=(nprocs,), join=True)

# 实战： SyncBatchNorm

后续在needle里实现

In [5]:
sync_bn = nn.SyncBatchNorm(num_features=256)