In [1]:
import torch

https://kimi.moonshot.cn/chat/co3af54udu62gncaf9fg

分布式数据并行（Distributed Data Parallel，简称DDP）是PyTorch中用于分布式训练的一种方法，它允许在多个进程或多个机器上并行地训练模型。每个进程拥有模型的一个副本，并且每个进程使用不同的数据子集进行训练。在训练过程中，各个进程会异步更新模型参数，并通过某种通信机制（如NVIDIA的NCCL库）同步梯度或参数更新。

以下是一个使用PyTorch的torch.nn.parallel.DistributedDataParallel（DDP）的简单测试例子：

In [None]:
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

class ToyDataset(Dataset):
    def __len__(self):
        return 100

    def __getitem__(self, idx):
        return torch.rand(10), torch.rand(10)

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = '172.17.0.2'
    os.environ['MASTER_PORT'] = '50574'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):

    setup(rank, world_size)

    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    dataset = ToyDataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=10)

    criterion = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for epoch in range(2):  # loop over the dataset multiple times
        for i, (inputs, labels) in enumerat(dataloader):
            inputs = inputs.to(rank)
            labels = labels.to(rank)

            optimizer.zero_grad()
            outputs = ddp_model(inputs)
            loss = criterion(outputs, labels)
            print(f'rank: {}, epoch: {epoch}, iteration:{i}, loss: {loss.item()}')
            loss.backward()
            optimizer.step()

    cleanup()

def main():
    rank = int(sys.argv[1])
    world_size = int(sys.argv[2])
    train(rank, world_size)

if __name__ == "__main__":
    main()


In [1]:
!torchrun --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="172.17.0.2" --master_port=50574 tst_ddp.py

/bin/bash: line 1: torchrun: command not found
