根据A10G和L40S的结果，得到GPT2-6B在4096的sequence length之下的不均匀切分数据时候，能得到比均匀切分更好的结果。

In [21]:
import torch 
import json
from typing import Dict, List

In [22]:
# read data from file
model_compute = {}
for device in ["A10G", "L40S"]:
    model_compute.setdefault(device, {})
    for bs in [1, 2, 4]:
        file = f"../metis/GPT_2-6B/4096/DeviceType.{device}_tp2_bs{bs}.json"
        try:
            with open(file, 'r') as f:
                data = json.load(f)
        except FileNotFoundError:
            print("The file was not found.")
            exit()
        except json.JSONDecodeError:
            print("The file does not contain valid JSON.")
            exit()
        model_compute[device][bs] = data["execution_time"]["layer_compute_total_ms"]

print(model_compute)
    

{'A10G': {1: [36.14472167968749, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 116.590263671875, 64.28416015625], 2: [74.28806396484374, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.081806640625, 234.08180

读取数据后，计算不同配置下的运行时间

In [23]:

def split_to_powers_of_two(n:int) -> List[int]:
    powers = []
    while n > 0:
        power = 1
        while power * 2 <= n:
            power *= 2
        powers.append(power)
        n -= power
    return powers

def get_iteration_time_per_mbs(model_compute: Dict[str, Dict[int, List[float]]], device: str, bs: int) -> int:
    '''
    Return the iteration time. Unit s.
    '''
    splited_bs = split_to_powers_of_two(bs)
    total_time = 0
    for bs in splited_bs:
        total_time += sum(model_compute[device][bs]) / 1000
    
    return total_time
    




In [24]:
balance_time_per_mbs = max(get_iteration_time_per_mbs(model_compute, "A10G", 2), get_iteration_time_per_mbs(model_compute, "L40S", 2))
imbalance_time_per_mbs = max(get_iteration_time_per_mbs(model_compute, "A10G", 1), get_iteration_time_per_mbs(model_compute, "L40S", 3))
gbs = 1024
mbs = 4
balance_time = balance_time_per_mbs * gbs / mbs
imbalance_time = imbalance_time_per_mbs * gbs / mbs
print(f"balance: {balance_time}, imbalance: {imbalance_time}")

balance: 1969.4212168749982, imbalance: 980.8172337500004
