# deepseed和accelerate

视频一：https://www.bilibili.com/video/BV1ZZ421T7XJ/?spm_id_from=333.337.search-card.all.click&vd_source=071b23b9c7175dbaf674c65294124341  
视频二：https://www.bilibili.com/video/BV1uK421a7HG?spm_id_from=333.788.videopod.episodes&vd_source=071b23b9c7175dbaf674c65294124341&p=4
官方博客：https://huggingface.co/docs/accelerate/quicktour

### 区别  
DeepSpeed 是由 Microsoft 开发的一个深度学习优化库，旨在提高大规模模型训练的效率。它提供了多种优化技术，包括混合精度训练、分布式训练、数据并行、模型并行和高效的梯度累积等。  
主要特点：  
1. 分布式训练: 支持数据并行、模型并行（包括管道并行和张量并行），使得训练超大规模模型成为可能。  
2. 优化技术: 提供优化算法，如 ZeRO（Zero Redundancy Optimizer）以减少显存占用、加速训练速度。  
3. 混合精度训练: 自动支持 FP16 和 BFLOAT16 精度训练，以减少计算开销和内存占用。  
4. 高效的梯度累积: 提供高效的梯度累积方法来处理超大批量训练。  
5. 弹性训练: 支持弹性训练，允许动态添加或移除计算资源。  
功能非常的强大、丰富，但是配置起来比较复杂，需要一定的深度学习知识。  

Accelerate 是由 Hugging Face 开发的一个库，旨在简化分布式训练的设置。它提供了一种简洁的方式来配置和管理多 GPU 和 TPU 环境，支持数据并行和模型并行。  
主要特点  
1. 简化配置: 提供统一的接口来处理多 GPU 和 TPU 环境的配置，简化了分布式训练的复杂性。  
2. 支持多种后端: 可以与不同的深度学习后端（如 PyTorch 和 TensorFlow）集成。  
3. 集成 DeepSpeed: 可以与 DeepSpeed 集成，利用 DeepSpeed 的高级功能进行训练。  
4. 简化分布式训练: 自动处理分布式训练的设置和同步问题，使用户能够专注于模型和数据  

功能丰富，而且使用非常简单，但是配置不是非常精细。  
联系  
1. 集成能力: Accelerate 可以与 DeepSpeed 集成，利用 DeepSpeed 的高级功能来优化训练。通过 Accelerate 的 DeepSpeedPlugin，可以在 Accelerate 的框架下使用 DeepSpeed 进行训练。  
2. 共同目标: 两者都旨在提高大规模模型训练的效率和简化配置，但 DeepSpeed 提供了更多的优化功能，而 Accelerate 注重于简化配置和多后端支持。  

## accelerate为什么能加速？  

1、数据并行：在数据并行中，模型的不同部分在不同的 GPU 上运行，每个 GPU 处理数据的不同部分。这样，可以并行处理数据，从而加速训练过程。  
2、梯度同步：在每个反向传播步骤中，Accelerate 会同步所有 GPU 上的梯度。具体步骤如下：1、每个 GPU 计算自己的梯度。2、所有 GPU 上的梯度通过 all_reduce 操作进行同步，确保每个 GPU 上的梯度是全局一致的。3、每个 GPU 使用同步后的梯度更新模型参数。  
3、模型并行

## accelerate基础函数

1、 accelerator.reduce()：在 Accelerate 中，accelerator.reduce() 是一个非常有用的函数，用于在分布式训练环境中对多个设备上的张量进行聚合操作（如求和、平均等）。  
accelerator.reduce(loss, reduction="mean")

2、accelerator.gather()：在 Accelerate 中，accelerator.gather() 是一个用于在分布式训练环境中收集多个设备上的张量到一个设备上的函数。  
accelerator.gather(loss).mean()即等价于accelerator.reduce(loss, reduction="mean")

In [2]:
# 版本一

'''
单卡直接运行：35.76 seconds
单卡accelerate：48.82 seconds
双卡accelerate：29.51 seconds
'''

from accelerate import Accelerator, DeepSpeedPlugin, notebook_launcher
import torch
from torch.utils.data import DataLoader, TensorDataset

import time

class SimpleNet(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)  #.to("cuda:0")
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim) #.to("cuda:1")

    def forward(self, x):
        # x.to("cuda:0")
        x = torch.relu(self.fc1(x))
        # x.to("cuda:1")
        x = self.fc2(x)
        return x

# if __name__ == "__main__":
def main():
    input_dim = 10
    hidden_dim = 20
    output_dim = 2
    batch_size = 64
    data_size = 10000

    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    
    input_data = torch.randn(data_size, input_dim)
    labels = torch.randn(data_size, output_dim)

    dataset = TensorDataset(input_data, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = SimpleNet(input_dim, hidden_dim, output_dim)
    # model.to(device)
    
    # deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_clipping=1.0)
    # accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
    accelerator = Accelerator()
    optimization = torch.optim.Adam(model.parameters(), lr=0.00015)
    crition = torch.nn.MSELoss()
    
    model, dataloader, optimization = accelerator.prepare(model, dataloader, optimization)
    
    start_time = time.time()
    for epoch in range(100):
        model.train()
        for batch in dataloader:
            inputs, labels = batch
            # inputs = inputs.to(device)
            # labels = labels.to(device)
            outputs = model(inputs)
            loss = crition(outputs, labels)
            
            optimization.zero_grad()
            # loss.backward()
            accelerator.backward(loss)
            optimization.step()
        print(f"Epoch {epoch} loss: {loss.item()}")

    end_time = time.time()  # 记录训练结束时间
    training_time = end_time - start_time  # 计算训练时间
    print(f"Training time: {training_time:.2f} seconds")

    # accelerator.save(model.state_dict(), "model.pth")

notebook_launcher(main, num_processes=1)

Launching training on one GPU.
Epoch 0 loss: 1.1368621587753296
Epoch 1 loss: 1.6605136394500732
Epoch 2 loss: 1.13680100440979
Epoch 3 loss: 1.0083215236663818
Epoch 4 loss: 0.5586246252059937
Epoch 5 loss: 0.5666735172271729
Epoch 6 loss: 0.5323970317840576
Epoch 7 loss: 0.7227011919021606
Epoch 8 loss: 0.9263297319412231
Epoch 9 loss: 0.9665447473526001
Epoch 10 loss: 1.2193164825439453
Epoch 11 loss: 0.8520088195800781
Epoch 12 loss: 1.3043696880340576
Epoch 13 loss: 1.1880245208740234
Epoch 14 loss: 0.989682674407959
Epoch 15 loss: 1.0800554752349854
Epoch 16 loss: 1.2249712944030762
Epoch 17 loss: 1.0276449918746948
Epoch 18 loss: 1.5330866575241089
Epoch 19 loss: 0.9961369037628174
Epoch 20 loss: 1.3767263889312744
Epoch 21 loss: 0.8045156002044678
Epoch 22 loss: 0.7873939871788025
Epoch 23 loss: 0.8252222537994385
Epoch 24 loss: 1.0691642761230469
Epoch 25 loss: 0.761885404586792
Epoch 26 loss: 0.841905951499939
Epoch 27 loss: 0.6713285446166992
Epoch 28 loss: 1.427389621734619

运行请参考同目录下的shell脚本，下面为双卡运行的结果，两个损失独立计算，但是尽管每个 GPU 的损失是独立计算的，但梯度会在所有 GPU 之间通过通信（如 All-Reduce）进行同步，从而保证模型参数的一致性。

存在问题：  
1、损失会打印两遍，可以将两个损失加起来再求平均；或者只打印主进程的损失


Epoch 0 loss: 1.222244143486023
Epoch 0 loss: 1.1932578086853027
Epoch 1 loss: 0.9584339261054993
Epoch 1 loss: 1.3205852508544922
Epoch 2 loss: 1.0940974950790405
Epoch 2 loss: 0.9358994960784912
Epoch 3 loss: 1.112207055091858
Epoch 3 loss: 1.2235267162322998
Epoch 4 loss: 1.1149256229400635
Epoch 4 loss: 1.2418652772903442
Epoch 5 loss: 0.9817483425140381
Epoch 5 loss: 1.0360157489776611
Epoch 6 loss: 1.0470519065856934
Epoch 6 loss: 1.1236121654510498
Epoch 7 loss: 0.9477889537811279
Epoch 7 loss: 1.0988385677337646
Epoch 8 loss: 0.9655306935310364
Epoch 8 loss: 0.9430689215660095
Epoch 9 loss: 1.139449119567871
Epoch 9 loss: 0.7671375274658203
Epoch 10 loss: 1.2914409637451172Epoch 10 loss: 1.020603060722351

Epoch 11 loss: 0.9398835897445679
Epoch 11 loss: 1.086294412612915
Epoch 12 loss: 0.941448986530304
Epoch 12 loss: 0.9452714920043945
Epoch 13 loss: 1.0518507957458496Epoch 13 loss: 1.0442299842834473

Epoch 14 loss: 0.8157416582107544Epoch 14 loss: 0.8653008341789246

Epoch 15 loss: 0.93047696352005
Epoch 15 loss: 1.0393348932266235
Epoch 16 loss: 1.037017822265625
Epoch 16 loss: 1.0681655406951904
Epoch 17 loss: 0.7890644669532776Epoch 17 loss: 1.1021982431411743

Epoch 18 loss: 0.959057629108429
Epoch 18 loss: 0.9632008671760559
Epoch 19 loss: 1.2127506732940674
Epoch 19 loss: 1.1701936721801758
Epoch 20 loss: 1.087782859802246
Epoch 20 loss: 1.262906789779663
Epoch 21 loss: 1.0912950038909912Epoch 21 loss: 1.0063831806182861

Epoch 22 loss: 0.8336697816848755Epoch 22 loss: 0.8239030838012695

Epoch 23 loss: 1.1087322235107422
Epoch 23 loss: 1.1554553508758545
Epoch 24 loss: 0.9161040186882019
Epoch 24 loss: 0.9664233922958374
Epoch 25 loss: 1.050710916519165
Epoch 25 loss: 1.0629756450653076
Epoch 26 loss: 0.8758817911148071Epoch 26 loss: 0.9303441047668457

Epoch 27 loss: 0.9692226052284241
Epoch 27 loss: 0.9761572480201721
Epoch 28 loss: 1.2455298900604248
Epoch 28 loss: 1.2066376209259033
Epoch 29 loss: 1.150545597076416
Epoch 29 loss: 1.0455403327941895
Epoch 30 loss: 1.1949223279953003
Epoch 30 loss: 0.8285470008850098
Epoch 31 loss: 0.7168357372283936Epoch 31 loss: 1.0651519298553467

Epoch 32 loss: 0.9221818447113037
Epoch 32 loss: 0.9815448522567749
Epoch 33 loss: 0.8430807590484619
Epoch 33 loss: 0.9851400852203369
Epoch 34 loss: 1.1626412868499756Epoch 34 loss: 0.9436689615249634

Epoch 35 loss: 0.9057785868644714
Epoch 35 loss: 1.016579031944275
Epoch 36 loss: 0.8286615610122681
Epoch 36 loss: 1.0037178993225098
Epoch 37 loss: 0.7945321798324585
Epoch 37 loss: 0.9148077368736267
Epoch 38 loss: 1.1119959354400635
Epoch 38 loss: 0.9336367845535278
Epoch 39 loss: 0.9502816200256348Epoch 39 loss: 0.8942989706993103

Epoch 40 loss: 0.8680436015129089
Epoch 40 loss: 0.9411708116531372
Epoch 41 loss: 0.9502044320106506Epoch 41 loss: 1.1237208843231201

Epoch 42 loss: 1.048546552658081
Epoch 42 loss: 1.1162359714508057
Epoch 43 loss: 1.0711779594421387
Epoch 43 loss: 1.075994849205017
Epoch 44 loss: 0.9065122008323669Epoch 44 loss: 0.9709441661834717

Epoch 45 loss: 0.9690513610839844
Epoch 45 loss: 1.080939531326294
Epoch 46 loss: 0.9084521532058716
Epoch 46 loss: 1.0406299829483032
Epoch 47 loss: 0.9158217906951904Epoch 47 loss: 0.9350024461746216

Epoch 48 loss: 1.0812666416168213
Epoch 48 loss: 1.1399707794189453
Epoch 49 loss: 0.9048585295677185
Epoch 49 loss: 1.142801284790039
Epoch 50 loss: 0.9551778435707092
Epoch 50 loss: 1.0305641889572144
Epoch 51 loss: 1.110548496246338
Epoch 51 loss: 0.8209130764007568
Epoch 52 loss: 0.9793894290924072
Epoch 52 loss: 1.168177843093872
Epoch 53 loss: 1.0132465362548828
Epoch 53 loss: 0.8407228589057922
Epoch 54 loss: 1.002813696861267
Epoch 54 loss: 1.0878620147705078
Epoch 55 loss: 1.015448808670044
Epoch 55 loss: 1.064071774482727
Epoch 56 loss: 1.166907787322998
Epoch 56 loss: 1.0066497325897217
Epoch 57 loss: 1.1030890941619873
Epoch 57 loss: 0.9938492774963379
Epoch 58 loss: 0.9750868082046509Epoch 58 loss: 0.9958294630050659

Epoch 59 loss: 0.9356696009635925
Epoch 59 loss: 1.2875062227249146
Epoch 60 loss: 1.1039798259735107
Epoch 60 loss: 1.227920413017273
Epoch 61 loss: 0.9704052209854126
Epoch 61 loss: 0.9694112539291382
Epoch 62 loss: 1.0148450136184692
Epoch 62 loss: 1.0889440774917603
Epoch 63 loss: 1.0066773891448975
Epoch 63 loss: 1.2679240703582764
Epoch 64 loss: 1.0210902690887451
Epoch 64 loss: 1.22023344039917
Epoch 65 loss: 0.906128466129303
Epoch 65 loss: 1.1522164344787598
Epoch 66 loss: 0.8859699964523315Epoch 66 loss: 0.9274510145187378

Epoch 67 loss: 1.098000168800354
Epoch 67 loss: 1.0197765827178955
Epoch 68 loss: 1.0629165172576904
Epoch 68 loss: 0.9598166346549988
Epoch 69 loss: 1.1767737865447998
Epoch 69 loss: 1.03900945186615
Epoch 70 loss: 0.9126209020614624Epoch 70 loss: 0.8794825673103333

Epoch 71 loss: 0.989602267742157
Epoch 71 loss: 1.0115382671356201
Epoch 72 loss: 0.9187991619110107
Epoch 72 loss: 1.0713469982147217
Epoch 73 loss: 0.9748698472976685Epoch 73 loss: 1.0427227020263672

Epoch 74 loss: 1.2203468084335327
Epoch 74 loss: 0.9615665078163147
Epoch 75 loss: 0.7661041021347046
Epoch 75 loss: 0.9157580137252808
Epoch 76 loss: 0.9729267954826355
Epoch 76 loss: 0.9903103113174438
Epoch 77 loss: 1.032605767250061
Epoch 77 loss: 1.1825000047683716
Epoch 78 loss: 1.0940577983856201
Epoch 78 loss: 0.9936052560806274
Epoch 79 loss: 0.9351733922958374
Epoch 79 loss: 1.0008231401443481
Epoch 80 loss: 0.9143375158309937
Epoch 80 loss: 1.066443920135498
Epoch 81 loss: 1.0062150955200195Epoch 81 loss: 1.0447568893432617

Epoch 82 loss: 0.9024689197540283
Epoch 82 loss: 1.1246912479400635
Epoch 83 loss: 1.2535226345062256
Epoch 83 loss: 0.933487594127655
Epoch 84 loss: 0.99973464012146
Epoch 84 loss: 0.9624868631362915
Epoch 85 loss: 1.0391117334365845Epoch 85 loss: 1.2220211029052734

Epoch 86 loss: 0.8270946741104126Epoch 86 loss: 1.037883996963501

Epoch 87 loss: 1.0317413806915283
Epoch 87 loss: 1.0438865423202515
Epoch 88 loss: 0.8812784552574158Epoch 88 loss: 1.0950701236724854

Epoch 89 loss: 0.8820447325706482
Epoch 89 loss: 0.9369880557060242
Epoch 90 loss: 0.9372445344924927Epoch 90 loss: 1.0238240957260132

Epoch 91 loss: 1.058774709701538Epoch 91 loss: 0.9712309837341309

Epoch 92 loss: 1.00765061378479
Epoch 92 loss: 1.0069276094436646
Epoch 93 loss: 1.2989996671676636Epoch 93 loss: 1.1104843616485596

Epoch 94 loss: 0.867250382900238Epoch 94 loss: 0.8564057350158691

Epoch 95 loss: 1.0129553079605103
Epoch 95 loss: 1.141326665878296
Epoch 96 loss: 1.204880714416504Epoch 96 loss: 0.9655894637107849

Epoch 97 loss: 0.9894803166389465Epoch 97 loss: 1.0592126846313477

Epoch 98 loss: 0.9539648294448853Epoch 98 loss: 0.9008721113204956

Epoch 99 loss: 1.0199759006500244
Training time: 29.51 seconds
Epoch 99 loss: 1.078111171722412
Training time: 29.51 seconds

In [None]:
# 聚合一下损失，并且只输出主进程的结果

'''
单卡直接运行：35.76 seconds
单卡accelerate：48.82 seconds
双卡accelerate：29.51 seconds
'''

from accelerate import Accelerator, DeepSpeedPlugin
import torch
from torch.utils.data import DataLoader, TensorDataset

import time

class SimpleNet(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)  #.to("cuda:0")
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim) #.to("cuda:1")

    def forward(self, x):
        # x.to("cuda:0")
        x = torch.relu(self.fc1(x))
        # x.to("cuda:1")
        x = self.fc2(x)
        return x

if __name__ == "__main__":
    input_dim = 10
    hidden_dim = 20
    output_dim = 2
    batch_size = 64
    data_size = 10000

    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    
    input_data = torch.randn(data_size, input_dim)
    labels = torch.randn(data_size, output_dim)

    dataset = TensorDataset(input_data, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = SimpleNet(input_dim, hidden_dim, output_dim)
    # model.to(device)
    
    # deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_clipping=1.0)
    # accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
    accelerator = Accelerator()
    optimization = torch.optim.Adam(model.parameters(), lr=0.00015)
    crition = torch.nn.MSELoss()
    # print(f'len(dataloader):{len(dataloader)}')
    model, dataloader, optimization = accelerator.prepare(model, dataloader, optimization)
    # print(f'len(dataloader):{len(dataloader)}')
    start_time = time.time()
    for epoch in range(100):
        model.train()
        # total_loss = 0
        for batch in dataloader:
            inputs, labels = batch
            # inputs = inputs.to(device)
            # labels = labels.to(device)
            outputs = model(inputs)
            loss = crition(outputs, labels)
            # total_loss += loss.item()
            optimization.zero_grad()
            # loss.backward()
            accelerator.backward(loss)
            optimization.step()
        
        # 这里的loss只是取了最后一次的loss，并不是所有batch的loss

        # 下面两个方法得到的结果是一样的
        gather_loss = accelerator.gather(loss).mean()  # 收集所有 GPU 的损失并取平均
        if accelerator.is_main_process:
            print(f"Epoch {epoch} loss: {gather_loss.item()}")

        # avg_loss = accelerator.reduce(loss, reduction="mean")
        # if accelerator.is_main_process:
        #     print(f"Epoch {epoch} loss: {avg_loss.item()}")

    end_time = time.time()  # 记录训练结束时间
    training_time = end_time - start_time  # 计算训练时间
    if accelerator.is_main_process:
        print(f"Training time: {training_time:.2f} seconds")

    # accelerator.save(model.state_dict(), "model.pth")
            
    
    

Epoch 0 loss: 0.9432018399238586
Epoch 1 loss: 0.9678932428359985
Epoch 2 loss: 1.0469319820404053
Epoch 3 loss: 1.0713804960250854
Epoch 4 loss: 0.9084330797195435
Epoch 5 loss: 0.916047215461731
Epoch 6 loss: 1.0488253831863403
Epoch 7 loss: 1.066176176071167
Epoch 8 loss: 1.2106516361236572
Epoch 9 loss: 0.8363816738128662
Epoch 10 loss: 1.0502277612686157
Epoch 11 loss: 1.1638139486312866
Epoch 12 loss: 0.9493512511253357
Epoch 13 loss: 1.017587661743164
Epoch 14 loss: 0.9472472667694092
Epoch 15 loss: 0.9615411162376404
Epoch 16 loss: 0.990830659866333
Epoch 17 loss: 0.8506342172622681
Epoch 18 loss: 1.0821845531463623
Epoch 19 loss: 1.1170412302017212
Epoch 20 loss: 1.044736623764038
Epoch 21 loss: 0.858221709728241
Epoch 22 loss: 0.8285454511642456
Epoch 23 loss: 0.8669906258583069
Epoch 24 loss: 1.1118721961975098
Epoch 25 loss: 1.1339545249938965
Epoch 26 loss: 1.152949571609497
Epoch 27 loss: 0.9197186827659607
Epoch 28 loss: 1.0091302394866943
Epoch 29 loss: 1.1065753698349
Epoch 30 loss: 0.9354009032249451
Epoch 31 loss: 0.9681801795959473
Epoch 32 loss: 0.934136152267456
Epoch 33 loss: 1.1814976930618286
Epoch 34 loss: 1.1219456195831299
Epoch 35 loss: 0.9535113573074341
Epoch 36 loss: 1.0602713823318481
Epoch 37 loss: 0.8352175951004028
Epoch 38 loss: 0.837505578994751
Epoch 39 loss: 0.9966219067573547
Epoch 40 loss: 0.9990330338478088
Epoch 41 loss: 0.902934193611145
Epoch 42 loss: 0.9071310758590698
Epoch 43 loss: 0.9718709588050842
Epoch 44 loss: 1.1083896160125732
Epoch 45 loss: 1.0436538457870483
Epoch 46 loss: 0.9891265630722046
Epoch 47 loss: 1.0118757486343384
Epoch 48 loss: 1.0493839979171753
Epoch 49 loss: 0.9719559550285339
Epoch 50 loss: 0.9746475219726562
Epoch 51 loss: 0.8756150603294373
Epoch 52 loss: 1.1143317222595215
Epoch 53 loss: 1.0076954364776611
Epoch 54 loss: 1.0678539276123047
Epoch 55 loss: 1.098946452140808
Epoch 56 loss: 1.0479623079299927
Epoch 57 loss: 0.8356451988220215
Epoch 58 loss: 1.066871166229248
Epoch 59 loss: 1.000270128250122
Epoch 60 loss: 1.0957651138305664
Epoch 61 loss: 0.9438618421554565
Epoch 62 loss: 1.0931410789489746
Epoch 63 loss: 0.8018264770507812
Epoch 64 loss: 0.8660523891448975
Epoch 65 loss: 1.01133131980896
Epoch 66 loss: 0.9543184041976929
Epoch 67 loss: 0.9269616603851318
Epoch 68 loss: 0.9933525323867798
Epoch 69 loss: 1.0068784952163696
Epoch 70 loss: 1.072657823562622
Epoch 71 loss: 1.026125192642212
Epoch 72 loss: 0.9402647614479065
Epoch 73 loss: 1.0038683414459229
Epoch 74 loss: 0.8630919456481934
Epoch 75 loss: 1.0004982948303223
Epoch 76 loss: 0.9028060436248779
Epoch 77 loss: 1.135162353515625
Epoch 78 loss: 1.0471564531326294
Epoch 79 loss: 1.036384105682373
Epoch 80 loss: 0.9773803949356079
Epoch 81 loss: 0.8492220640182495
Epoch 82 loss: 1.0323593616485596
Epoch 83 loss: 1.0219027996063232
Epoch 84 loss: 1.177842140197754
Epoch 85 loss: 1.1461409330368042
Epoch 86 loss: 0.9765742421150208
Epoch 87 loss: 0.9133744239807129
Epoch 88 loss: 0.8916490077972412
Epoch 89 loss: 0.9485730528831482
Epoch 90 loss: 1.1236549615859985
Epoch 91 loss: 0.9537147283554077
Epoch 92 loss: 0.9323829412460327
Epoch 93 loss: 0.8480007648468018
Epoch 94 loss: 0.922979474067688
Epoch 95 loss: 0.8562111854553223
Epoch 96 loss: 0.8881111145019531
Epoch 97 loss: 1.0885865688323975
Epoch 98 loss: 1.0496211051940918
Epoch 99 loss: 1.0826056003570557
Training time: 33.32 seconds

In [None]:
# 使用deepseed

'''
单卡直接运行：35.76 seconds
单卡accelerate：48.82 seconds
双卡accelerate：29.51 seconds
'''

from accelerate import Accelerator, DeepSpeedPlugin
import torch
from torch.utils.data import DataLoader, TensorDataset

import time

class SimpleNet(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)  #.to("cuda:0")
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim) #.to("cuda:1")

    def forward(self, x):
        # x.to("cuda:0")
        x = torch.relu(self.fc1(x))
        # x.to("cuda:1")
        x = self.fc2(x)
        return x

if __name__ == "__main__":
    input_dim = 10
    hidden_dim = 20
    output_dim = 2
    batch_size = 64
    data_size = 10000

    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    
    input_data = torch.randn(data_size, input_dim)
    labels = torch.randn(data_size, output_dim)

    dataset = TensorDataset(input_data, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = SimpleNet(input_dim, hidden_dim, output_dim)
    # model.to(device)
    
    deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_clipping=1.0)
    accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
    # accelerator = Accelerator()
    optimization = torch.optim.Adam(model.parameters(), lr=0.00015)
    crition = torch.nn.MSELoss()
    # print(f'len(dataloader):{len(dataloader)}')
    model, dataloader, optimization = accelerator.prepare(model, dataloader, optimization)
    # print(f'len(dataloader):{len(dataloader)}')
    start_time = time.time()
    for epoch in range(100):
        model.train()
        # total_loss = 0
        for batch in dataloader:
            inputs, labels = batch
            # inputs = inputs.to(device)
            # labels = labels.to(device)
            outputs = model(inputs)
            loss = crition(outputs, labels)
            # total_loss += loss.item()
            optimization.zero_grad()
            # loss.backward()
            accelerator.backward(loss)
            optimization.step()
        
        # 这里的loss只是取了最后一次的loss，并不是所有batch的loss

        # 下面两个方法得到的结果是一样的
        gather_loss = accelerator.gather(loss).mean()  # 收集所有 GPU 的损失并取平均
        if accelerator.is_main_process:
            print(f"Epoch {epoch} loss: {gather_loss.item()}")

        # avg_loss = accelerator.reduce(loss, reduction="mean")
        # if accelerator.is_main_process:
        #     print(f"Epoch {epoch} loss: {avg_loss.item()}")

    end_time = time.time()  # 记录训练结束时间
    training_time = end_time - start_time  # 计算训练时间
    if accelerator.is_main_process:
        print(f"Training time: {training_time:.2f} seconds")

    # accelerator.save(model.state_dict(), "model.pth")
            
    
    