Skip to content

Commit 70d4eee

Browse files
committed
fix:eagle3 offline
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent d19f4f5 commit 70d4eee

File tree

5 files changed

+41
-26
lines changed

5 files changed

+41
-26
lines changed

examples/speculative_decoding/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ trainer.save_model("<path to the output directory>")
312312
| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
313313
| Mistral | ✅ | ✅ | ✅ |
314314
| Phi 3 | ✅ | ✅ | ✅ |
315-
| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
315+
| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ |
316316

317317
## Speculation Module Checkpoints
318318

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,16 @@ def keep_conversation(entry):
208208
num_success = 0
209209
pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations")
210210

211-
def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
212-
"""Post-process the TRTLLM dumped file to same format as HF dumped:
211+
async def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
212+
"""
213+
Post-process the TRTLLM dumped file to same format as HF dumped:
213214
1. Remove id field, replace it with conversation_id
214215
2. Rename hidden_state field to hidden_states
215216
3. From list of length 1 to dict
216217
4. Rename file to conversation_id.pt
217218
"""
219+
if not trtllm_dumped_file.exists():
220+
return False
218221
with open(trtllm_dumped_file, "rb") as f:
219222
trtllm_dumped = torch.load(f)
220223
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
@@ -232,35 +235,33 @@ def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
232235
output_file = args.output_dir / f"{conversation_id}.pt"
233236
with open(output_file, "wb") as f:
234237
torch.save(trtllm_dumped, f)
235-
236-
if trtllm_dumped_file.exists():
237-
trtllm_dumped_file.unlink()
238+
trtllm_dumped_file.unlink()
239+
return True
238240

239241
async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]):
240242
nonlocal num_success
241243
await llm_spec.generate_async(input_ids, sampling_params)
242244
# TRTLLM API name files starts from 1
243245
# ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012
244246
trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt"
245-
_post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
246-
num_success += 1
247+
dump_success = await _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
248+
num_success += int(dump_success)
247249
pbar.update(1)
248250

249251
async def submit_generates():
250252
nonlocal num_skipped_too_long
251253
nonlocal num_invalid
252254
tasks = []
253-
for idx, entry in enumerate(dataset):
255+
idx = 0
256+
for entry in dataset:
254257
conversation_id = entry.get("conversation_id", entry.get("uuid"))
255258

256259
conversations = entry["conversations"]
257260
if not conversations or not isinstance(conversations, list):
258261
num_invalid += 1
259262
continue
260263

261-
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
262-
:256
263-
]
264+
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)
264265
num_input_tokens = (
265266
input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids)
266267
)
@@ -269,6 +270,8 @@ async def submit_generates():
269270
continue
270271

271272
tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
273+
# Increment only for valid conversations to match dump file index
274+
idx += 1
272275
await asyncio.gather(*tasks)
273276

274277
asyncio.run(submit_generates())

examples/speculative_decoding/eagle_utils.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def preprocess(examples, tokenizer):
4747
"loss_mask": [],
4848
"labels": [],
4949
}
50-
roles = ["user", "assistant"]
5150
for i in range(len(examples)):
5251
messages = []
5352
source = examples[i]["conversations"]
@@ -61,13 +60,8 @@ def get_role_content(item):
6160
else:
6261
raise ValueError(f"Unknown conversation format: {item}")
6362

64-
first_role, _ = get_role_content(source[0])
65-
if first_role.lower() != "user":
66-
# Skip the first one if it is not from human
67-
source = source[1:]
68-
for j, sentence in enumerate(source):
63+
for sentence in source:
6964
role, content = get_role_content(sentence)
70-
assert role.lower() == roles[j % 2], f"{i}"
7165
messages.append({"role": role.lower(), "content": content})
7266
conversation = tokenizer.apply_chat_template(
7367
messages,
@@ -259,11 +253,20 @@ def make_eagle_supervised_data_module(
259253
dict: A dictionary containing train and eval datasets.
260254
"""
261255
# Load the conversations from the source file
262-
with open(data_args.data_path) as f:
263-
if data_args.data_path.endswith("jsonl"):
264-
data_json = [json.loads(line) for line in f]
265-
else:
266-
data_json = json.load(f)
256+
print_rank_0("Loading input conversations...")
257+
data_json = []
258+
data_path_p = Path(data_args.data_path)
259+
if data_path_p.is_dir():
260+
# Load all .jsonl files in the directory and combine them
261+
for jsonl_file in sorted(data_path_p.glob("*.jsonl")):
262+
with open(jsonl_file) as f:
263+
data_json.extend(json.loads(line) for line in f)
264+
else:
265+
with open(data_args.data_path) as f:
266+
if data_args.data_path.endswith("jsonl"):
267+
data_json = [json.loads(line) for line in f]
268+
else:
269+
data_json = json.load(f)
267270

268271
if use_offline_training:
269272
print_rank_0("Loading pre-processed data for offline training...")
@@ -280,12 +283,14 @@ def make_eagle_supervised_data_module(
280283

281284
# Filter to conversations that exist in the offline data and in the provided json
282285
valid_entries = []
283-
for idx, entry in enumerate(data_json):
286+
for entry in data_json:
284287
conv_id = entry.get("conversation_id")
288+
if conv_id is None:
289+
conv_id = entry.get("uuid")
285290
if conv_id is None:
286291
conv_id = entry.get("id")
287292
if conv_id is None:
288-
conv_id = "{:08d}".format(idx)
293+
raise ValueError(f"Conversation ID required but not found for entry {entry}")
289294
file_path = str(offline_data_path / f"{conv_id}.pt")
290295
if file_path in all_files:
291296
valid_entries.append((entry, file_path))

examples/speculative_decoding/launch_train.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ while [ $# -gt 0 ]; do
7878
if [[ "$1" != *=* ]]; then shift; fi
7979
NUM_GPU="${1#*=}"
8080
;;
81+
--disable_tqdm*)
82+
if [[ "$1" != *=* ]]; then shift; fi
83+
DISABLE_TQDM="${1#*=}"
84+
;;
8185
*)
8286
>&2 printf "Error: Invalid argument ${1#*=}\n"
8387
exit 1
@@ -110,6 +114,7 @@ FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaD
110114
NUM_GPU=${NUM_GPU:-1}
111115
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
112116
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
117+
DISABLE_TQDM=${DISABLE_TQDM:-False}
113118

114119
if [[ "$MODE" == "medusa" ]]; then
115120
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
@@ -165,6 +170,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
165170
--logging_steps 100 \
166171
--tf32 True \
167172
--data_path $DATA \
173+
--disable_tqdm $DISABLE_TQDM \
168174
$OFFLINE_TRAINING_ARGS \
169175
$SPECULATIVE_ARGS
170176
"

examples/speculative_decoding/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class TrainingArguments(transformers.TrainingArguments):
9292
bf16: bool = field(default=True)
9393
mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3"
9494
ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."})
95+
disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."})
9596

9697

9798
@dataclass

0 commit comments

Comments
 (0)