-
Notifications
You must be signed in to change notification settings - Fork 554
feat(pt): add hook to last fitting layer output #4789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughThis update adds support for evaluating and retrieving the output of the last hidden layer (before the final layer) of the fitting network in deep potential models. New methods and hooks are introduced across the neural network, inference, and model classes to enable, cache, and access these intermediate outputs, with API extensions for both standard and PyTorch-based implementations. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant DeepEval
participant DeepEvalBackend
participant DPModelCommon
participant DPAtomicModel
participant GeneralFitting
User->>DeepEval: eval_fitting_last_layer(...)
DeepEval->>DeepEvalBackend: eval_fitting_last_layer(...)
DeepEvalBackend->>DPModelCommon: set_eval_fitting_last_layer_hook(True)
DeepEvalBackend->>DPModelCommon: eval(...)
DPModelCommon->>DPAtomicModel: set_eval_fitting_last_layer_hook(True)
DPAtomicModel->>GeneralFitting: set_return_middle_output(True)
DPModelCommon->>DPAtomicModel: forward_atomic(...)
DPAtomicModel->>GeneralFitting: _forward_common(...)
GeneralFitting->>GeneralFitting: call_until_last(...)
GeneralFitting-->>DPAtomicModel: return {"middle_output": ...}
DPAtomicModel->>DPAtomicModel: Cache middle_output
DPAtomicModel->>DPAtomicModel: set_eval_fitting_last_layer_hook(False)
DPModelCommon->>DeepEvalBackend: return eval_fitting_last_layer()
DeepEvalBackend->>DeepEval: return result
DeepEval->>User: return result
Suggested labels
Suggested reviewers
Warning There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure. 🔧 Pylint (3.3.7)deepmd/dpmodel/utils/network.pyNo files to lint: exiting. deepmd/infer/deep_eval.pyNo files to lint: exiting. deepmd/pt/infer/deep_eval.pyNo files to lint: exiting.
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
82-92
: Well-designed hook management methods.The implementation correctly:
- Manages the hook enable/disable state
- Integrates with the fitting network's
set_return_middle_output
method- Clears the cache to prevent stale data
Consider potential thread safety issues if multiple threads access these methods concurrently.
272-278
: Correct implementation of middle output caching.The logic properly checks for the presence of
middle_output
, removes it from the result dictionary, detaches it from the computation graph, and caches it. The assertion ensures the feature is only used with compatible fitting networks.Consider making the error message more descriptive to help users understand which fitting network types support this feature.
- assert "middle_output" in fit_ret, ( - f"eval_fitting_last_layer not supported for fitting net {type(self.fitting_net.__class__)}!" - ) + assert "middle_output" in fit_ret, ( + f"eval_fitting_last_layer not supported for fitting net {type(self.fitting_net)}! " + f"Only mixed_types fitting networks support this feature." + )deepmd/infer/deep_eval.py (1)
504-569
: Well-implemented high-level interface method.The implementation correctly follows the established pattern of input standardization and delegation to the backend. The parameter handling is consistent with other evaluation methods.
Minor documentation inconsistency: The docstring mentions an
efield
parameter that's not in the method signature.- efield - The external field on atoms. - The array should be of size nframes x natoms x 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
deepmd/dpmodel/utils/network.py
(1 hunks)deepmd/infer/deep_eval.py
(2 hunks)deepmd/pt/infer/deep_eval.py
(2 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py
(3 hunks)deepmd/pt/model/model/dp_model.py
(1 hunks)deepmd/pt/model/task/fitting.py
(4 hunks)deepmd/pt/model/task/invar_fitting.py
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (4)
deepmd/pt/model/model/dp_model.py (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
set_eval_fitting_last_layer_hook
(82-87)eval_fitting_last_layer
(89-91)deepmd/pt/infer/deep_eval.py (1)
eval_fitting_last_layer
(683-736)
deepmd/pt/model/task/invar_fitting.py (1)
deepmd/pt/model/task/fitting.py (1)
_forward_common
(505-645)
deepmd/pt/model/task/fitting.py (4)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
mixed_types
(118-128)deepmd/pt/model/descriptor/se_a.py (2)
mixed_types
(171-175)mixed_types
(587-597)deepmd/pt/model/descriptor/hybrid.py (1)
mixed_types
(143-147)deepmd/dpmodel/utils/network.py (1)
call_until_last
(636-651)
deepmd/infer/deep_eval.py (1)
deepmd/pt/infer/deep_eval.py (1)
eval_fitting_last_layer
(683-736)
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (1, 3.9)
🔇 Additional comments (10)
deepmd/dpmodel/utils/network.py (1)
636-651
: LGTM! Well-implemented method for intermediate output extraction.The
call_until_last
method correctly implements forward pass through all layers except the last one. The implementation properly handles edge cases (empty layers or single layer) and follows the existing code patterns with clear documentation.deepmd/pt/model/model/dp_model.py (1)
68-76
: LGTM! New methods follow established patterns correctly.The new fitting last layer hook methods are well-implemented:
- Consistent naming and documentation with existing descriptor methods
- Proper delegation to
atomic_model
maintains the architecture@torch.jit.export
decorators ensure TorchScript compatibility- Clear documentation following existing patterns
deepmd/pt/model/task/invar_fitting.py (1)
184-194
: LGTM! Safe and backward-compatible implementation.The modified
forward
method correctly handles the conditional inclusion of"middle_output"
:
- Properly captures output from
_forward_common
- Safely checks for
"middle_output"
existence before adding to result- Correctly converts both main output and middle output to global precision
- Maintains backward compatibility when middle output is not available
deepmd/pt/infer/deep_eval.py (2)
133-135
: Verify the impact of disabling JIT compilation.The JIT compilation is commented out with a TODO comment. This might impact model performance during inference. Ensure this is a temporary workaround and track the issue for resolution.
683-736
: LGTM! Well-implemented evaluation method following established patterns.The
eval_fitting_last_layer
method is correctly implemented:
- Follows the same pattern as
eval_descriptor
for consistency- Comprehensive parameter documentation matching other evaluation methods
- Proper hook management (enable -> eval -> retrieve -> disable)
- Correct type conversion to NumPy array for the return value
- Method signature consistent with other evaluation APIs
deepmd/pt/model/task/fitting.py (3)
331-331
: LGTM! Clean initialization of the new attribute.The
eval_return_middle_output
attribute is properly initialized toFalse
, ensuring backward compatibility.
454-456
: LGTM! Well-designed setter method.The method provides a clean interface to toggle the middle output evaluation behavior.
605-646
:✅ Verification successful
The return type change from tensor to dictionary is a significant modification.
The logic correctly handles the middle output extraction for mixed types and appropriately raises
NotImplementedError
for non-mixed types. The implementation properly integrates with thecall_until_last
method from the network utility layer.Consider the impact of this breaking change on existing code that expects a single tensor return value.
Verify that all callers of
_forward_common
have been updated to handle the dictionary return type:
🏁 Script executed:
#!/bin/bash # Description: Find all callers of _forward_common to ensure they handle the new dictionary return type # Expected: All callers should handle dictionary return or be updated accordingly rg -A 10 "_forward_common" --type pyLength of output: 4275
🏁 Script executed:
#!/bin/bash # Inspect how the Paddle invar_fitting caller handles the new dict return rg -C5 "self._forward_common" deepmd/pd/model/task/invar_fitting.pyLength of output: 298
Dict return correctly handled in all PyTorch callers
All PyTorch-based tasks now unpack the dictionary returned by
_forward_common
:
- deepmd/pt/model/task/dipole.py indexes
out[self.var_name]
- deepmd/pt/model/task/polarizability.py indexes
out[self.var_name]
- deepmd/pt/model/task/invar_fitting.py builds its
result
fromout[self.var_name]
(andmiddle_output
)The Paddle implementation still returns a tensor and its callers remain unaffected. No further updates are required.
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
65-67
: LGTM! Consistent attribute additions following the existing pattern.The new attributes for fitting last layer hook follow the same design pattern as the existing descriptor hook, maintaining consistency.
Also applies to: 70-70
deepmd/infer/deep_eval.py (1)
218-258
: Excellent addition of abstract interface method.The method signature and documentation are comprehensive and consistent with existing evaluation methods. The abstract nature ensures all backends implement this functionality.
Summary by CodeRabbit