## 1.背景介绍

本实验主要研究 PyTorch 分布式训练框架下的单机多卡（Single Machine Multi-GPU）流水线并行训练方法，重点探索 torch.distributed.pipelining 的使用。

## 2.实验目的
实现基于 torch.distributed.pipelining 的流水线并行训练，并理解其工作原理。


## 3.硬件要求

2张 GPU（4090、V100、A100等）。


## 4.技术原理

### 流水线并行（Pipeline Parallelism）

流水线并行指将模型的不同部分放在不同的GPU上，从而降低单个GPU显存使用量。

例如，如果有模型为以下结构: [embedding -> pos_encoding -> layer.0 -> layer.1 -> layer.2 -> layer.3 -> out_linear]

我们可以将 [embedding -> pos_encoding -> layer.0 -> layer.1] 这部分模型放到 GPU 0 上，将其余部分 ( [layer.2 -> layer.3 -> out_linear] ) 放到 GPU 1 上。

在前向传播的时候，输入数据经过 GPU 0，计算得到 layer.1 的输出，然后此时 GPU 0 与 GPU 1 之间进行一次张量通信，将 layer.1 的输出发送给 GPU 1。 随后 GPU 1 完成其余部分模型的计算

同理，在反向传播的时候，不同 GPU 之间 也会进行梯度的通信。

### micro-batch 优化

朴素的流水线并行效率不高，是因为同一时间只有一张 GPU 在进行计算。为了提高 GPU 的利用效率，micro-batch 优化技术将一个 input batch 分割成多个 micro-batch， 从而提高数据在多块 GPU 上的流转效率。

## 5.实验流程

### 环境配置



In [20]:
!pip install torch



### 5.1. transformer模型定义

In [11]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DistributedSampler
import math


from torch.distributed.pipelining import pipeline, SplitPoint
from torch.distributed.pipelining import ScheduleGPipe


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:, : x.size(1)]


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out_linear = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size = x.shape[0]

        q = (
            self.q_linear(x)
            .view(batch_size, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        k = (
            self.k_linear(x)
            .view(batch_size, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        v = (
            self.v_linear(x)
            .view(batch_size, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-1e9"))

        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        output = (
            output.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.num_heads * self.d_k)
        )
        return self.out_linear(output)


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))


class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=512,
        num_heads=8,
        d_ff=2048,
        num_layers=6,
        max_len=5000,
    ):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList(
            [TransformerBlock(d_model, num_heads, d_ff) for _ in range(num_layers)]
        )
        self.out_linear = nn.Linear(d_model, vocab_size)

    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x)
        return self.out_linear(x)


Transformer(
  (embedding): Embedding(1000, 32)
  (pos_encoding): PositionalEncoding()
  (layers): ModuleList(
    (0-1): 2 x TransformerBlock(
      (attention): MultiHeadSelfAttention(
        (q_linear): Linear(in_features=32, out_features=32, bias=True)
        (k_linear): Linear(in_features=32, out_features=32, bias=True)
        (v_linear): Linear(in_features=32, out_features=32, bias=True)
        (out_linear): Linear(in_features=32, out_features=32, bias=True)
      )
      (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (ffn): FeedForward(
        (fc1): Linear(in_features=32, out_features=64, bias=True)
        (fc2): Linear(in_features=64, out_features=32, bias=True)
      )
      (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (out_linear): Linear(in_features=32, out_features=1000, bias=True)
)


### 5.2.数据集定义

In [13]:
class NLPDataset(Dataset):
    def __init__(self, size, length):
        self.data = []
        for i in range(size):
            self.data.append(torch.full((length, ), i))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

dataset = NLPDataset(12, 10)
for data in dataset:
    print(data)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9])


### 5.3训练核心代码实现

In [17]:
def train(rank, world_size):
    VOCAB_SIZE = 100
    D_MODEL = 12
    NUM_HEADS = 4
    D_FF = 24
    NUM_LAYERS = 2
    MAX_LEN = 100

    DATASET_SIZE = 12
    DATASET_LENGTH = 10
    BATCH_SIZE = 4

    NUM_MICROBATCHES = 2

    def compute_loss(output, target):
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output.view(-1, VOCAB_SIZE), target.view(-1))
        return loss

    dataset = NLPDataset(size=DATASET_SIZE, length=DATASET_LENGTH)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)

    x = torch.zeros((BATCH_SIZE // NUM_MICROBATCHES, DATASET_LENGTH - 1), dtype=torch.long)

    pipe = pipeline(
        module=Transformer(
            vocab_size=VOCAB_SIZE,
            d_model=D_MODEL,
            num_heads=NUM_HEADS,
            d_ff=D_FF,
            num_layers=NUM_LAYERS,
            max_len=MAX_LEN,
        ),
        mb_args=(x,),
        split_spec={
            "layers.1": SplitPoint.BEGINNING,
        },
    )

    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    device = f"cuda:{rank}"

    stage_mod = pipe.get_stage_module(rank)
    optimizer = optim.SGD(stage_mod.parameters(), lr=0.01)
    stage = pipe.build_stage(rank, device, None)

    schedule = ScheduleGPipe(stage, NUM_MICROBATCHES, compute_loss)
    for epoch in range(200):
        for batch, data in enumerate(dataloader):
            label = data[:, 1:].to(device)
            x = data[:, :-1].to(device)
            optimizer.zero_grad()
            if rank == 0:
                schedule.step(x)
            else:
                losses = []
                output = schedule.step(target=label, losses=losses)
                print(
                    f"Epoch {epoch}, Batch {batch}, Loss: {torch.stack(losses).mean()}"
                )
            optimizer.step()

    dist.destroy_process_group()


### 5.4启动训练

In [3]:
!bash run.sh

W0325 14:32:16.307000 221057 site-packages/torch/distributed/run.py:793] 
W0325 14:32:16.307000 221057 site-packages/torch/distributed/run.py:793] *****************************************
W0325 14:32:16.307000 221057 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0325 14:32:16.307000 221057 site-packages/torch/distributed/run.py:793] *****************************************
Epoch 0, Batch 0, Loss: 4.920873641967773
Epoch 0, Batch 1, Loss: 4.386817932128906
Epoch 0, Batch 2, Loss: 5.070835113525391
Epoch 1, Batch 0, Loss: 4.84718656539917
Epoch 1, Batch 1, Loss: 4.296722412109375
Epoch 1, Batch 2, Loss: 4.999993324279785
Epoch 2, Batch 0, Loss: 4.698904037475586
Epoch 2, Batch 1, Loss: 4.194802284240723
Epoch 2, Batch 2, Loss: 4.897258758544922
Epoch 3, Batch 0, Loss: 4.5966205596