# 多gpu训练

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import time
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
class Timer:
    """Record multiple running times."""
    def __init__(self):
        """Defined in :numref:`sec_minibatch_sgd`"""
        self.times = []
        self.start()

    def start(self):
        """Start the timer."""
        self.tik = time.time()

    def stop(self):
        """Stop the timer and record the time in a list."""
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        """Return the average time."""
        return sum(self.times) / len(self.times)

    def sum(self):
        """Return the sum of time."""
        return sum(self.times)

    def cumsum(self):
        """Return the accumulated time."""
        return np.array(self.times).cumsum().tolist()

def try_gpu(i=0):
    """返回第i个GPU设备，如果不存在则返回CPU"""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

def load_data_fashion_mnist(batch_size, resize=None):
    """下载Fashion-MNIST数据集，然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = datasets.FashionMNIST(
        root='./data', train=True, transform=trans, download=True)
    mnist_test = datasets.FashionMNIST(
        root='./data', train=False, transform=trans, download=True)
    return (DataLoader(mnist_train, batch_size, shuffle=True, num_workers=4),
            DataLoader(mnist_test, batch_size, shuffle=False, num_workers=4))

def evaluate_accuracy_gpu(net, data_iter, device=None):
    """计算在指定数据集上模型的精度"""
    if device is None:
        # 尝试从net获取设备信息
        try:
            if hasattr(net, 'parameters'):
                device = next(iter(net.parameters())).device
            else:
                # 如果是lambda函数，使用第一个参数来推断设备
                device = torch.device('cpu')
        except:
            device = torch.device('cpu')
    # 正确预测的数量，总预测的数量
    metric = [0.0] * 2
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # BERT微调所需的（之后将介绍）
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            metric[0] += (net(X).argmax(dim=1) == y).sum().item()
            metric[1] += y.numel()
    return metric[0] / metric[1]

class Animator:
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: self._set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def _set_axes(self, axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
        """设置matplotlib的轴"""
        axes.set_xlabel(xlabel)
        axes.set_ylabel(ylabel)
        axes.set_xscale(xscale)
        axes.set_yscale(yscale)
        axes.set_xlim(xlim)
        axes.set_ylim(ylim)
        if legend:
            axes.legend(legend)
        axes.grid()

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

In [None]:
scale = 0.01
W1 = torch.randn(size=(20, 1, 3, 3))*scale
b1 = torch.zeros(20)
W2 = torch.randn(size=(50, 20, 5, 5))*scale
b2 = torch.zeros(50)
W3= torch.randn(size=(800, 128))*scale
b3 = torch.zeros(128)
W4 = torch.randn(size=(128,10))*scale
b4 = torch.zeros(10)
params = [W1, b1, W2, b2, W3, b3, W4, b4]

def lenet(X, params):
    h1_conv = F.conv2d(input=X, weight=params[0], bias=params[1])
    h1_activation = F.relu(h1_conv)
    h1 = F.avg_pool2d(input=h1_activation, kernel_size=(2,2), stride=(2,2))
    h2_conv = F.conv2d(h1, params[2], params[3])
    h2_activation = F.relu(h2_conv)
    h2 = F.avg_pool2d(input=h2_activation, kernel_size=(2,2), stride=(2,2))
    h2 = h2.reshape(h2.shape[0], -1)
    h3_linear = torch.mm(h2, params[4])+params[5]
    h3 = F.relu(h3_linear)
    y_hat = torch.mm(h3, params[6]) + params[7]
    return y_hat

# reduction参数决定返回的损失计算方式:
# 'none'表示返回每个样本的损失（不进行任何缩减）,
# 'mean'表示返回所有样本损失的平均值,
# 'sum'表示返回所有样本损失的总和。
# 这里选择'none'，以保留每个样本的单独损失值。
loss = nn.CrossEntropyLoss(reduction='none')

In [None]:
def get_params(params, device):
    new_params = [p.to(device) for p in params]
    for p in new_params:
        p.requires_grad_()
    return new_params

In [None]:
new_params = get_params(params, try_gpu(0))
print('b1 权重:', new_params[1])
print('b1 梯度:', new_params[1].grad)
print(try_gpu(0))

In [None]:
def allreduce(data):
    """将所有设备上的梯度求和并广播到所有设备"""
    for i in range(1, len(data)):
        data[0][:] += data[i].to(data[0].device)
    for i in range(1, len(data)):
        data[i][:] = data[0].to(data[i].device)

In [None]:
data = [torch.ones((1,2), device=try_gpu(i)) * (i+1) for i in range(2)]
print('allreduce之前：\n', data[0], '\n', data[1])
allreduce(data)
print('allreduce之后：\n', data[0], '\n', data[1])

In [None]:
print(try_gpu(1))

In [None]:
# 注意：nn.parallel.scatter需要CUDA设备，如果没有GPU会报错
# 这里提供一个手动实现的scatter函数作为替代
def scatter(data, devices):
    """将数据分散到多个设备上"""
    if len(devices) == 1:
        return [data.to(devices[0])]
    # 如果数据在CPU上，需要先移动到第一个GPU
    if data.device.type == 'cpu':
        data = data.to(devices[0])
    # 将数据分割并移动到各个设备
    chunk_size = data.shape[0] // len(devices)
    chunks = []
    for i, device in enumerate(devices):
        start = i * chunk_size
        end = start + chunk_size if i < len(devices) - 1 else data.shape[0]
        chunks.append(data[start:end].to(device))
    return chunks

# 测试scatter函数（如果没有GPU，使用CPU设备）
if torch.cuda.is_available():
    data = torch.arange(20).reshape(4,5)
    devices = [torch.device('cuda:0'), torch.device('cuda:1')]
    split_data = scatter(data, devices)
    print('input: ', data)
    print('load into', devices)
    print('output:', split_data)
else:
    print('CUDA不可用，使用CPU设备进行演示')
    data = torch.arange(20).reshape(4,5)
    devices = [torch.device('cpu')]
    split_data = scatter(data, devices)
    print('input: ', data)
    print('load into', devices)
    print('output:', split_data)

In [None]:
def split_batch(X, y, devices):
    """将批次数据分割到多个设备上"""
    assert X.shape[0] == y.shape[0], "X和y的批次大小必须相同"
    X_shards = scatter(X, devices)
    y_shards = scatter(y, devices)
    return X_shards, y_shards

In [None]:
def train_batch(X, y, device_params, devices, lr, batch_size):
    X_shards, y_shards = split_batch(X, y, devices)
    # 在每个GPU上分别计算损失
    ls = [loss(lenet(X_shard, device_W), y_shard).sum()
          for X_shard, y_shard, device_W in zip(
              X_shards, y_shards, device_params)]
    for l in ls:  # 反向传播在每个GPU上分别执行
        l.backward()
    # 将每个GPU的所有梯度相加，并将其广播到所有GPU
    with torch.no_grad():
        for i in range(len(device_params[0])):
            allreduce(
                [device_params[c][i].grad for c in range(len(devices))])
    # 在每个GPU上分别更新模型参数
    for param_list in device_params:
        with torch.no_grad():
            for param in param_list:
                param -= lr * param.grad / batch_size
                param.grad.zero_()

In [None]:
def train(num_gpus, batch_size, lr):
    train_iter, test_iter = load_data_fashion_mnist(batch_size)
    devices = [try_gpu(i) for i in range(num_gpus)]
    device_params = [get_params(params, d) for d in devices]
    num_epochs = 10
    animator = Animator('epoch', 'test acc', xlim=[1, num_epochs])
    timer = Timer()
    for epoch in range(num_epochs):
        timer.start()
        for X, y in train_iter:
            train_batch(X, y, device_params, devices, lr, batch_size)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
        timer.stop()
        # 评估模型精度
        test_acc = evaluate_accuracy_gpu(
            lambda x: lenet(x, device_params[0]), test_iter, devices[0])
        animator.add(epoch + 1, (test_acc,))
    print(f'测试精度：{animator.Y[0][-1]:.2f}，{timer.avg():.1f}秒/轮，'
          f'在{str(devices)}')

In [None]:
train(num_gpus=1, batch_size=256, lr=0.2)

In [None]:
train(num_gpus=2, batch_size=256, lr=0.2)