In [1]:
import torch 
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

In [3]:
scale = 0.01
w1 = torch.randn(size=(20, 1, 3, 3)) * scale
b1 = torch.zeros(20) 

w2 = torch.randn(size=(50, 20, 5, 5)) * scale
b2 = torch.zeros(50)

w3 = torch.randn(size=(800, 128)) * scale
b3 = torch.zeros(128)

w4 = torch.randn(size=(128, 10 )) *scale
b4 = torch.zeros(10)

params = [w1, b1, w2, b2, w3, b3, w4, b4]

def lenet(X, params):
    h1_conv = F.conv2d(input=X, weight=params[0],bias=params[1])
    h1_activation = F.relu(h1_conv)
    h1 = F.avg_pool2d(input=h1_activation, kernel_size=(2, 2), stride=(2, 2))

    h2_conv = F.conv2d(input=h1, weight= params[2], bias=params[3])
    h2_activation = F.relu(h2_conv)
    h2 = F.avg_pool2d(input=h2_activation, kernel_size=(2, 2), stride=(2, 2))
    h2 = h2.reshape(h2.shape[0], -1)
    h3_linear = torch.mm(h2, params[4]) + params[5]
    h3 = F.relu(h3_linear)
    y_hat = torch.mm(h3, params[6]) + params[7]

    return y_hat

loss = nn.CrossEntropyLoss(reduction='none')


In [4]:
def get_params( params, device):
    new_params = [param.to(device) for param in params]
    for param in new_params:
        param.requires_grad_()
    return new_params 

In [5]:
new_params = get_params(params, d2l.try_gpu(0))
print('b1 权重',params[1])
print('b1 梯度',params[1].grad)

b1 权重 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       requires_grad=True)
b1 梯度 None


In [6]:
def allreduce(data):
    for i in range(1, len(data)):
        data[0][:] += data[i].to(data[0].device)

    for i in range(1, len(data)):
        data[i][:] = data[0].to(data[i].device)

In [7]:
data = [ torch.ones((1, 2), device=d2l.try_gpu(i)) * (i+1) for i in range(2) ]
print( 'allreduce 之前：\n', data[0],'\n',data[1])
allreduce(data)
print(' allreduce 之后:    \n', data[0], '\n', data[1]) 

allreduce 之前：
 tensor([[1., 1.]]) 
 tensor([[2., 2.]])
 allreduce 之后:    
 tensor([[3., 3.]]) 
 tensor([[3., 3.]])


In [None]:
data = torch.arange(20). reshape(4, 5)
devices = [torch.device('cuda:0'), torch.device('cuda:1') ]
split = nn.parallel.scatter(data, devices)
print('input : ', data)
print('load into',  devices)
print('out', split)

In [10]:
def split_batch(X, y , devices):
    assert X.shape[0] == y.shape[0]
    return (nn.parallel.scatter(X, devices),
                   nn.parallel.scatter(y, devices))

In [11]:
def train_batch(X, y ,device_params,  devices, lr):
    X_shards, y_shards = split_batch(X, y, devices)

    ls = [loss(lenet(X_shard,device_w ) , y_shard) .sum() 
          for X_shard, y_shard, device_w in  
          zip(X_shards, y_shards, device_params)]
    
    for l in ls:
        l.backward()

    with torch.no_grad():
        for i in range(len(device_params[0])):
            allreduce([device_params[c][i].grad  for c in range(len(devices))]         )
    
    for param in device_params:
        d2l.sgd(param, lr, X.shape[0])

    


In [None]:
def train(num_gpus, batch_size,lr):
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

    devices = [ d2l.try_gpu(i) for i in range(num_gpus)]
    device_params = [get_params(params, d) for d in devices]
    num_epochs = 10
    animator = d2l.Animator('eopch', 'test_acc', xlim=[1, num_epochs])
    timer = d2l.Timer()

    for epoch in range(num_epochs):
        timer.start()
        for X, y in train_iter:
            train_batch(X, y,device_params,devices, lr)
            torch.cuda.synchronize()
        timer.stop()
        animator.add(epoch + 1, (d2l.evaluate_accuracy_gpu(
            lambda x: lenet(x, device_params[0]), test_iter, devices[0]), ))
    print(f'测试精度：,{ animator.Y[0][-1] :.2f }, {timer.avg():.1f} 秒/ 每轮， ', f'在{str(devices)}')


    

In [None]:
train(num_gpus= 2, batch_size=256, lr = 0.2)

![20250118135356](https://raw.githubusercontent.com/Rainbow452/image/main/img/20250118135356.png)