In [None]:
%matplotlib inline
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

## 多卡训练-从零实现

In [None]:
# 向多个设备分发参数
def get_params(params, devices):
    new_params = [params.clone().to(device) for device in devices]
    for params in new_params:
        params.require_grad_()
    return new_params

In [None]:
# 将所有向量相加，并将结果广播给所有设备
def allreduce(data):
    for i in range(1, len(data)):
        data[0][:] += data[i].to(data[0].device)
    for i in range(1, len(data)):
        data[i] = data[0].to(data[i].device)

In [None]:
# 将一个小批量数据均匀地分布在多个设备上
def split_batch(X, y, devices):
    assert len(X) == len(y)
    X = nn.parallel.scatter(X, devices)
    y = nn.parallel.scatter(y. devices)
    return X, y

In [None]:
# 模型准确率评估函数
def evaluate_accuracy(model, test_iter, device):
    metrics = d2l.Accumulator(2)
    model.to(device)
    model.eval()
    for X, y in test_iter:
        X, y = X.to(device), y.to(device)
        y_hat = model(X)
        metrics.add(d2l.accuracy(y_hat, y), y.numel())
    return metrics[0] / metrics[1]

In [None]:
# 训练函数
def train(model, train_iter, test_iter, loss_fn, optimizer, num_epochs, devices):
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train_loss', 'train_acc', 'test_acc'])
    metrics = d2l.Accumulator(3)
    timer = d2l.Timer()
    model.train()
    models = [model.clone().to(device) for device in devices]
    for epoch in range(num_epochs):
        metrics.reset()
        timer.start()
        for X, y in train_iter:
            X_shards, y_shards = split_batch(X, y, devices)
            y_hats = [model(X_shard) for X_shard, model in zip(X_shards, models)]
            losses = [loss_fn(y_hat, y_shard) for y_hat, y_shard in zip(y_hats, y_shards)]
            allreduce(losses)
            total_loss = 0
            for loss in losses:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss
            with torch.no_grad():
                metrics.add(total_loss * y.numel(), d2l.accuracy(y_hat, y), y.numel())
        timer.stop()
        print(f'speed: {metrics[2] / timer.times[-1]:.1f} samples/sec')
        test_acc = evaluate_accuracy(model, test_iter, device)
        print(f'train_acc: {metrics[1] / metrics[2]:.3f}, test_acc: {test_acc}')
        animator.add(epoch + 1,
                     (metrics[0] / metrics[2], metrics[1] / metrics[2], test_acc))
    print(f'speed: {metrics[2] / timer.avg():.1f} samples/sec on {device}.')
    print(f'train_acc: {metrics[1] / metrics[2]:.3f}, test_acc: {test_acc}.')

In [None]:
lr = 0.001
num_epochs = 10

model = ResNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr)

train(model, train_iter, val_iter, loss_fn, optimizer, num_epochs, d2l.try_gpu(3))

## 多卡训练-简洁实现

In [None]:
# 模型准确率评估函数
def evaluate_accuracy(model, test_iter, device):
    metrics = d2l.Accumulator(2)
    model.to(device)
    model.eval()
    for X, y in test_iter:
        X, y = X.to(device), y.to(device)
        y_hat = model(X)
        metrics.add(d2l.accuracy(y_hat, y), y.numel())
    return metrics[0] / metrics[1]

In [None]:
# 训练函数
def train(model, train_iter, test_iter, loss_fn, optimizer, num_epochs, devices):
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train_loss', 'train_acc', 'test_acc'])
    metrics = d2l.Accumulator(3)
    devices = [d2l.try_gpu(i) for i in devices]
    timer = d2l.Timer()
    model.to(devices[0])
    model = nn.parallel.DataParallel(model, devices)
    model.train()
    for epoch in range(num_epochs):
        metrics.reset()
        timer.start()
        for X, y in train_iter:
            X, y = X.to(devices[0]), y.to(devices[0])
            y_hat = model(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                metrics.add(loss * y.numel(), d2l.accuracy(y_hat, y), y.numel())
        timer.stop()
        print(f'speed: {metrics[2] / timer.times[-1]:.1f} samples/sec')
        test_acc = evaluate_accuracy(model, test_iter, devices[0])
        print(f'train_acc: {metrics[1] / metrics[2]:.3f}, test_acc: {test_acc}')
        animator.add(epoch + 1,
                     (metrics[0] / metrics[2], metrics[1] / metrics[2], test_acc))
    print(f'speed: {metrics[2] / timer.avg():.1f} samples/sec on {devices}.')
    print(f'train_acc: {metrics[1] / metrics[2]:.3f}, test_acc: {test_acc}.')

In [None]:
lr = 0.001
num_epochs = 10
devices = [2, 3]

model = ResNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr)

train(model, train_iter, val_iter, loss_fn, optimizer, num_epochs, devices)