根据A10G和L40S的结果，得到GPT2-6B在4096的sequence length之下的不均匀切分数据时候，能得到比均匀切分更好的结果。

In [1]:
import torch 
import json
from typing import Dict, List

In [2]:
# read data from file
model_compute = {}
for device in ["A10G", "L40S"]:
    model_compute.setdefault(device, {})
    for bs in [1]:
        file = f"../metis/GPT_2-6B/4096/DeviceType.{device}_tp2_bs{bs}.json"
        try:
            with open(file, 'r') as f:
                data = json.load(f)
        except FileNotFoundError:
            print("The file was not found.")
            exit()
        except json.JSONDecodeError:
            print("The file does not contain valid JSON.")
            exit()
        model_compute[device][bs] = data["execution_time"]["layer_compute_total_ms"]

print(model_compute)
    

{'A10G': {1: [36.14472167968749, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 64.28416015625]}, 'L40S': {1: [11.615805053710938, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 

In [9]:
print(f"A10G: {model_compute['A10G'][1]}")
num_layers = len(model_compute['A10G'][1])
print(f"L40S: {model_compute['L40S'][1]}")
print(f"num_layers: {num_layers}")

balance_stage_time = [sum(model_compute['L40S'][1][ : num_layers // 2]), sum(model_compute['A10G'][1][num_layers // 2 : ])]

imbalance_stage_time = [sum(model_compute['L40S'][1][ : num_layers // 4 * 3]), sum(model_compute['A10G'][1][num_layers // 4 * 3: ])]
print(f"balance: {balance_stage_time}, imbalance: {imbalance_stage_time}")


A10G: [36.14472167968749, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 64.28416015625]
L40S: [11.615805053710938, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375, 34.540380859375

In [10]:
def pp_iteration_time(stage_time_ms:List[float], nums_of_mbs: int) -> float:
    '''
    Unit in s.
    '''
    
    bottle_stage_time = max(stage_time_ms)
    bottle_stage = stage_time_ms.index(bottle_stage_time) 
    
    num_stages = len(stage_time_ms)
    steady_time = (nums_of_mbs - num_stages + bottle_stage) * bottle_stage_time + (num_stages - bottle_stage - 1) * (2 * bottle_stage_time / 3)
    total_time = steady_time + sum(stage_time_ms)

    return total_time / 1000
    
    

In [11]:
gbs = 1024
mbs = 1
nums_of_mbs = gbs / mbs
balance_time = pp_iteration_time(balance_stage_time, nums_of_mbs)

imbalance_time = pp_iteration_time(imbalance_stage_time, nums_of_mbs)
print(f"balance_time: {balance_time}, imbalance_time: {imbalance_time}")

balance_time: 1976.606121898803, imbalance_time: 1141.1288945648191


读取数据后，计算不同配置下的运行时间

In [23]:

def split_to_powers_of_two(n:int) -> List[int]:
    powers = []
    while n > 0:
        power = 1
        while power * 2 <= n:
            power *= 2
        powers.append(power)
        n -= power
    return powers

def get_iteration_time_per_mbs(model_compute: Dict[str, Dict[int, List[float]]], device: str, bs: int) -> int:
    '''
    Return the iteration time. Unit s.
    '''
    splited_bs = split_to_powers_of_two(bs)
    total_time = 0
    for bs in splited_bs:
        total_time += sum(model_compute[device][bs]) / 1000
    
    return total_time
    




In [24]:
balance_time_per_mbs = max(get_iteration_time_per_mbs(model_compute, "A10G", 2), get_iteration_time_per_mbs(model_compute, "L40S", 2))
imbalance_time_per_mbs = max(get_iteration_time_per_mbs(model_compute, "A10G", 1), get_iteration_time_per_mbs(model_compute, "L40S", 3))
gbs = 1024
mbs = 4
balance_time = balance_time_per_mbs * gbs / mbs
imbalance_time = imbalance_time_per_mbs * gbs / mbs
print(f"balance: {balance_time}, imbalance: {imbalance_time}")

balance: 1969.4212168749982, imbalance: 980.8172337500004
