[Eagle Offline] multinode support for hidden states dumper#422
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughReplaces file-based JSONL ingestion with HuggingFace Datasets loading (file or directory), adds DP sharding and filtering of already-dumped conversations, caps tokenization to 256 tokens, updates progress/success reporting, simplifies local DP launcher, and adds a new SLURM submission script using containerized trtllm-llmapi-launch. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant Compute as compute_hidden_states_trtllm.py
participant HFDS as HuggingFace Datasets
participant Tokenizer
participant TRTLLM as TRT-LLM Runtime
participant FS as Filesystem
User->>Compute: launch with --input (file/dir), --output, --dp_rank/world_size
Compute->>HFDS: load_dataset(input)
Compute->>HFDS: shard(dp_rank, dp_world_size)
Compute->>HFDS: filter(existing .pt by conversation_id/uuid)
loop per conversation (dataset)
Compute->>Tokenizer: tokenize(prompt, max_length=256)
Tokenizer-->>Compute: input_ids
Compute->>TRTLLM: run forward to collect hidden states
TRTLLM-->>Compute: hidden states
Compute->>FS: save hidden states as `.pt`
end
Compute-->>User: print completion with processed count (len(dataset))
sequenceDiagram
autonumber
participant SLURM as SLURM Array Task
participant srun
participant Container
participant Launcher as trtllm-llmapi-launch
participant Compute as compute_hidden_states_trtllm.py
SLURM->>SLURM: derive TP/DP from ARRAY_TASK_ID/COUNT
SLURM->>srun: invoke with container image, mounts, and env
srun->>Container: start shell
Container->>Launcher: start TRT-LLM server/context
Launcher->>Compute: execute with model/input/output + parallel opts
Compute-->>Launcher: exit status
Launcher-->>Container: stop
Container-->>srun: exit
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (1)
261-272: Fix prompt tokenization before handing to TRTLLM.Line 261 currently slices the string returned by
tokenizer.apply_chat_template, so we end up truncating by characters, compute “token” length on characters, and then pass a string intogenerate_async. That breaks token limits and can crash depending on backend expectations. Convert to real token IDs (and cap after tokenization) before dispatch.- input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ - :256 - ] - num_input_tokens = ( - input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids) - ) + encoded = tokenizer.apply_chat_template( + conversations, + add_generation_template=False, + tokenize=True, + return_tensors="pt", + ) + input_ids = encoded["input_ids"][0][:256] + num_input_tokens = input_ids.numel() @@ - tasks.append(dump_hidden_states(idx, conversation_id, input_ids)) + tasks.append(dump_hidden_states(idx, conversation_id, input_ids.tolist()))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py(8 hunks)examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh(0 hunks)examples/speculative_decoding/collect_hidden_states/slurm_dump.sh(1 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh
[warning] 29-29: Use "${var:?}" to ensure this never expands to /* .
(SC2115)
[warning] 35-35: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
| if args.debug_max_num_conversations is not None: | ||
| dataset = dataset.select(range(args.debug_max_num_conversations)) | ||
|
|
There was a problem hiding this comment.
Guard the debug cap against short datasets.
If --debug-max-num-conversations exceeds len(dataset), dataset.select(range(...)) raises an IndexError. Clamp the range to the dataset length so the debug flag remains safe to use.
- dataset = dataset.select(range(args.debug_max_num_conversations))
+ limit = min(args.debug_max_num_conversations, len(dataset))
+ dataset = dataset.select(range(limit))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if args.debug_max_num_conversations is not None: | |
| dataset = dataset.select(range(args.debug_max_num_conversations)) | |
| if args.debug_max_num_conversations is not None: | |
| limit = min(args.debug_max_num_conversations, len(dataset)) | |
| dataset = dataset.select(range(limit)) |
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 166 to 168, the debug cap uses
dataset.select(range(args.debug_max_num_conversations)) which will raise
IndexError if the arg exceeds len(dataset); clamp the requested count to the
dataset length before selecting (e.g., compute count =
min(args.debug_max_num_conversations, len(dataset)) and call
dataset.select(range(count)) or skip selection if count <= 0) so the debug flag
is safe for shorter datasets.
| fi | ||
|
|
||
| MOUNTS=$INPUT_DIR:/input,$DUMP_DIR:/output,$MODELOPT_DIR:/modelopt,$TEACHER_MODEL:/model | ||
| rm -rf $DUMP_DIR/* |
There was a problem hiding this comment.
Harden the cleanup step before launching.
rm -rf $DUMP_DIR/* is dangerous—if $DUMP_DIR is unset/empty the glob resolves to / and nukes the host. Add the usual ${var:?} guard (or bail out explicitly) before wiping the directory.
-rm -rf $DUMP_DIR/*
+rm -rf "${DUMP_DIR:?}/"*🧰 Tools
🪛 Shellcheck (0.11.0)
[warning] 29-29: Use "${var:?}" to ensure this never expands to /* .
(SC2115)
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around line
29, the cleanup command uses an unguarded rm -rf $DUMP_DIR/* which can delete
root if DUMP_DIR is empty; update the script to validate or bail if DUMP_DIR is
unset/empty (e.g., use the parameter expansion guard ${DUMP_DIR:?} before using
it) or add an explicit check that DUMP_DIR is non-empty and points to the
expected path, then perform a safe removal only after the guard/check passes.
| RUN_DUMPER="export TLLM_LOG_LEVEL="error"; | ||
| trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ | ||
| --model /model \ | ||
| --input-data /input/ \ | ||
| --output-dir /output \ | ||
| $PARALLEL_ARGS \ | ||
| " |
There was a problem hiding this comment.
Repair RUN_DUMPER quoting so the script actually runs.
The current assignment ends the string at the inner "error"; the remainder is interpreted as separate tokens, so the script fails before launching anything. Use a heredoc (or escape the inner quotes) to produce a well-formed command string.
-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
-trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
- --model /model \
- --input-data /input/ \
- --output-dir /output \
- $PARALLEL_ARGS \
- "
+read -r -d '' RUN_DUMPER <<EOF
+export TLLM_LOG_LEVEL="error"
+trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
+ --model /model \
+ --input-data /input/ \
+ --output-dir /output \
+ $PARALLEL_ARGS
+EOF📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| RUN_DUMPER="export TLLM_LOG_LEVEL="error"; | |
| trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ | |
| --model /model \ | |
| --input-data /input/ \ | |
| --output-dir /output \ | |
| $PARALLEL_ARGS \ | |
| " | |
| read -r -d '' RUN_DUMPER <<EOF | |
| export TLLM_LOG_LEVEL="error" | |
| trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ | |
| --model /model \ | |
| --input-data /input/ \ | |
| --output-dir /output \ | |
| $PARALLEL_ARGS | |
| EOF |
🧰 Tools
🪛 Shellcheck (0.11.0)
[warning] 35-35: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 35 to 41, the RUN_DUMPER string is prematurely terminated by the inner
"error" quotes, breaking the command; fix it by constructing a single
well-formed string that contains the TLLM_LOG_LEVEL assignment and the
subsequent command — either escape the inner quotes (e.g., \"error\") or use a
heredoc or single-quoted wrapper so the whole export and trtllm-llmapi-launch
python invocation are part of the RUN_DUMPER value, preserving $PARALLEL_ARGS
and newlines as needed.
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (5)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (2)
82-92: Clarify SLURM variable names in help text.The help text mentions "TASK_ID on SLURM" and "Number of tasks on SLURM" but could be more precise. Consider updating to reference the specific SLURM environment variables (
SLURM_ARRAY_TASK_IDandSLURM_ARRAY_TASK_COUNT) for clarity.Apply this diff:
parser.add_argument( "--dp-rank", type=int, default=0, - help="""Data parallel rank. TASK_ID on SLURM.""", + help="""Data parallel rank. Set to SLURM_ARRAY_TASK_ID when using SLURM arrays.""", ) parser.add_argument( "--dp-world-size", type=int, default=1, - help="""Data parallel world size. Number of tasks on SLURM.""", + help="""Data parallel world size. Set to SLURM_ARRAY_TASK_COUNT when using SLURM arrays.""", )
283-286: Success reporting may be misleading with skipped conversations.The success message compares
num_successagainstlen(dataset), but this doesn't account for conversations skipped due to invalid data or length constraints. This could be confusing when debugging failed runs.Consider more accurate reporting:
+expected_success = len(dataset) - num_invalid - num_skipped_too_long -if num_success == len(dataset): - print(f"Successfully processed all {num_success} conversations.") +if num_success == expected_success: + print(f"Successfully processed all {num_success} valid conversations.") else: - print(f"Successfully processed {num_success} out of {len(dataset)} conversations.") + print(f"Successfully processed {num_success} out of {expected_success} valid conversations " + f"({num_invalid} invalid, {num_skipped_too_long} skipped due to length).")examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (3)
8-12: Make account/job-name more configurable.The SBATCH account and job-name are hardcoded to NVIDIA-internal values. Consider either:
- Using placeholder variables like
<YOUR_ACCOUNT>and<YOUR_JOB_NAME>to match the pattern used for INPUT_DIR, DUMP_DIR, etc.- Adding a comment explicitly stating these must be updated by users.
This would make the script more clearly a template that requires customization.
Apply this diff:
-#SBATCH -A coreai_dlalgo_modelopt -#SBATCH --job-name=coreai_dlalgo_modelopt-mcore.modelopt +#SBATCH -A <YOUR_ACCOUNT> +#SBATCH --job-name=<YOUR_JOB_NAME> #SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4 #SBATCH -p batch #SBATCH -t 04:00:00
17-17: Consider making the container version configurable.The container version is hardcoded to
1.2.0rc0, which is a release candidate. Consider:
- Moving this to a variable at the top of the script for easy updates
- Updating to a stable release if available
- Adding a comment about compatible versions
Example:
+# Container version - update as needed +CONTAINER_VERSION="1.2.0rc0" +CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:${CONTAINER_VERSION}" -CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0rc0"
17-17: Consider using a stable container release instead of RC.The script uses a release candidate container version (
1.2.0rc0). For production use, consider switching to a stable release once available.-CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0rc0" +# TODO: Update to stable release when 1.2.0 is GA +CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py(8 hunks)examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh(0 hunks)examples/speculative_decoding/collect_hidden_states/slurm_dump.sh(1 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh
[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (9)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py (6)
131-168: LGTM! Robust data loading with resumability.The refactored data loading logic properly:
- Supports both single files and directories of JSONL files
- Implements DP sharding for distributed processing
- Filters out already-processed conversations for resumability
- Provides debug capability to limit conversation count
The conversation_id extraction handles both "conversation_id" and "uuid" fields with proper fallback and validation.
261-263: Verify the 256-token tokenization cap.The tokenization input is hard-capped at 256 tokens, which could truncate longer conversations before they're processed. This seems inconsistent with the
--max-seq-lenparameter (default 3072).Please clarify:
- Why is 256 tokens chosen as the cap?
- Should this be configurable via a command-line argument?
- How does this interact with the
--max-seq-lenvalidation on line 267?If this cap is intentional and specific to the dumper's requirements, consider adding an inline comment explaining the rationale:
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ - :256 + :256 # Cap prompt to 256 tokens for hidden state collection ]
26-26: LGTM!The addition of
load_datasetfrom the datasets library is appropriate for the refactoring to HuggingFace datasets-based loading.
144-148: LGTM!The dataset sharding logic correctly distributes data across DP ranks using the datasets library's built-in
shardmethod, which ensures even distribution.
165-167: LGTM!The debug cap feature is well-implemented and will be useful for testing without processing the entire dataset.
131-141: Verify load_dataset handles empty input directory Ensure that passing an empty directory (with no.jsonlfiles) toload_dataset("json", data_files={"train": ".../*.jsonl"}, split="train")raises a clear, descriptive error rather than returning an empty dataset.examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (3)
42-47: LGTM! Appropriate timeout margin.The srun configuration is well-structured:
- The 235-minute timeout provides a 5-minute buffer before the 4-hour SBATCH limit, allowing graceful cleanup
- PMIx MPI and overlap flags are appropriate for distributed execution
- Container configuration correctly references the defined variables
- Log file naming includes job metadata for tracking
8-12: LGTM!The SBATCH directives are correctly configured for single-node, 4-GPU execution with appropriate time limits and partition settings.
19-22: LGTM!The placeholder values are clearly marked and provide helpful guidance to users on what needs to be configured.
| # Remove already dumped conversations | ||
| def keep_conversation(entry): | ||
| conversation_id = entry.get("conversation_id", entry.get("uuid", None)) | ||
| assert conversation_id is not None, "conversation_id is required" | ||
| output_file = args.output_dir / f"{conversation_id}.pt" | ||
| if output_file.exists(): | ||
| continue | ||
| filtered_conversations.append(entry) | ||
| return not output_file.exists() | ||
|
|
||
| original_num = len(dataset) | ||
| dataset = dataset.filter(keep_conversation) | ||
| print( | ||
| "Removed", | ||
| len(all_conversations) - len(filtered_conversations), | ||
| original_num - len(dataset), | ||
| "conversations due to existing output files", | ||
| ) |
There was a problem hiding this comment.
Replace assertion with proper error handling.
The keep_conversation filter function uses an assertion to validate that conversation_id exists, which will cause the entire process to crash if any conversation lacks this field. In a distributed DP setting, this would fail the entire SLURM job.
Replace the assertion with proper error handling:
def keep_conversation(entry):
conversation_id = entry.get("conversation_id", entry.get("uuid", None))
- assert conversation_id is not None, "conversation_id is required"
+ if conversation_id is None:
+ return False # Skip conversations without valid ID
output_file = args.output_dir / f"{conversation_id}.pt"
return not output_file.exists()Additionally, consider logging a warning when conversations are skipped due to missing IDs, similar to how other invalid conversations are tracked.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Remove already dumped conversations | |
| def keep_conversation(entry): | |
| conversation_id = entry.get("conversation_id", entry.get("uuid", None)) | |
| assert conversation_id is not None, "conversation_id is required" | |
| output_file = args.output_dir / f"{conversation_id}.pt" | |
| if output_file.exists(): | |
| continue | |
| filtered_conversations.append(entry) | |
| return not output_file.exists() | |
| original_num = len(dataset) | |
| dataset = dataset.filter(keep_conversation) | |
| print( | |
| "Removed", | |
| len(all_conversations) - len(filtered_conversations), | |
| original_num - len(dataset), | |
| "conversations due to existing output files", | |
| ) | |
| # Remove already dumped conversations | |
| def keep_conversation(entry): | |
| conversation_id = entry.get("conversation_id", entry.get("uuid", None)) | |
| if conversation_id is None: | |
| # Skip conversations without valid ID (consider logging a warning here) | |
| return False | |
| output_file = args.output_dir / f"{conversation_id}.pt" | |
| return not output_file.exists() | |
| original_num = len(dataset) | |
| dataset = dataset.filter(keep_conversation) | |
| print( | |
| "Removed", | |
| original_num - len(dataset), | |
| "conversations due to existing output files", | |
| ) |
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 150-163, replace the assertion that enforces a conversation_id with
non-crashing error handling: if conversation_id is missing, log a warning (or
print) indicating the skipped entry and return False from keep_conversation so
the entry is filtered out instead of crashing the job; optionally increment or
track a skipped counter for reporting, then continue to check for existing
output files and return not output_file.exists() for valid IDs.
| input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ | ||
| :256 | ||
| ] |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Clarify the magic number 256 for tokenization truncation.
The hardcoded limit of 256 tokens appears arbitrary and may not align with the model's actual context window or the user's --max-seq-len parameter. This could lead to confusion when conversations are unexpectedly truncated.
Consider one of the following approaches:
- Make this configurable via a CLI argument:
+parser.add_argument(
+ "--max-input-tokens",
+ type=int,
+ default=256,
+ help="Maximum number of tokens to use from conversation input for context."
+)- Or at minimum, add a constant and a comment explaining the rationale:
+# Limit input tokens to reduce memory usage during hidden state collection
+MAX_INPUT_TOKENS = 256
+
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
- :256
+ :MAX_INPUT_TOKENS
]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ | |
| :256 | |
| ] | |
| # Limit input tokens to reduce memory usage during hidden state collection | |
| MAX_INPUT_TOKENS = 256 | |
| input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ | |
| :MAX_INPUT_TOKENS | |
| ] |
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 261 to 263, replace the hardcoded token truncation slice [:256]
with a named constant or CLI-configurable value tied to the model/context window
(e.g., use the --max-seq-len argument if available, or
model.config.max_position_embeddings minus reserved tokens) and add a comment
explaining why that limit was chosen; update the tokenizer.apply_chat_template
call to slice by that variable so truncation respects the model's actual context
length or the user-provided parameter.
| #EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py. | ||
| PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT" | ||
|
|
||
| RUN_DUMPER="export TLLM_LOG_LEVEL="error"; |
There was a problem hiding this comment.
Fix shell quoting issue.
The line has incorrect nested quoting that will result in the string "error" (including quotes) being assigned to TLLM_LOG_LEVEL, rather than just error.
Apply this diff:
-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
+RUN_DUMPER="export TLLM_LOG_LEVEL=error;
trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
--model /model \
--input-data /input/ \
--output-dir /output \
$PARALLEL_ARGS \
"🧰 Tools
🪛 Shellcheck (0.11.0)
[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around line
34, the RUN_DUMPER assignment uses nested double quotes causing the literal
string "\"error\"" to be assigned; change it so the environment variable is
assigned the plain word error (e.g., RUN_DUMPER="export TLLM_LOG_LEVEL=error;"
or RUN_DUMPER='export TLLM_LOG_LEVEL=error;') by removing the inner quotes or
using single outer quotes.
There was a problem hiding this comment.
@ChenhanYu, thank you for confirming! The fix will prevent the environment variable from receiving the literal string "error" (with quotes) and ensure it gets the plain value error as intended.
| RUN_DUMPER="export TLLM_LOG_LEVEL="error"; | ||
| trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ | ||
| --model /model \ | ||
| --input-data /input/ \ | ||
| --output-dir /output \ | ||
| $PARALLEL_ARGS \ | ||
| " |
There was a problem hiding this comment.
Fix shell quoting syntax error.
The export statement has incorrect nested quotes that will cause a shell syntax error.
As per static analysis
Apply this diff to fix the quoting:
-RUN_DUMPER="export TLLM_LOG_LEVEL="error";
+RUN_DUMPER="export TLLM_LOG_LEVEL=error;
trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
--model /model \
--input-data /input/ \
--output-dir /output \
$PARALLEL_ARGS \
"🧰 Tools
🪛 Shellcheck (0.11.0)
[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 34 to 40, the RUN_DUMPER assignment contains incorrectly nested double
quotes causing a shell syntax error; fix by wrapping the whole command in single
quotes (or escape the inner quotes) so the export uses TLLM_LOG_LEVEL="error"
correctly and the rest of the command string remains intact, e.g., assign
RUN_DUMPER='export TLLM_LOG_LEVEL="error"; trtllm-llmapi-launch python3
/modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
--model /model --input-data /input/ --output-dir /output $PARALLEL_ARGS' (ensure
no trailing unmatched quotes).
| timeout 235m srun -l \ | ||
| --mpi=pmix --overlap \ | ||
| --output=%x_%j_$DATETIME.log \ | ||
| --container-image ${CONTAINER} \ | ||
| --container-mounts ${MOUNTS} \ | ||
| bash -c "$RUN_DUMPER" |
There was a problem hiding this comment.
Fix undefined variable in log filename.
Line 44 references $DATETIME which is never defined, causing the log filename to be incomplete or use a literal string.
Define the variable before use or remove it from the log filename:
+DATETIME=$(date +%Y%m%d_%H%M%S)
+
timeout 235m srun -l \
--mpi=pmix --overlap \
--output=%x_%j_$DATETIME.log \
--container-image ${CONTAINER} \
--container-mounts ${MOUNTS} \
bash -c "$RUN_DUMPER"Alternatively, if timestamps aren't needed, simplify to:
timeout 235m srun -l \
--mpi=pmix --overlap \
- --output=%x_%j_$DATETIME.log \
+ --output=%x_%j.log \
--container-image ${CONTAINER} \
--container-mounts ${MOUNTS} \
bash -c "$RUN_DUMPER"🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 42 to 47 the --output pattern references $DATETIME which is never defined;
either define DATETIME before the srun call (e.g. set DATETIME to a timestamp
string) or remove $DATETIME from the --output argument so the filename does not
include an undefined variable; update the script to export or assign DATETIME
prior to using it, or simplify the --output value to rely on existing SLURM
placeholders only.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #422 +/- ##
=======================================
Coverage 73.36% 73.36%
=======================================
Files 180 180
Lines 17919 17919
=======================================
Hits 13147 13147
Misses 4772 4772 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (2)
34-40: Fix RUN_DUMPER quoting; current form is syntactically broken.The nested quotes terminate the string at "error", so the script fails before launch. Build the command with a heredoc to avoid quoting pitfalls.
-RUN_DUMPER="export TLLM_LOG_LEVEL="error"; -trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ - --model /model \ - --input-data /input/ \ - --output-dir /output \ - $PARALLEL_ARGS \ - " +RUN_DUMPER=$(cat <<EOF +export TLLM_LOG_LEVEL=error +trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ + --model /model \ + --input-data /input/ \ + --output-dir /output \ + $PARALLEL_ARGS +EOF +)
42-47: Define DATETIME (or remove it) in the srun --output pattern.$DATETIME is undefined, producing odd filenames.
+DATETIME=$(date +%Y%m%d_%H%M%S) timeout 235m srun -l \ --mpi=pmix --overlap \ --output=%x_%j_$DATETIME.log \ --container-image ${CONTAINER} \ - --container-mounts ${MOUNTS} \ + --container-mounts "${MOUNTS}" \ bash -c "$RUN_DUMPER"
🧹 Nitpick comments (2)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh (2)
19-23: Validate required paths before running.Placeholders/defaults will cause failures or bad mounts. Add guards and existence checks.
INPUT_DIR="<Can be directory containing the .jsonl files, or path to single .jsonl file>" DUMP_DIR="<Directory for output hidden states>" MODELOPT_DIR="<Path to Modelopt repo>" TEACHER_MODEL="<Path to teacher model>" +# Basic validation +for var in INPUT_DIR DUMP_DIR MODELOPT_DIR TEACHER_MODEL; do + val="${!var}" + if [[ -z "$val" || "$val" == \<* \>* ]]; then + echo "ERROR: $var is not set. Please edit slurm_dump.sh." + exit 1 + fi +done + +[[ -e "$INPUT_DIR" ]] || { echo "ERROR: INPUT_DIR not found: $INPUT_DIR"; exit 1; } +[[ -d "$MODELOPT_DIR" ]] || { echo "ERROR: MODELOPT_DIR not found: $MODELOPT_DIR"; exit 1; } +[[ -d "$TEACHER_MODEL" ]] || { echo "ERROR: TEACHER_MODEL not found: $TEACHER_MODEL"; exit 1; }
28-28: Quote mounts to avoid word splitting.Protect paths with spaces/commas by quoting the value and usage.
-MOUNTS=$INPUT_DIR:/input,$DUMP_DIR:/output,$MODELOPT_DIR:/modelopt,$TEACHER_MODEL:/model +MOUNTS="${INPUT_DIR}:/input,${DUMP_DIR}:/output,${MODELOPT_DIR}:/modelopt,${TEACHER_MODEL}:/model"And in srun:
- --container-mounts ${MOUNTS} \ + --container-mounts "${MOUNTS}" \
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh(1 hunks)
🧰 Additional context used
🪛 Shellcheck (0.11.0)
examples/speculative_decoding/collect_hidden_states/slurm_dump.sh
[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"?
(SC2140)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
|
|
||
| #SBATCH -A coreai_dlalgo_modelopt | ||
| #SBATCH --job-name=coreai_dlalgo_modelopt-generate_eagle_hidden_states | ||
| #SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4 |
There was a problem hiding this comment.
Align allocation with intended topology (1 proc using 4 GPUs).
You set ntasks-per-node=4 but run a single command that does TP=4. Allocate 1 task and give it 4 GPUs to avoid idle tasks and binding ambiguity.
-#SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4
+#SBATCH --nodes=1 --ntasks-per-node=1 --gpus-per-task=4Optionally, also make the step explicit:
-timeout 235m srun -l \
+timeout 235m srun -l -n 1 \📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4 | |
| # In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh | |
| #SBATCH --nodes=1 --ntasks-per-node=1 --gpus-per-task=4 | |
| timeout 235m srun -l -n 1 \ | |
| --mpi=pmix --overlap \ | |
| --output=%x_%j_$DATETIME.log \ | |
| --container-image ${CONTAINER} \ | |
| --container-mounts ${MOUNTS} \ | |
| bash -c "$RUN_DUMPER" |
🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around line
10, the SBATCH line currently specifies --ntasks-per-node=4 which creates four
tasks but you run a single process that needs 4 GPUs; change it to
--ntasks-per-node=1 --gpus-per-node=4 (and optionally add --cpus-per-task=<num>
if you need CPU binding) so one task owns all 4 GPUs, and if desired make the
launch step explicit by using srun --ntasks=1 (or sbatch step) to run the one
process.
| #EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py. | ||
| PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT" | ||
|
|
There was a problem hiding this comment.
Provide safe defaults when not using an array job.
If SLURM_ARRAY_TASK_ID/COUNT are unset, dp args become empty and the Python launcher may fail. Default to a single‑rank DP.
-#By default: TP inside node, and DP across slurm array
-#EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
-PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT"
+# By default: TP inside node, and DP across slurm array
+# EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
+DP_RANK="${SLURM_ARRAY_TASK_ID:-0}"
+DP_WORLD="${SLURM_ARRAY_TASK_COUNT:-1}"
+PARALLEL_ARGS="--tp 4 --dp-rank ${DP_RANK} --dp-world-size ${DP_WORLD}"🤖 Prompt for AI Agents
In examples/speculative_decoding/collect_hidden_states/slurm_dump.sh around
lines 31 to 33, SLURM_ARRAY_TASK_ID and SLURM_ARRAY_TASK_COUNT may be unset
causing PARALLEL_ARGS to be empty and the Python launcher to fail; add safe
defaults (e.g., SLURM_ARRAY_TASK_ID=${SLURM_ARRAY_TASK_ID:-0} and
SLURM_ARRAY_TASK_COUNT=${SLURM_ARRAY_TASK_COUNT:-1}) and construct PARALLEL_ARGS
to include "--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size
$SLURM_ARRAY_TASK_COUNT" only after applying those defaults so a non-array run
defaults to single-rank DP.
What does this PR do?
Type of change: New feature
Overview:
Added a slurm multi-node script for TRTLLM eagle hidden states dumper.
Usage
sbatch --array=0-n slurm_dumper.sh# Add a code snippet demonstrating how to use thisTesting
Tested on HSG cluster;
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Chores