-
Notifications
You must be signed in to change notification settings - Fork 568
pd: support different label_dict in CINN #4795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the CINN jit.to_static
wrapper by dynamically constructing the input_spec
based on the keys present in label_dict
, thereby removing unnecessary specs when certain labels aren’t used.
- Fetches a sample
label_dict
viaself.get_data
to drive spec construction - Defines a full
label_dict_spec
and pops entries not present inlabel_dict
- Applies the filtered
label_dict_spec
tojit.to_static
📝 WalkthroughWalkthroughThe update modifies the way the input specification for model compilation is constructed in the training module. Instead of using a hardcoded label dictionary, the code now dynamically builds the input specification based on the actual keys present in sampled training data label dictionaries, ensuring alignment with runtime label structures. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant DataSource
participant ModelCompiler
Trainer->>DataSource: get_data(is_train=True, task_key="Default" or first model key)
DataSource-->>Trainer: sample label_dict
Trainer->>Trainer: build label_dict_spec based on sample keys
Trainer->>ModelCompiler: compile model with input_spec (using label_dict_spec)
Possibly related PRs
Warning There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure. 🔧 Pylint (3.3.7)deepmd/pd/train/training.pyNo files to lint: exiting. ✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pd/train/training.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (28)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (1)
deepmd/pd/train/training.py (1)
647-647
: LGTM! Correct usage of the dynamically constructed specification.The usage of
label_dict_spec
in theinput_spec
list correctly replaces the previously hardcoded label dictionary specification, ensuring alignment between static compilation and runtime data structures.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4795 +/- ##
==========================================
- Coverage 84.79% 84.78% -0.01%
==========================================
Files 698 698
Lines 67816 67820 +4
Branches 3540 3540
==========================================
- Hits 57505 57504 -1
- Misses 9177 9181 +4
- Partials 1134 1135 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. @HydrogenSulfate The suggested changes by AI might help.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: HydrogenSulfate <490868991@qq.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/pd/train/training.py (1)
630-632
: Use membership on the dict directly & drop the redundant.keys()
dict.keys()
materialises a view; membership works directly on the dict and is clearer:-label_dict_spec = { - k: spec_templates[k] for k in label_dict.keys() if k in spec_templates -} +label_dict_spec = {k: spec_templates[k] for k in label_dict if k in spec_templates}Minor, but it avoids the SIM118 lint warning and trims a few cycles.
🧰 Tools
🪛 Ruff (0.11.9)
631-631: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pd/train/training.py
(2 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
deepmd/pd/train/training.py
631-631: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms (23)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Analyze (c-cpp)
11c11fb
to
3a6438e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
deepmd/pd/train/training.py (1)
610-613
:⚠️ Potential issueAvoid consuming a real training batch when inferring
input_spec
Calling
self.get_data(is_train=True)
here advances the primary training iterator and permanently drops one batch before training even starts.
This silently skews epoch statistics and is especially painful when the dataset is small.A side-effect-free “peek” is safer:
- # NOTE: This is a trick to decide the right input_spec for wrapper.forward - _, label_dict, _ = self.get_data(is_train=True) + # NOTE: do **not** advance the real iterator; just peek one sample + loader = self.training_dataloader + if self.multi_task: + # fall back to the first branch – adjust if you need something smarter + loader = self.training_dataloader[self.model_keys[0]] + raw_sample = next(iter(loader)) + _, label_dict, _ = self._split_inputs_labels(raw_sample) # reuse existing helper(The helper
_split_inputs_labels
/similar already exists further down in the file; reuse it to keep the parsing logic in one place.)
🧹 Nitpick comments (1)
deepmd/pd/train/training.py (1)
628-630
: Drop the redundant.keys()
call
label_dict
is already iterable; using.keys()
is unnecessary and flagged by Ruff (SIM118).- label_dict_spec = { - k: spec_templates[k] for k in label_dict.keys() if k in spec_templates - } + label_dict_spec = { + k: spec_templates[k] for k in label_dict if k in spec_templates + }🧰 Tools
🪛 Ruff (0.11.9)
629-629: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pd/train/training.py
(2 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
deepmd/pd/train/training.py
629-629: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms (25)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
@coderabbitai resolve |
✅ Actions performedComments resolved. |
pop unnecessary item when wrapping model with
jit.to_static
, so we can support se_e2_a/dpa2/dpa3 without extra modification.@njzjz can you give some suggestions for better code improvements? The current approach of fetching data via
self.get_data
isn't very concise.Summary by CodeRabbit