In [1]:
import torch
from torch import nn
import tools as tl

In [None]:
net = tl.resnet18(10)
devices = tl.try_all_gpus()

In [None]:
def train(net, num_gpus, batch_size, lr):
    train_iter, test_iter = tl.load_data_fashion_mnist(batch_size)
    devices = [tl.try_gpu(i) for i in range(num_gpus)]

    def init_weights(m):
        if type(m) in [nn.Linear, nn.Conv2d]:
            nn.init.normal_(m.weight, std=0.01)

    net.apply(init_weights)

    net = nn.DataParallel(net, device_ids=devices)
    trainer = torch.optim.SGD(net.parameters(), lr)
    loss = nn.CrossEntropyLoss()
    timer, num_epochs = tl.Timer(), 10
    animator = tl.Animator('epoch', 'test acc', xlim=[1, num_epochs])
    for epoch in range(num_epochs):
        net.train()
        timer.start()
        for X, y in train_iter:
            trainer.zero_grad()
            X, y = X.to(devices[0]), y.to(devices[0])
            l = loss(net(X), y)
            l.backward()
            trainer.step()
        timer.stop()
        animator.add(epoch + 1, (tl.evaluate_accuracy_gpu(net, test_iter),))
    print(f'测试精度：{animator.Y[0][-1]:.2f}，{timer.avg():.1f}秒/轮，'
        f'在{str(devices)}')


In [None]:
train(net, num_gpus=1, batch_size=256, lr=0.1)

In [None]:
train(net, num_gpus=2, batch_size=512, lr=0.2)