估算WSE的计算延迟。
Core allocation：对于每个operator，假设其输入输出同时需要放在分配到的核上，把给定的wafer吃满，按照core数去分配。我们提前给一个合理的wafer setup。

In [None]:
from math import sqrt

In [None]:
# 
arch_config = {
    'num_mac': 4,
    'noc_bandwidth': 4,
    'memory_size': 48 * 1024,
    "core_array_size": 66 * 154,
    "reticle_array_size": 12 * 8,
}
TOTAL_WSE_CORES = 66 * 154 * 12 * 8

In [None]:
import onnx
from itertools import chain
from functools import reduce

onnx_model = onnx.load('gpt2.onnx')
PRECISION = 2 # BF16


In [None]:
# 估算模型参数量
total_weight_size = 0
for val in onnx_model.graph.input:
    if "onnx::" in val.name:
        continue
    shape = [d.dim_value for d in val.type.tensor_type.shape.dim]
    total_weight_size += reduce(lambda x, y: x * y, shape) * PRECISION
print(f"Total weight size: {total_weight_size / (1024 * 1024 * 1024)} GB")

In [None]:
# 估计算子存储开销
# 就认为是一个operator用到的所有input和output的大小之和
# 如果intra-layer细粒度传输，每个算子的存储需求，就认为是output的需求
tensor_name_2_shape = {}
tensor_name_2_size = {}
for val in chain(onnx_model.graph.input, onnx_model.graph.value_info, onnx_model.graph.output):
    name = val.name
    shape = [d.dim_value for d in val.type.tensor_type.shape.dim]
    tensor_name_2_shape[name] = shape
    tensor_name_2_size[name] = (reduce(lambda x, y: x * y, shape) if shape else 1) * PRECISION
for val in onnx_model.graph.initializer:
    name = val.name
    shape = val.dims
    tensor_name_2_shape[name] = shape
    tensor_name_2_size[name] = (reduce(lambda x, y: x * y, shape) if shape else 1) * PRECISION
    
op_2_memory_consumption = {}
for op_proto in onnx_model.graph.node:
    name = op_proto.name
    memory_consumption = 0
    # for tensor in chain(op_proto.input, op_proto.output):
    # only consider 1 copy of output
    for tensor in op_proto.output:
        memory_consumption += tensor_name_2_size[tensor]
    op_2_memory_consumption[name] = memory_consumption 

total_memory = reduce(lambda x, y: x + y, list(op_2_memory_consumption.values()))
print(f"Total memory: {total_memory / (1024 * 1024 * 1024)} GB")

In [None]:
# 估计传输开销
# 只考虑input当中，位于value_info里面的，我们认为这些是需要inter-layer传输的
# 还没除带宽，只考虑了传输量
INTERMEDIATE_TENSOR_TYPE = 1
WEIGHT_TENSOR_TYPE = 2
tensor_name_2_type = {val.name: INTERMEDIATE_TENSOR_TYPE for val in onnx_model.graph.value_info}
tensor_name_2_type.update({val.name: WEIGHT_TENSOR_TYPE for val in onnx_model.graph.input})
tensor_name_2_type.update({val.name: WEIGHT_TENSOR_TYPE for val in onnx_model.graph.output})
tensor_name_2_type.update({val.name: WEIGHT_TENSOR_TYPE for val in onnx_model.graph.initializer})

op_2_comm_cost = {}
for op_proto in onnx_model.graph.node:
    name = op_proto.name
    comm_cost = 0
    for tensor in op_proto.input:
        if tensor_name_2_type[tensor] == INTERMEDIATE_TENSOR_TYPE:
            comm_cost += tensor_name_2_size[tensor]
    op_2_comm_cost[name] = comm_cost


In [None]:
# 统计所有张量的计算延迟
# 只考虑总的计算量
# 先看一下有啥类型的计算，然后把其中计算量大的统计出来即可
def get_compute_cost(op_proto):
    op_type = op_proto.op_type
    if op_type == "Conv":
        x, w = op_proto.input
        x_shape, w_shape = tensor_name_2_shape[x], tensor_name_2_shape[w]
        total_macs = reduce(lambda x, y: x * y, chain(x_shape, w_shape)) // x_shape[1]
        y = op_proto.output
        y_shape = tensor_name_2_shape[y]
        total_macs += reduce(lambda x, y: x + y, y_shape)
        return total_macs

    elif op_type == "Gemm":
        a, b, c = op_proto.input
        a_shape, b_shape, c_shape = tensor_name_2_shape[a], tensor_name_2_shape[b], tensor_name_2_shape[c]
        total_macs = a_shape[0] * (a_shape[1] + 1) * b_shape[1]
        return total_macs
        
    elif op_type in ["Add", "Sub", "Mul", "Div"]:
        a, b = op_proto.input
        a_shape, b_shape = tensor_name_2_shape[a], tensor_name_2_shape[b]
        is_scalar = lambda x: len(x) == 0
        if is_scalar(a_shape) and is_scalar(b_shape):
            return 1
        elif is_scalar(a_shape) and not is_scalar(b_shape):
            return tensor_name_2_size[b]
        elif not is_scalar(a_shape) and is_scalar(b_shape):
            return tensor_name_2_size[a]
        else:
            broadcast_shape = [max(i, j) for i, j in zip(a_shape, b_shape)]
            return reduce(lambda x, y: x * y, broadcast_shape)

    else:
        return 0

op_2_compute_cost = {}
for op_proto in onnx_model.graph.node:
    name = op_proto.name
    op_2_compute_cost[name] = get_compute_cost(op_proto)
# print(op_2_compute_cost)

In [None]:
# 核的分配：
# 先分配不卡计算的，按照memory需求给最少的核
# 再分配卡计算的，按照计算需求吃满所有的核

compute_bounded_op_type = [
    'Gemm',
]

op_2_core_alloc = {
    op.name: op_2_memory_consumption[op.name] // arch_config['memory_size'] for op in onnx_model.graph.node
    if op.op_type not in compute_bounded_op_type
}
mem_bounded_total_core = reduce(lambda x, y: x + y, list(op_2_core_alloc.values()))
compute_bounded_ops = [op for op in onnx_model.graph.node if op.op_type in compute_bounded_op_type]
total_compute_bounded_cost = reduce(lambda x, y: x + y, 
                                    [op_2_compute_cost[op.name] for op in compute_bounded_ops])
cur_total_core = TOTAL_WSE_CORES - mem_bounded_total_core
assert cur_total_core > 0, cur_total_core
op_2_core_alloc.update({
    op.name: int((op_2_compute_cost[op.name] / total_compute_bounded_cost) * cur_total_core)
    for op in compute_bounded_ops
})
print(op_2_core_alloc)

In [None]:
# 核的分配
# 就按照计算需求分配，不考虑爆内存

# total_compute_cost = reduce(lambda x, y: x + y, list(op_2_compute_cost.values()))

# op_2_core_alloc = {
#     op.name: int(op_2_compute_cost[op.name] / total_compute_cost * TOTAL_WSE_CORES)
#     for op in onnx_model.graph.node
# }
# print(op_2_core_alloc)

In [None]:
def get_comm_latency(op):
    core_alloc = op_2_core_alloc[op.name]
    if core_alloc:
        # return op_2_comm_cost[op.name] // (arch_config['noc_bandwidth'] * int(sqrt(op_2_core_alloc[op.name])))
        return op_2_comm_cost[op.name] // (arch_config['noc_bandwidth'] * op_2_core_alloc[op.name])
    else:
        return 0

def get_compute_latency(op):
    core_alloc = op_2_core_alloc[op.name]
    if core_alloc:
        return op_2_compute_cost[op.name] // (arch_config['num_mac'] * op_2_core_alloc[op.name])
    else:
        return 0

op_2_comm_latency = {
    op.name: get_comm_latency(op)
    for op in onnx_model.graph.node
}
op_2_compute_latency = {
    op.name: get_compute_latency(op)
    for op in onnx_model.graph.node
}

print(max(list(op_2_comm_latency.values())))  # 计算开销
print(max(list(op_2_compute_latency.values())))  # 传输开销