# DP（Data Parallelism）数据并行

## 1. DP简介

数据并行（Data Parallelism）是PyTorch中的一种并行计算方式，它通过在多个GPU上并行计算来加速模型的训练。
DP 的实现原理：
- 创建多个线程
- 每个线程保存同样的模型
- 每个线程处理切分的数据
- 每个线程计算梯度
- 将梯度汇总到Master线程，并更新模型参数

## 2. DP实现

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
torch.manual_seed(0)
# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 随机生成输入和输出数据
input_data = torch.randn(100, 10)
output_data = torch.randn(100, 1)

# 创建数据集和数据加载器
dataset = TensorDataset(input_data, output_data)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# 定义一个简单的线性回归模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        print(f"Processing batch on device: {x.device}")
        return self.linear(x)

# 初始化模型
model = SimpleModel()

# 如果有多个 GPU，使用 DataParallel
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

# 将模型移动到设备上
model.to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    print(f"Epoch {epoch + 1} starts:")
    for idx, (inputs, labels) in enumerate(dataloader):
        # 将数据移动到设备上
        inputs, labels = inputs.to(device), labels.to(device)

        # 清零梯度
        optimizer.zero_grad()

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        print(f"Batch {idx + 1} processed.")
        print("-" * 50)

    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(dataloader)}')
    print("=" * 80)

print('Training finished.')

Using device: cuda
Epoch 1 starts:
Processing batch on device: cuda:0
Batch 1 processed.
--------------------------------------------------
Processing batch on device: cuda:0
Batch 2 processed.
--------------------------------------------------
Processing batch on device: cuda:0
Batch 3 processed.
--------------------------------------------------
Processing batch on device: cuda:0
Batch 4 processed.
--------------------------------------------------
Processing batch on device: cuda:0
Batch 5 processed.
--------------------------------------------------
Processing batch on device: cuda:0
Batch 6 processed.
--------------------------------------------------
Processing batch on device: cuda:0
Batch 7 processed.
--------------------------------------------------
Processing batch on device: cuda:0
Batch 8 processed.
--------------------------------------------------
Processing batch on device: cuda:0
Batch 9 processed.
--------------------------------------------------
Processing batch on 