forked from AI-Hypercomputer/jetstream-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprefill_offline.py
135 lines (112 loc) · 3.31 KB
/
prefill_offline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import time
import humanize
import jax
import numpy as np
# pylint: disable-next=all
from absl import app, flags
from jetstream_pt.config import FLAGS, create_engine_from_config_flags
def delete_pytree(p):
"""delete jax pytree"""
def delete_leaf(leaf):
if isinstance(leaf, jax.Array):
leaf.delete()
del leaf
jax.tree_map(delete_leaf, p)
def print_mem_usage():
"""Print current mem usage"""
fmt_size = functools.partial(humanize.naturalsize, binary=True)
for d in jax.local_devices():
stats = d.memory_stats()
used = stats["bytes_in_use"]
limit = stats["bytes_limit"]
print(
f"memory using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}"
)
def create_prefill_tokens():
"""create list of prefill tokens"""
prefill_lengths = [
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
32768,
# 65536,
# 131072,
]
tokens_list = []
for length in prefill_lengths:
tokens = np.random.randint(1, 32000, length)
tokens_list.append(tokens)
return tokens_list
def prefill_benchmark(tokens_list, engine, params, warmup):
"""prefill bechmark function"""
for prefill_tokens in tokens_list:
# pylint: disable-next=all
warmup_text = "warmup" if warmup else "execute"
it = time.time()
prefill_result = engine.prefill(
params=params,
padded_tokens=prefill_tokens,
true_length=len(prefill_tokens),
)
print(f"---- {warmup_text} First Token: {prefill_result.token}")
elapsed = time.time() - it
print(
f"---- {warmup_text} time: {elapsed} for token_len: {len(prefill_tokens)}"
)
if warmup:
print_mem_usage()
delete_pytree(prefill_result)
print("\n\n")
# pylint: disable-next=all
def main(argv):
engine = create_engine_from_config_flags()
start = time.perf_counter()
params = engine.load_params()
print("Load params ", time.perf_counter() - start)
profiling_output = FLAGS.profiling_output
if profiling_output:
jax.profiler.start_trace(profiling_output)
print_mem_usage()
tokens_list = create_prefill_tokens()
for _ in range(3):
prefill_benchmark(
tokens_list=tokens_list, engine=engine, params=params, warmup=True
)
prefill_benchmark(
tokens_list=tokens_list, engine=engine, params=params, warmup=True
)
for _ in range(5):
prefill_benchmark(
tokens_list=tokens_list, engine=engine, params=params, warmup=False
)
prefill_benchmark(
tokens_list=tokens_list, engine=engine, params=params, warmup=False
)
if profiling_output:
jax.profiler.stop_trace()
if __name__ == "__main__":
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
app.run(main)