In [1]:
# Adapted from https://le.qun.ch/en/blog/2023/05/13/transformer-batching/
import itertools
import numpy as np
import pandas as pd
import pickle
import gzip
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [2]:
import torch
torch.set_grad_enabled(False)
import transformers
from tqdm import tqdm
import intrasm_engine
from benchmark.utils_bench_transformer import gen_opt_cfg, benchmark_dense, benchmark_qk_ar, benchmark_qk_init, time_greedy_generate

Print handle to force cublas initialization (otherwise first matmul captured in the graph may fail): 190660736
Setting float16 and bf16 using reduced precision in reduction


In [3]:
optcfg = gen_opt_cfg()

In [4]:
#nd_list = list(itertools.chain(itertools.product([12, 16, 32], [64]), itertools.product([32, 40, 56, 72, 96], [128])))
#seqlen_list = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000]
# Reducing the size to fit in RTX 3090 device memory
nd_list = list(itertools.chain(itertools.product([12, 16, 32], [64]), itertools.product([32, 40, 56], [128])))
seqlen_list = [10, 20, 50, 100, 200, 500, 1000]
bs_list = list(itertools.chain(range(1, 8), range(8, 16, 2), range(16, 32, 4), range(32, 64, 8), range(64, 128, 16), [128]))

In [5]:
print(nd_list)
print(seqlen_list)
print(bs_list)

[(12, 64), (16, 64), (32, 64), (32, 128), (40, 128), (56, 128)]
[10, 20, 50, 100, 200, 500, 1000]
[1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80, 96, 112, 128]


In [6]:
data = {}

In [7]:
db = []
benchmark_qk_init(db, nd_list, seqlen_list, bs_list)
data["qk_init"] = db

  0%|          | 0/1008 [00:00<?, ?it/s, bs=128, d=128, h=7168, n=56, seqlen=1000]

Working on 56 128 1000


  2%|▏         | 24/1008 [00:48<32:51,  2.00s/it, bs=128, d=128, h=7168, n=56, seqlen=500]

Working on 56 128 500


  5%|▍         | 48/1008 [01:36<32:01,  2.00s/it, bs=128, d=128, h=7168, n=56, seqlen=200]

Working on 56 128 200


  7%|▋         | 72/1008 [02:24<31:13,  2.00s/it, bs=128, d=128, h=7168, n=56, seqlen=100]

Working on 56 128 100


 10%|▉         | 96/1008 [03:13<30:31,  2.01s/it, bs=128, d=128, h=7168, n=56, seqlen=50] 

Working on 56 128 50


 12%|█▏        | 120/1008 [04:01<29:37,  2.00s/it, bs=128, d=128, h=7168, n=56, seqlen=20]

Working on 56 128 20


 14%|█▍        | 144/1008 [04:49<28:49,  2.00s/it, bs=128, d=128, h=7168, n=56, seqlen=10]

Working on 56 128 10


 17%|█▋        | 168/1008 [05:37<28:01,  2.00s/it, bs=128, d=128, h=5120, n=40, seqlen=1000]

Working on 40 128 1000


 19%|█▉        | 192/1008 [06:25<27:14,  2.00s/it, bs=128, d=128, h=5120, n=40, seqlen=500] 

Working on 40 128 500


 21%|██▏       | 216/1008 [07:13<26:25,  2.00s/it, bs=128, d=128, h=5120, n=40, seqlen=200]

Working on 40 128 200


 24%|██▍       | 240/1008 [08:01<25:37,  2.00s/it, bs=128, d=128, h=5120, n=40, seqlen=100]

Working on 40 128 100


 26%|██▌       | 264/1008 [08:49<24:49,  2.00s/it, bs=128, d=128, h=5120, n=40, seqlen=50] 

Working on 40 128 50


 29%|██▊       | 288/1008 [09:37<24:01,  2.00s/it, bs=128, d=128, h=5120, n=40, seqlen=20]

Working on 40 128 20


 31%|███       | 312/1008 [10:25<23:13,  2.00s/it, bs=128, d=128, h=5120, n=40, seqlen=10]

Working on 40 128 10


 33%|███▎      | 336/1008 [11:13<22:25,  2.00s/it, bs=128, d=128, h=4096, n=32, seqlen=1000]

Working on 32 128 1000


 36%|███▌      | 360/1008 [12:02<21:37,  2.00s/it, bs=128, d=128, h=4096, n=32, seqlen=500] 

Working on 32 128 500


 38%|███▊      | 384/1008 [12:50<20:49,  2.00s/it, bs=128, d=128, h=4096, n=32, seqlen=200]

Working on 32 128 200


 40%|████      | 408/1008 [13:38<20:01,  2.00s/it, bs=128, d=128, h=4096, n=32, seqlen=100]

Working on 32 128 100


 43%|████▎     | 432/1008 [14:26<19:13,  2.00s/it, bs=128, d=128, h=4096, n=32, seqlen=50] 

Working on 32 128 50


 45%|████▌     | 456/1008 [15:14<18:24,  2.00s/it, bs=128, d=128, h=4096, n=32, seqlen=20]

Working on 32 128 20


 48%|████▊     | 480/1008 [16:02<17:36,  2.00s/it, bs=128, d=128, h=4096, n=32, seqlen=10]

Working on 32 128 10


 50%|█████     | 504/1008 [16:50<16:48,  2.00s/it, bs=128, d=64, h=2048, n=32, seqlen=1000]

Working on 32 64 1000


 52%|█████▏    | 528/1008 [17:38<16:01,  2.00s/it, bs=128, d=64, h=2048, n=32, seqlen=500] 

Working on 32 64 500


 55%|█████▍    | 552/1008 [18:26<15:12,  2.00s/it, bs=128, d=64, h=2048, n=32, seqlen=200]

Working on 32 64 200


 57%|█████▋    | 576/1008 [19:14<14:24,  2.00s/it, bs=128, d=64, h=2048, n=32, seqlen=100]

Working on 32 64 100


 60%|█████▉    | 600/1008 [20:02<13:36,  2.00s/it, bs=128, d=64, h=2048, n=32, seqlen=50] 

Working on 32 64 50


 62%|██████▏   | 624/1008 [20:51<12:48,  2.00s/it, bs=128, d=64, h=2048, n=32, seqlen=20]

Working on 32 64 20


 64%|██████▍   | 648/1008 [21:39<12:00,  2.00s/it, bs=128, d=64, h=2048, n=32, seqlen=10]

Working on 32 64 10


 67%|██████▋   | 672/1008 [22:27<11:12,  2.00s/it, bs=128, d=64, h=1024, n=16, seqlen=1000]

Working on 16 64 1000


 69%|██████▉   | 696/1008 [23:15<10:24,  2.00s/it, bs=128, d=64, h=1024, n=16, seqlen=500] 

Working on 16 64 500


 71%|███████▏  | 720/1008 [24:03<09:36,  2.00s/it, bs=128, d=64, h=1024, n=16, seqlen=200]

Working on 16 64 200


 74%|███████▍  | 744/1008 [24:51<08:48,  2.00s/it, bs=128, d=64, h=1024, n=16, seqlen=100]

Working on 16 64 100


 76%|███████▌  | 768/1008 [25:39<08:00,  2.00s/it, bs=128, d=64, h=1024, n=16, seqlen=50] 

Working on 16 64 50


 79%|███████▊  | 792/1008 [26:27<07:12,  2.00s/it, bs=128, d=64, h=1024, n=16, seqlen=20]

Working on 16 64 20


 81%|████████  | 816/1008 [27:15<06:24,  2.00s/it, bs=128, d=64, h=1024, n=16, seqlen=10]

Working on 16 64 10


 83%|████████▎ | 840/1008 [28:03<05:36,  2.00s/it, bs=128, d=64, h=768, n=12, seqlen=1000]

Working on 12 64 1000


 86%|████████▌ | 864/1008 [28:51<04:48,  2.00s/it, bs=128, d=64, h=768, n=12, seqlen=500] 

Working on 12 64 500


 88%|████████▊ | 888/1008 [29:39<04:00,  2.00s/it, bs=128, d=64, h=768, n=12, seqlen=200]

Working on 12 64 200


 90%|█████████ | 912/1008 [30:27<03:12,  2.00s/it, bs=128, d=64, h=768, n=12, seqlen=100]

Working on 12 64 100


 93%|█████████▎| 936/1008 [31:15<02:24,  2.00s/it, bs=128, d=64, h=768, n=12, seqlen=50] 

Working on 12 64 50


 95%|█████████▌| 960/1008 [32:03<01:36,  2.00s/it, bs=128, d=64, h=768, n=12, seqlen=20]

Working on 12 64 20


 98%|█████████▊| 984/1008 [32:51<00:48,  2.00s/it, bs=128, d=64, h=768, n=12, seqlen=10]

Working on 12 64 10


100%|██████████| 1008/1008 [33:39<00:00,  2.00s/it, bs=1, d=64, h=768, n=12, seqlen=10] 


In [None]:
db = []
benchmark_qk_ar(db, nd_list, seqlen_list, bs_list)
data["qk_ar"] = db

In [None]:
db = []
benchmark_dense(db, nd_list, seqlen_list, bs_list)
data["dense"] = db

In [None]:
with gzip.open("data/20230516-transformer-batching.pkl.gz", "wb") as f:
    pickle.dump(data, f)

In [None]:
df_dense = (
    pd.DataFrame.from_dict(data["dense"])
    .assign(h=lambda x: x["n"] * x["d"])
    .assign(flop=lambda x: (x["bs"] * x["seqlen"] * x["h"]**2) * 2)
    .assign(io=lambda x: (x["bs"]*x["seqlen"]*x["h"]*2 + x["h"]**2) * 2)
    .assign(intensity=lambda x: x["flop"] / x["io"])
    .assign(throughput=lambda x: x["bs"]*x["seqlen"] / x["latency"])
    .assign(series="dense")
)
df_qk_init = (
    pd.DataFrame.from_dict(data["qk_init"])
    .assign(h=lambda x: x["n"] * x["d"])
    .assign(flop=lambda x: (x["bs"]*x["n"]*x["d"]*x["seqlen"]**2) * 2)
    .assign(io=lambda x: (x["bs"]*x["n"]*(x["seqlen"]*x["d"]*2 + x["seqlen"]**2)) * 2)
    .assign(intensity=lambda x: x["flop"] / x["io"])
    .assign(throughput=lambda x: x["bs"]*x["seqlen"] / x["latency"])
    .assign(series="qk_init")
)
df_qk_ar = (
    pd.DataFrame.from_dict(data["qk_ar"])
    .assign(h=lambda x: x["n"] * x["d"])
    .assign(flop=lambda x: (x["bs"]*x["n"]*x["d"]*x["seqlen"]) * 2)
    .assign(io=lambda x: (x["bs"]*x["n"]*(x["d"] + x["seqlen"]*x["d"] + x["seqlen"])) * 2)
    .assign(intensity=lambda x: x["flop"] / x["io"])
    .assign(throughput=lambda x: x["bs"] / x["latency"])
    .assign(series="qk_ar")
)
pd.concat([df_dense, df_qk_init, df_qk_ar]).to_csv("data/transformer-batching-microbenchmarks.csv", index=False)

In [None]:
opt_config = optcfg["6.7b"]

torch.set_default_dtype(torch.bfloat16)
with transformers.modeling_utils.no_init_weights():
    model = transformers.models.opt.OPTForCausalLM(opt_config).to("cuda")
torch.set_default_dtype(torch.float32)

In [None]:
db = {}
input_tokens = 200
new_tokens = 500
for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):
    x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)
    stack = []
    for _ in range(10):
        l = time_greedy_generate(model, x, new_tokens=new_tokens)
        stack.append(l)
    db[bs] = np.median(np.stack(stack), axis=0)
    del x
    torch.cuda.empty_cache()
del model
torch.cuda.empty_cache()

with gzip.open("data/20230516-e2e-text-generation-batch.pkl.gz", "wb") as f:
    pickle.dump(db, f)