-
Notifications
You must be signed in to change notification settings - Fork 9
/
eval.py
568 lines (479 loc) · 17.7 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
import time
import argparse
import math
import json
import regex as re
import contextlib
import shutil
import itertools
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Optional, List
from collections import defaultdict, Counter
from tqdm.auto import tqdm
import torch
import torch._dynamo.config
import torch._inductor.config
from cache import add_cache_arguments, cache_compatibility, get_cache_constructor
from model import Transformer
from generation_utils import (
add_generation_arguments,
compile_funcs,
compute_max_seq_length,
device_sync,
get_cache_stats,
merge_cache_config,
reset_caches,
setup_caches,
)
from tokenizer import encode, TokenizerInterface
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
DEBUG_COMPILE = False
if DEBUG_COMPILE:
import logging
level = logging.DEBUG
torch._logging.set_logs(dynamo=level, inductor=level)
torch._dynamo.config.verbose = True
default_device = "cuda" if torch.cuda.is_available() else "cpu"
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from tokenizer import get_tokenizer
from generation_utils import load_model, generate
from task import TASK_MAPPING, AutoTask
def flatten_dict(in_dict: dict) -> dict:
out_dict = {}
for k, v in in_dict.items():
if type(v) == dict:
for kk, vv in v.items():
out_dict[f"{k}_{kk}"] = vv
else:
out_dict[k] = v
return out_dict
def compress_list(l):
if len(l) < 3:
return l
else:
counter = Counter(l)
return [f"{k}x{v}" for k, v in counter.items()]
def args_to_str(args):
if "debug" in args.cache_strategy[0]:
debug_suffix = "__debug"
cache_strategy = [
re.sub(r"debug_+", "", cs).strip() for cs in args.cache_strategy
]
else:
cache_strategy = args.cache_strategy
debug_suffix = ""
RELEVANT_CACHE_KWARGS = list(
sorted(
set(
itertools.chain(
*[get_cache_constructor(cs)[1] for cs in cache_strategy]
)
)
)
)
def process_num(n):
# Return integer floats as "1" not 1.0
# Otherwise, no op
if type(n) == float and int(n) == n:
return int(n)
return n
RELEVANT_CACHE_KWARGS.append("cache_length_pattern")
RELEVANT_CACHE_KWARGS.append("cache_strategy_pattern")
if hasattr(args, "attn_top_k") and args.attn_top_k != 1.0:
RELEVANT_CACHE_KWARGS.append("attn_top_k")
args_dict = vars(args).copy()
# Hybrid Strategies will be too long to save in a file name so just need to pick the strategy
if "hybrid_strategies" in args_dict:
args_dict["hybrid_strategies"] = [
x["strategy"] for x in args_dict["hybrid_strategies"]
]
return (
"__".join(
sorted(
[
f"{k}=" + ",".join(compress_list([str(process_num(m)) for m in v]))
if type(v) == list
else f"{k}={process_num(v)}"
for k, v in args_dict.items()
if k in RELEVANT_CACHE_KWARGS
]
)
)
+ debug_suffix
)
def run_task(
args: argparse.Namespace,
task: AutoTask,
model: Transformer,
prefill: callable,
decode_one_token: callable,
tokenizer: TokenizerInterface,
is_chat: bool = False,
profile: Optional[Path] = None,
feed_long_prompts=False,
decode_first_token=False,
device=default_device,
cache_kwargs: dict = {},
use_tp: bool = False,
rank: int = None,
terminator_ids: List[int] = None,
):
aggregate_metrics = defaultdict(list)
predictions = []
all_probs = []
task_metrics = {}
test = task.get_test()
if len(test) == 0:
print(
f"No test data found for {task.__class__.__name__}. Skipping. Possibly all filtered out by tokenizer for being too long."
)
return None, None, None
prompts = test["prompt"]
inputs = [
encode(tokenizer, prompt, device="cpu", is_chat=is_chat)
for prompt in tqdm(prompts, desc="Encoding Prompts")
]
if task.requires_perplexity:
assert (
len(test["labels"][0]) == 1
), "Only one label supported for perplexity tasks"
label_ids = [
encode(tokenizer, label[0], device="cpu", is_chat=False, bos=False)
for label in tqdm(test["labels"], desc="Encoding Labels")
]
_, max_seq_length = compute_max_seq_length(model, inputs, label_ids, 0)
else:
label_ids = None
_, max_seq_length = compute_max_seq_length(model, inputs, None, task.max_tokens)
# Estimate median sequence length
median_seq_length = int(np.median([len(i) for i in inputs]) + task.max_tokens / 2)
target_length = (
max_seq_length
if any([x in {"full", "hybrid"} or "debug" in x for x in args.cache_strategy])
else median_seq_length
)
task_cache_kwargs = setup_caches(
model, tokenizer, device, target_length, cache_kwargs.copy()
)
for i in tqdm(range(len(inputs))):
input = inputs[i].to(device)
next_tokens = None if label_ids is None else label_ids[i].to(device)
prompt_length = input.size(0)
max_new_tokens = min(task.max_tokens, max_seq_length - prompt_length)
assert max_new_tokens > 0, f"Prompt too long for model: {prompt_length}"
device_sync(device=device) # MKG
if not profile or (use_tp and rank != 0):
prof = contextlib.nullcontext()
else:
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
with prof:
y, probs, perf_stats = generate(
model,
input,
prefill,
decode_one_token,
max_new_tokens=max_new_tokens,
next_tokens=next_tokens,
terminator_ids=terminator_ids if next_tokens is None else None,
attn_top_k=args.attn_top_k,
feed_long_prompts=feed_long_prompts,
decode_first_token=decode_first_token,
)
for k, v in perf_stats.items():
aggregate_metrics[k].append(v)
if next_tokens is not None:
nll = -torch.tensor(
[
torch.log(probs[j][next_tokens[j]])
for j in range(next_tokens.size(0))
]
)
for k in range(500, len(nll), 500):
aggregate_metrics[f"ppl@{k}"].append(
float(torch.exp(torch.mean(nll[:k])).item())
)
aggregate_metrics["ppl"].append(float(torch.exp(torch.mean(nll)).item()))
if hasattr(prof, "export_chrome_trace"):
if use_tp:
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
else:
prof.export_chrome_trace(f"{profile}.json")
device_sync(device=device) # MKG
cache_stats = get_cache_stats(model, prompt_length, perf_stats["decode_tokens"])
for k, v in cache_stats.items():
aggregate_metrics[k].append(v)
if (
not task.requires_perplexity
): # Perplexity tasks don't decode from model so don't save predictions
# Decode: remove EoT and prompt
end = y.size(0)
if y[-1] in terminator_ids:
end = -1
pred = tokenizer.decode(y[prompt_length:end].tolist())
if args.debug:
print(f"Prediction: {pred}")
predictions.append(pred)
if task.requires_logits:
all_probs.append(
{k: v for k, v in zip(tokenizer.get_vocab(), probs[-1].tolist())}
)
# Reset KV Cache state
reset_caches(model)
print(
f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['total_toks_per_sec'])).item():.2f}"
)
max_mem_gb = torch.cuda.max_memory_reserved() / 1e9
print(f"Memory used: {max_mem_gb} GB")
task_metrics["max_memory_gb"] = max_mem_gb
for k, v in aggregate_metrics.items():
task_metrics[k] = sum(v) / len(v)
# For toks_per_sec, we also want to report the average of the highest 10% toks/second
# This is useful to get a sense of toks / second without the one-time impact of compilation
if "toks_per_sec" in k:
# Useful to save toks_per_sec for each example for better understanding of how it changes over time with compile
task_metrics[k] = v
# Also save the top 10% average (likely unaffected by compile)
v.sort()
cutoff = math.ceil(len(v) / 10)
task_metrics[f"{k}_top_10p"] = sum(v[-cutoff:]) / cutoff
if k == "total_seconds":
task_metrics[f"{k}_min"] = min(aggregate_metrics[k])
task_metrics[f"{k}_max"] = max(aggregate_metrics[k])
task_metrics[f"{k}_median"] = float(np.median(aggregate_metrics[k]))
if task.requires_perplexity:
pred_df = None
else:
pred_units = all_probs if task.requires_logits else predictions
task_metrics.update(flatten_dict(task.test_metrics(pred_units)))
pred_df = pd.DataFrame({"prompt": prompts, "prediction": predictions})
return task_metrics, pred_df, task_cache_kwargs
def main(
args: argparse.Namespace,
tasks: List[str],
debug: bool = False,
checkpoint_path: Path = Path(
"checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"
),
profile: Optional[Path] = None,
compile=True,
feed_long_prompts=False,
decode_first_token=False,
device=default_device,
cache_kwargs: dict = {},
out_dir: Path = None,
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
assert checkpoint_path.is_file(), checkpoint_path
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
if not tokenizer_path.is_file():
# If there's no tokenizer.model, try to load the tokenizer from the parent directory
# NOTE: We assume the tokenizer in the parent directory is compatible with huggingface transformers
tokenizer_path = checkpoint_path.parent
global print
from tp import maybe_init_dist
rank = maybe_init_dist()
use_tp = rank is not None
if use_tp:
if rank != 0:
# only print on rank 0
print = lambda *args, **kwargs: None
print(f"Using device={device}")
precision = torch.bfloat16
is_chat = (
"chat" in str(checkpoint_path).lower()
or "instruct" in str(checkpoint_path).lower()
)
print("Loading model ...")
t0 = time.time()
model = load_model(checkpoint_path, device, precision, use_tp)
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path, is_chat=is_chat)
if cache_kwargs["cache_strategy"] == "hybrid":
# We need to pass the special and punctuation token ids to the cache via cache_kwargs
cache_kwargs["token_ids"] = {
"special": tokenizer.special_ids(),
"punctuation": tokenizer.punctuation_ids(),
}
terminator_ids = tokenizer.get_terminator_ids()
torch.manual_seed(1234)
task_kwargs = {
"model_max_length": model.config.max_length,
"num_samples": args.num_samples,
"tokenizer": tokenizer.encode_prompt if is_chat else tokenizer.encode,
"seq_length": args.seq_length,
}
if tasks == ["all"]:
# Evaluate all tasks
tasks = list(TASK_MAPPING.keys())
eval_tasks = {task: AutoTask.from_name(task, **task_kwargs) for task in tasks}
task_metrics = defaultdict(dict)
args_fn = out_dir / "args.json"
all_out_fn = out_dir / "all_metrics.json"
for task_name, task in eval_tasks.items():
print(f"Running task {task_name} ...")
task_out_fn = out_dir / f"{task_name}_metrics.json"
task_args_out_fn = out_dir / f"{task_name}_args.json"
pred_out_fn = out_dir / f"{task_name}_predictions.csv"
if task_out_fn.exists() and not cache_kwargs["overwrite"]:
print(f"Task {task_name} already evaluated. Skipping.")
with open(task_out_fn, "r") as fd:
task_metrics[task_name] = json.load(fd)
else:
prefill, decode_one_token = compile_funcs(compile)
task_metrics[task_name], predictions, task_args = run_task(
args,
task,
model,
prefill,
decode_one_token,
tokenizer,
is_chat,
profile,
feed_long_prompts,
decode_first_token,
device,
cache_kwargs,
use_tp,
rank,
terminator_ids,
)
if task_metrics[task_name] is None:
continue
if predictions is not None:
predictions.to_csv(pred_out_fn, index=False)
if debug:
print(f"Results for {task_name}:")
print(task_metrics[task_name])
with open(task_out_fn, "w") as fd:
print(f"Saving results for {task_name} to {task_out_fn}")
json.dump(task_metrics[task_name], fd, indent=4)
with open(task_args_out_fn, "w") as fd:
print(f"Saving dynamic args for {task_name} to {task_args_out_fn}")
# Convert Path objects to strings
task_args_json = {
k: str(v) if isinstance(v, Path) else v
for k, v in task_args.items()
}
json.dump(task_args_json, fd, indent=4)
if not args_fn.exists():
# Only save args once and only save if we've gotten through a full eval and are ready to dump metrics
with open(args_fn, "w") as fd:
# Convert Path objects to strings
cache_kwargs_json = {
k: str(v) if isinstance(v, Path) else v
for k, v in cache_kwargs.items()
}
json.dump(cache_kwargs_json, fd, indent=4)
with open(all_out_fn, "w") as fd:
json.dump(task_metrics, fd, indent=4)
def setup(args) -> Path:
sub_dir = args_to_str(args) if args.out_dir is None else args.out_dir
out_dir = (
Path(__file__).parent
/ "results"
/ args.checkpoint_path.parent.name
/ "__".join(compress_list(args.cache_strategy))
/ sub_dir
)
print(f"Saving to {out_dir}")
# Make out_dir and don't err out if it already exists
if out_dir.exists():
print(f"Output directory {out_dir} already exists.")
if args.overwrite:
print(f"Removing {out_dir}.")
shutil.rmtree(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
cache_compatibility(args)
for k, v in vars(args).items():
print(f"{k} -> {v}")
return out_dir
def add_eval_args(parser):
parser.add_argument(
"--tasks",
type=str,
nargs="+",
default=["truthfulqa"],
choices=list(TASK_MAPPING.keys()) + ["all"],
help="List of tasks to be evaluated.",
)
parser.add_argument(
"--out_dir",
type=Path,
default=None,
help="Output directory for results. If not specified, will be a concatenation of the program args.",
)
parser.add_argument(
"--debug",
default=False,
action="store_true",
help="Debug mode uses first 10 examples in dataset.",
)
parser.add_argument(
"--num_samples",
type=int,
default=-1,
help="Number of examples to sample for evaluation. Defaults to None, which uses the full dataset.",
)
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Whether to over-write existing results if they exist.",
)
# Only for --tasks PG19
parser.add_argument(
"--seq_length",
type=int,
default=None,
help="Specify the number of tokens for the dataset.",
)
parser.add_argument(
"--cache_config",
type=str,
default=None,
help="Name of YAML file in ./cache_configs.",
)
parser.add_argument(
"--decode_first_token",
default=False,
action="store_true",
help="If True will truncate cache after prefill and then decode the first token.",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Evaluation script for different KV-Cache Compression Algorithms."
)
add_eval_args(parser)
add_generation_arguments(parser)
add_cache_arguments(parser)
args = merge_cache_config(parser.parse_args())
if args.tasks[0] == "all":
args.tasks = list(TASK_MAPPING.keys())
print(f"Running all tasks: {', '.join(args.tasks)}")
out_dir = setup(args)
main(
args,
args.tasks,
args.debug,
args.checkpoint_path,
args.profile,
args.compile,
args.feed_long_prompts,
args.decode_first_token,
args.device,
cache_kwargs=vars(args),
out_dir=out_dir,
)