In [14]:
# -*- coding: utf-8 -*-
from gemm_tiling import gemm_tiling_input_stationary, gemm_tiling_weight_stationary
import transformer_block as tbk
import arch_execution as arch
from util import *
import math
from mapper import *

In [15]:
llm_config = load_config("./input/transformer/input0.json")
model = tbk.Llama_block(llm_config)
tx8_config = load_config('hardware_parameter.json')
hardware = arch.Tx8(tx8_config)
Layers = model.config['L']
ops = model.ops
mapping_result = {}

In [16]:
# =============== FNN Up  ===================
# FFN Gate 与 FFN Up是完全同样的size的计算，这里不列出了
M, K, N = 4096, 4096, 11008
details = False
B = 1
print(f"FFN Up analysis, B={B}")

tile_m = 128
tile_n = 43
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")


# [N/tile_n, B*M/tile_m] -- SRAM exceed -- debug shape info
Tm_Tn = [int(tile_m), int(tile_n)]
input_stationary = True  # ishape <--> oshape, internal exchange shape

mapping_result['FFNup'] = gemm_auto_opt_mapper(
    ops['FFNup'], hardware, input_stationary=input_stationary, Tm_Tn=Tm_Tn, details=details)
utilization = mapping_result['FFNup']["utilization"]*100
stationary = "input" if input_stationary else "weight"
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
print(f"{stationary} Stationary, Tm_Tn={Tm_Tn}, Hardware utilization: {utilization:.2f}%")
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")

B = 16
tile_m = 4
tile_n = 86
print(f"FFN Up analysis, B={B}")
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")


# [N/tile_n, B*M/tile_m] -- SRAM exceed -- debug shape info
Tm_Tn = [int(tile_m), int(tile_n)]
input_stationary = False  # ishape <--> oshape, internal exchange shape

mapping_result['FFNup'] = gemm_auto_opt_mapper(
    ops['FFNup'], hardware, input_stationary=input_stationary, Tm_Tn=Tm_Tn, details=details)
utilization = mapping_result['FFNup']["utilization"]*100
stationary = "input" if input_stationary else "weight"
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
print(f"{stationary} Stationary, Tm_Tn={Tm_Tn}, Hardware utilization: {utilization:.2f}%")
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")

FFN Up analysis, B=1
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=128, tile_n=43, stationary: input, utilization=91.17%
输入不复用则SRAM满足要求
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
input Stationary, Tm_Tn=[128, 43], Hardware utilization: 90.92%
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
FFN Up analysis, B=16
FFN Up, M=4096, K=4096, N=11008, B=16, tile_m=4, tile_n=86, stationary: input, utilization=98.28%
输入不复用则SRAM满足要求
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
weight Stationary, Tm_Tn=[4, 86], Hardware utilization: 77.20%
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$


In [25]:
# =============== FNN down  ===================
M, K, N = 4096, 11008, 4096
B = 1


tile_m, tile_n = 128, 4
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Down, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")


# [N/tile_n, B*M/tile_m] -- SRAM exceed -- debug shape info
Tm_Tn = [int(tile_m), int(tile_n)]
input_stationary = True  # ishape <--> oshape, internal exchange shape

mapping_result['FFNup'] = gemm_auto_opt_mapper(
    ops['FFNup'], hardware, input_stationary=input_stationary, Tm_Tn=Tm_Tn, details=details)
utilization = mapping_result['FFNup']["utilization"]*100
stationary = "input" if input_stationary else "weight"
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
print(f"{stationary} Stationary, Tm_Tn={Tm_Tn}, Hardware utilization: {utilization:.2f}%")
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")

B = 16
tile_m, tile_n = 4, 128
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Down, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

# [N/tile_n, B*M/tile_m] -- SRAM exceed -- debug shape info
Tm_Tn = [int(tile_m), int(tile_n)]
input_stationary = False  # ishape <--> oshape, internal exchange shape

mapping_result['FFNup'] = gemm_auto_opt_mapper(
    ops['FFNup'], hardware, input_stationary=input_stationary, Tm_Tn=Tm_Tn, details=details)
utilization = mapping_result['FFNup']["utilization"]*100
stationary = "input" if input_stationary else "weight"
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
print(f"{stationary} Stationary, Tm_Tn={Tm_Tn}, Hardware utilization: {utilization:.2f}%")
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")

FFN Down, M=4096, K=11008, N=4096, B=1, tile_m=128, tile_n=4, stationary: input, utilization=77.63%
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
input Stationary, Tm_Tn=[128, 4], Hardware utilization: 90.10%
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
FFN Down, M=4096, K=11008, N=4096, B=16, tile_m=4, tile_n=128, stationary: weight, utilization=98.23%
输入不复用则SRAM满足要求
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
weight Stationary, Tm_Tn=[4, 128], Hardware utilization: 77.20%
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$


In [17]:
# =============== FNN Up  ===================
# FFN Gate 与 FFN Up是完全同样的size的计算，这里不列出了
M, K, N = 4096, 4096, 11008
B = 1
details = False

print(f"FFN Up analysis, B={B}")

tile_m = 128
tile_n = 32
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")


Tm_Tn = [int(tile_m), int(tile_n)]
print(f"Batch size = {ops['K_proj']['ishape'][0]}")

mapping_result['K_proj'] = gemm_auto_opt_mapper(
    ops['K_proj'], hardware, Tm_Tn=Tm_Tn, details=details)
utilization = mapping_result['K_proj']["utilization"]*100
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
print(f"Tm_Tn={Tm_Tn}, Hardware utilization: {utilization:.2f}%")
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")

tile_m = 128
tile_n = 43
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")


tile_m = 256
tile_n = 32
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")

tile_m = 256
tile_n = 21
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")

tile_m = 32
tile_n = 128
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

tile_m = 32
tile_n = 256
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

FFN Up analysis, B=1
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=128, tile_n=32, stationary: input, utilization=89.10%
Batch size = 1
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
Tm_Tn=[128, 32], Hardware utilization: 78.47%
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=128, tile_n=43, stationary: input, utilization=91.17%
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=256, tile_n=32, stationary: input, utilization=87.26%
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=256, tile_n=21, stationary: input, utilization=88.88%
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=32, tile_n=128, stationary: weight, utilization=74.68%
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=32, tile_n=256, stationary: weight, utilization=73.31%


In [None]:
# =============== FNN Up  ===================
# FFN Gate 与 FFN Up是完全同样的size的计算，这里不列出了
M, K, N = 4096, 4096, 11008
B = 1
details = False

print(f"FFN Up analysis, B={B}")

tile_m = 128
tile_n = 32
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")


Tm_Tn = [int(tile_m), int(tile_n)]
print(f"Batch size = {ops['K_proj']['ishape'][0]}")

mapping_result['K_proj'] = gemm_auto_opt_mapper(
    ops['K_proj'], hardware, Tm_Tn=Tm_Tn, details=details)
utilization = mapping_result['K_proj']["utilization"]*100
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
print(f"Tm_Tn={Tm_Tn}, Hardware utilization: {utilization:.2f}%")
print(f"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")

tile_m = 128
tile_n = 43
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")


tile_m = 256
tile_n = 32
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")

tile_m = 256
tile_n = 21
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")

tile_m = 32
tile_n = 128
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

tile_m = 32
tile_n = 256
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

FFN Up analysis, B=1
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=128, tile_n=32, stationary: input, utilization=89.10%
Batch size = 1
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
Tm_Tn=[128, 32], Hardware utilization: 78.47%
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=128, tile_n=43, stationary: input, utilization=91.17%
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=256, tile_n=32, stationary: input, utilization=87.26%
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=256, tile_n=21, stationary: input, utilization=88.88%
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=32, tile_n=128, stationary: weight, utilization=74.68%
FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=32, tile_n=256, stationary: weight, utilization=73.31%


因为 FFN up是升维，N的维度很高11008，所以input stationary的性能会比weight stationary好很多

下面分析batch size的影响

In [18]:
M, K, N = 4096, 4096, 11008
B = 16
print(f"FFN Up analysis, B={B}")


tile_m = 128
tile_n = 32
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")


tile_m = 128
tile_n = 43
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")


tile_m = 256
tile_n = 32
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")

tile_m = 256
tile_n = 21
utilization = gemm_tiling_input_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: input, utilization={utilization:.2f}%")

tile_m = 32
tile_n = 128
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

tile_m = 32
tile_n = 256
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

tile_m = 8
tile_n = 128
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

FFN Up analysis, B=16
FFN Up, M=4096, K=4096, N=11008, B=16, tile_m=128, tile_n=32, stationary: input, utilization=91.96%
FFN Up, M=4096, K=4096, N=11008, B=16, tile_m=128, tile_n=43, stationary: input, utilization=95.17%
FFN Up, M=4096, K=4096, N=11008, B=16, tile_m=256, tile_n=32, stationary: input, utilization=91.78%
FFN Up, M=4096, K=4096, N=11008, B=16, tile_m=256, tile_n=21, stationary: input, utilization=91.91%
FFN Up, M=4096, K=4096, N=11008, B=16, tile_m=32, tile_n=128, stationary: weight, utilization=88.48%
FFN Up, M=4096, K=4096, N=11008, B=16, tile_m=32, tile_n=256, stationary: weight, utilization=88.36%
FFN Up, M=4096, K=4096, N=11008, B=16, tile_m=8, tile_n=128, stationary: weight, utilization=88.11%


In [19]:
M, K, N = 4096, 4096, 11008
B = 1
tile_m, tile_n = 4, 86
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")
B = 16
tile_m, tile_n = 4, 86
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=False)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

FFN Up, M=4096, K=4096, N=11008, B=1, tile_m=4, tile_n=86, stationary: weight, utilization=78.16%
FFN Up, M=4096, K=4096, N=11008, B=16, tile_m=4, tile_n=86, stationary: weight, utilization=98.28%


详细分析

In [20]:
M, K, N = 4096, 4096, 11008
B = 16
tile_m, tile_n = 4, 86
utilization = gemm_tiling_weight_stationary(
    B, M, K, N, tile_m, tile_n, print_details=True)
print(f"FFN Up, M={M}, K={K}, N={N}, B={B}, tile_m={tile_m}, tile_n={tile_n}, stationary: weight, utilization={utilization:.2f}%")

+-------------------+--------------------------+
|        var        |         mem (MB)         |
+-------------------+--------------------------+
|    input_size     | 0.671875 * 2 =  1.343750 |
|    weight_size    | 0.031250 * 3 =  0.093750 |
|    output_size    | 0.000656 * 2 =  0.001312 |
|    total_size     |         1.438812         |
| input_load_iters  |            8             |
| weight_load_iters |           1024           |
+-------------------+--------------------------+
+-----------------------+------------+
|         unit          | time (us)  |
+-----------------------+------------+
|    input_load_time    | 105.080469 |
|   weight_load_time    |  4.982813  |
|    weight_noc_time    |  0.248419  |
| compute_time_one_tile |  0.352256  |
|   output_save_time    |  0.202520  |
+-----------------------+------------+
+--------------------------+------------+
|           item           | time (us)  |
+--------------------------+------------+
|  time_one_noc_pipe_flow  |  5.6

In [21]:
# constants
GB = 1024*1024*1024
MB = 1024*1024
KB = 1024
ns = 1e-9
us = 1e-6
ms = 1e-3
# 1 TFLOPS = 1e12 FLOPS 还是 2**40 FLOPS （1.0995e12)? 这个影响还是有一些的
TFLOPS = 1e12
# hardware configuration
data_type = 2  # 2 bytes for FP16, 4 bytes for FP32
Tile_num = 4 * 4  # 4x4 tiles
Tile_SRAM = 3 * MB  # 3MB
Tile_compute = 128/16 * TFLOPS  # 8 TFLOPS
DDR_BW = 100 * GB  # 100GB/s
NOC_BW = 128 * GB  # 128GB/s
NOC_latency_hop = 10 * ns  # 10ns for 1 hop
DDR_latency = 0 * ns  # 100ns

In [22]:
print(f"tile_m need to be > {Tile_compute/NOC_BW}, to hide the NOC time")

print(f"tile_m need to be > {Tile_compute/DDR_BW}, to hide the DDR time")

tile_m need to be > 58.20766091346741, to hide the NOC time
tile_m need to be > 74.50580596923828, to hide the DDR time


In [23]:
11008/16

688.0

In [24]:
688.0/16

43.0