-
Notifications
You must be signed in to change notification settings - Fork 554
feat: JAX training #4782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
feat: JAX training #4782
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
…to jax_training
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
valid_data = None | ||
|
||
# get training info | ||
stop_batch = jdata["training"]["numb_steps"] |
Check notice
Code scanning / CodeQL
Unused local variable Note
if ( | ||
origin_type_map is not None and not origin_type_map | ||
): # get the type_map from data if not provided | ||
origin_type_map = get_data( |
Check notice
Code scanning / CodeQL
Unused local variable Note
) | ||
jdata_cpy = jdata.copy() | ||
type_map = jdata["model"].get("type_map") | ||
train_data = get_data( |
Check notice
Code scanning / CodeQL
Unused local variable Note
📝 WalkthroughWalkthroughThis update introduces a JAX-based training and entrypoint framework for DeePMD-Kit, including new modules for command-line handling, model freezing, and training orchestration. It adds or modifies several utilities and loss functions, enhances serialization with step tracking, and generalizes learning rate scheduling. Minor internal changes and new attributes are also introduced in descriptors and statistics utilities. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI_User
participant MainEntrypoint
participant TrainEntrypoint
participant FreezeEntrypoint
participant DPTrainer
participant SerializationUtils
CLI_User->>MainEntrypoint: main(args)
MainEntrypoint->>MainEntrypoint: parse_args(args)
alt command == "train"
MainEntrypoint->>TrainEntrypoint: train(**args)
TrainEntrypoint->>DPTrainer: DPTrainer(jdata, ...)
DPTrainer->>DPTrainer: train(train_data, valid_data)
DPTrainer->>SerializationUtils: save_checkpoint(...)
else command == "freeze"
MainEntrypoint->>FreezeEntrypoint: freeze(checkpoint_folder, output)
FreezeEntrypoint->>SerializationUtils: serialize_from_file(folder)
FreezeEntrypoint->>SerializationUtils: deserialize_to_file(output)
end
Suggested labels
Suggested reviewers
Warning There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure. 🔧 Pylint (3.3.7)deepmd/dpmodel/loss/ener.pyNo files to lint: exiting. deepmd/dpmodel/fitting/ener_fitting.pyNo files to lint: exiting. deepmd/dpmodel/descriptor/dpa1.pyNo files to lint: exiting.
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (10)
deepmd/dpmodel/utils/learning_rate.py (1)
48-52
: Update return type annotation for backend generalization.The addition of the
xp
parameter to support different array backends (like JAX) is excellent for the framework's extensibility. However, the return type annotation should be updated to reflect the actual return type.- def value(self, step, xp=np) -> np.float64: + def value(self, step, xp=np):Or if you want to be more specific:
- def value(self, step, xp=np) -> np.float64: + def value(self, step, xp=np) -> Union[np.float64, Any]:deepmd/jax/entrypoints/freeze.py (1)
12-36
: LGTM! Consider adding output validation and improved error handling.The freeze function implementation is well-structured with proper checkpoint handling logic. The keyword-only argument design and integration with serialization utilities are excellent.
Consider these minor enhancements:
def freeze( *, checkpoint_folder: str, output: str, **kwargs, ) -> None: """Freeze the graph in supplied folder. Parameters ---------- checkpoint_folder : str location of either the folder with checkpoint or the checkpoint prefix output : str - output file name + output file name (supported formats: .jax, .hlo, .savedmodel) **kwargs other arguments """ + # Validate output format + supported_formats = ['.jax', '.hlo', '.savedmodel'] + if not any(output.endswith(fmt) for fmt in supported_formats): + raise ValueError(f"Unsupported output format. Supported: {supported_formats}") + if (Path(checkpoint_folder) / "checkpoint").is_file(): checkpoint_meta = Path(checkpoint_folder) / "checkpoint" checkpoint_folder = checkpoint_meta.read_text().strip() if Path(checkpoint_folder).is_dir(): - data = serialize_from_file(checkpoint_folder) - deserialize_to_file(output, data) + try: + data = serialize_from_file(checkpoint_folder) + deserialize_to_file(output, data) + except Exception as e: + raise RuntimeError(f"Failed to freeze checkpoint: {e}") from e else: raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.")deepmd/dpmodel/fitting/ener_fitting.py (1)
118-124
: Simplify nested loops for better readability.The triple-nested loops can be simplified using list comprehensions or numpy operations.
Consider refactoring the nested loops:
- sys_ener = [] - for ss in range(len(data)): - sys_data = [] - for ii in range(len(data[ss])): - for jj in range(len(data[ss][ii])): - sys_data.append(data[ss][ii][jj]) - sys_data = np.concatenate(sys_data) - sys_ener.append(np.average(sys_data)) + sys_ener = [] + for system_data in data: + # Flatten all batches and frames for this system + sys_data = np.concatenate([frame for batch in system_data for frame in batch]) + sys_ener.append(np.average(sys_data))Similarly for the mixed_type branch:
- tmp_tynatom = [] - for ii in range(len(data[ss])): - for jj in range(len(data[ss][ii])): - tmp_tynatom.append(data[ss][ii][jj].astype(np.float64)) - tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0) + # Flatten all batches and frames, then compute average + tmp_tynatom = np.average( + np.array([frame.astype(np.float64) + for batch in data[ss] + for frame in batch]), + axis=0 + )Also applies to: 130-136
deepmd/dpmodel/loss/ener.py (1)
457-457
: Clarify the ndof comment for Hessian data.The comment
# 9=3*3 --> 3N*3N=ndof*natoms*natoms
is confusing since ndof is set to 1, not 9.Update the comment to be clearer:
- ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms + ndof=1, # Hessian has shape (natoms, 3, natoms, 3), flattened per atomdeepmd/jax/train/trainer.py (3)
134-134
: Remove unused instance variables.The following instance variables are assigned but never used in the class:
self.numb_fparam
(line 134)self.frz_model
(line 142)self.ckpt_meta
(line 143)self.model_type
(line 144)Consider removing these unused variables to improve code clarity.
Also applies to: 142-145
283-284
: Simplify .get() calls by removing redundant None default.Apply this diff:
- fparam=jax_data.get("fparam", None), - aparam=jax_data.get("aparam", None), + fparam=jax_data.get("fparam"), + aparam=jax_data.get("aparam"),And similarly for lines 333-334:
- fparam=jax_valid_data.get("fparam", None), - aparam=jax_valid_data.get("aparam", None), + fparam=jax_valid_data.get("fparam"), + aparam=jax_valid_data.get("aparam"),Also applies to: 333-334
🧰 Tools
🪛 Ruff (0.11.9)
283-283: Use
jax_data.get("fparam")
instead ofjax_data.get("fparam", None)
Replace
jax_data.get("fparam", None)
withjax_data.get("fparam")
(SIM910)
284-284: Use
jax_data.get("aparam")
instead ofjax_data.get("aparam", None)
Replace
jax_data.get("aparam", None)
withjax_data.get("aparam")
(SIM910)
396-396
: Simplify dictionary key iteration.Apply this diff:
- for k in valid_results.keys(): + for k in valid_results:And for line 401:
- for k in train_results.keys(): + for k in train_results:Also applies to: 401-401
🧰 Tools
🪛 Ruff (0.11.9)
396-396: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
deepmd/jax/entrypoints/train.py (3)
144-147
: Simplify type map assignment with ternary operator.Apply this diff:
- if len(type_map) == 0: - ipt_type_map = None - else: - ipt_type_map = type_map + ipt_type_map = None if len(type_map) == 0 else type_map🧰 Tools
🪛 Ruff (0.11.9)
144-147: Use ternary operator
ipt_type_map = None if len(type_map) == 0 else type_map
instead ofif
-else
-blockReplace
if
-else
-block withipt_type_map = None if len(type_map) == 0 else type_map
(SIM108)
172-172
: Remove unused variablestop_batch
.The variable
stop_batch
is assigned but never used.Apply this diff:
- stop_batch = jdata["training"]["numb_steps"]
🧰 Tools
🪛 Ruff (0.11.9)
172-172: Local variable
stop_batch
is assigned to but never usedRemove assignment to unused variable
stop_batch
(F841)
201-204
: Address the OOM issue in neighbor statistics calculation.The commented code indicates an out-of-memory issue that needs to be resolved. This functionality appears to be important for updating the model's selection parameters based on neighbor statistics.
Would you like me to help investigate the OOM issue and propose a memory-efficient solution for computing neighbor statistics? I could open a new issue to track this task.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
deepmd/backend/jax.py
(1 hunks)deepmd/dpmodel/descriptor/dpa1.py
(1 hunks)deepmd/dpmodel/fitting/ener_fitting.py
(3 hunks)deepmd/dpmodel/loss/ener.py
(3 hunks)deepmd/dpmodel/utils/env_mat_stat.py
(1 hunks)deepmd/dpmodel/utils/learning_rate.py
(1 hunks)deepmd/jax/entrypoints/__init__.py
(1 hunks)deepmd/jax/entrypoints/freeze.py
(1 hunks)deepmd/jax/entrypoints/main.py
(1 hunks)deepmd/jax/entrypoints/train.py
(1 hunks)deepmd/jax/train/__init__.py
(1 hunks)deepmd/jax/train/trainer.py
(1 hunks)deepmd/jax/utils/serialization.py
(2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
deepmd/backend/jax.py (1)
deepmd/jax/entrypoints/main.py (1)
main
(32-67)
deepmd/dpmodel/fitting/ener_fitting.py (2)
deepmd/utils/out_stat.py (1)
compute_stats_from_redu
(15-86)deepmd/tf/fit/ener.py (2)
compute_output_stats
(257-273)_compute_output_stats
(275-321)
deepmd/jax/entrypoints/main.py (3)
deepmd/backend/suffix.py (1)
format_model_suffix
(17-75)deepmd/jax/entrypoints/freeze.py (1)
freeze
(12-36)deepmd/loggers/loggers.py (1)
set_log_handles
(146-278)
🪛 Ruff (0.11.9)
deepmd/jax/entrypoints/train.py
144-147: Use ternary operator ipt_type_map = None if len(type_map) == 0 else type_map
instead of if
-else
-block
Replace if
-else
-block with ipt_type_map = None if len(type_map) == 0 else type_map
(SIM108)
172-172: Local variable stop_batch
is assigned to but never used
Remove assignment to unused variable stop_batch
(F841)
195-195: Local variable train_data
is assigned to but never used
Remove assignment to unused variable train_data
(F841)
deepmd/jax/train/trainer.py
269-269: Use a context manager for opening files
(SIM115)
283-283: Use jax_data.get("fparam")
instead of jax_data.get("fparam", None)
Replace jax_data.get("fparam", None)
with jax_data.get("fparam")
(SIM910)
284-284: Use jax_data.get("aparam")
instead of jax_data.get("aparam", None)
Replace jax_data.get("aparam", None)
with jax_data.get("aparam")
(SIM910)
333-333: Use jax_valid_data.get("fparam")
instead of jax_valid_data.get("fparam", None)
Replace jax_valid_data.get("fparam", None)
with jax_valid_data.get("fparam")
(SIM910)
334-334: Use jax_valid_data.get("aparam")
instead of jax_valid_data.get("aparam", None)
Replace jax_valid_data.get("aparam", None)
with jax_valid_data.get("aparam")
(SIM910)
396-396: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
401-401: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
🔇 Additional comments (11)
deepmd/jax/train/__init__.py (1)
1-1
: LGTM! Clean package initialization.The license header is properly included and this follows standard Python package conventions.
deepmd/jax/entrypoints/__init__.py (1)
1-1
: LGTM! Clean package initialization.The license header is properly included and this follows standard Python package conventions.
deepmd/dpmodel/descriptor/dpa1.py (1)
784-784
: LGTM! Descriptor count calculation is logical.The calculation
self.ndescrpt = self.nnei * 4
correctly computes the total descriptor elements, where each neighbor contributes 4 values (consistent with thewanted_shape = (self.ntypes, self.nnei, 4)
on line 779).deepmd/backend/jax.py (1)
63-67
: LGTM! Proper implementation of the entry point hook.The change correctly implements the JAX backend's entry point by importing and returning the
main
function from the new entrypoints module, replacing the previousNotImplementedError
. This properly connects the backend to the CLI infrastructure.deepmd/jax/utils/serialization.py (1)
180-180
: LGTM! Excellent checkpoint step tracking implementation.The extraction of
current_step
frommodel_def_script
and its inclusion in the@variables
dictionary provides proper support for training resumption. The default value of 0 and the use of.pop()
to avoid duplication are both appropriate design choices.Also applies to: 191-193
deepmd/dpmodel/utils/env_mat_stat.py (1)
122-130
:✅ Verification successful
Verify natoms removal doesn't affect other code.
The explicit reshaping operations improve data handling clarity and the removal of
natoms
unpacking simplifies the code. However, ensure thatnatoms
isn't required elsewhere in the codebase.
🏁 Script executed:
#!/bin/bash # Description: Check if natoms is used elsewhere in the codebase that might depend on this method. # Expected: Find any usage of natoms that could be affected by this change. # Search for natoms usage in environment matrix statistics contexts rg -A 5 -B 5 "natoms.*system\[" --type py rg -A 5 -B 5 "system.*natoms" --type pyLength of output: 28436
🏁 Script executed:
#!/bin/bash # Locate DP model env_mat_stat file and inspect natoms usage FILE=$(find deepmd/dpmodel -type f -name env_mat_stat.py) echo "Found file: $FILE" if [ -z "$FILE" ]; then echo "env_mat_stat.py not found in deepmd/dpmodel" exit 0 fi echo -e "\n---- grep natoms occurrences ----" grep -n "natoms" "$FILE" || echo "No natoms references found" echo -e "\n---- Context around lines 100-160 ----" sed -n '100,160p' "$FILE"Length of output: 2702
No missing
natoms
references in deepmd/dpmodel/utils/env_mat_stat.py; removal is safe.A search in
deepmd/dpmodel/utils/env_mat_stat.py
shows no remaining uses ofnatoms
around the updated lines or elsewhere in that file. The reshaping logic now only relies oncoord.shape[0]
, and no downstream code in this module expectsnatoms
.– deepmd/dpmodel/utils/env_mat_stat.py: lines 122–130 no longer mention
natoms
, and no grep hits for that identifier in this file.All clear to approve.
deepmd/jax/entrypoints/main.py (1)
1-31
: Well-structured imports and module setup.The imports are properly organized and the module docstring clearly describes the purpose.
deepmd/dpmodel/fitting/ener_fitting.py (3)
9-9
: Appropriate imports for statistics computation.The numpy and compute_stats_from_redu imports are necessary for the new output statistics methods.
Also applies to: 22-25
96-112
: Well-documented public interface for computing output statistics.The method provides a clean interface that delegates to the private implementation method.
144-144
:❓ Verification inconclusive
Validate array shape before slicing.
The code slices
sys_tynatom[:, 2:]
without validating the array has at least 3 columns. This could cause an IndexError if the data structure is unexpected.Consider adding a validation check:
sys_tynatom = np.array(sys_tynatom) sys_tynatom = np.reshape(sys_tynatom, [nsys, -1]) + if sys_tynatom.shape[1] < 3: + raise ValueError(f"Expected natoms_vec to have at least 3 columns, got {sys_tynatom.shape[1]}") sys_tynatom = sys_tynatom[:, 2:]Let me verify the expected structure of natoms_vec:
🏁 Script executed:
#!/bin/bash # Search for natoms_vec usage patterns to understand expected structure rg -A 5 -B 5 'natoms_vec.*\[:.*2' --type pyLength of output: 45
#!/bin/bash
Display context around the slicing to understand sys_tynatom’s origin and expected shape
rg -n 'sys_tynatom' -C5 deepmd/dpmodel/fitting/ener_fitting.py
🏁 Script executed:
#!/bin/bash # Show surrounding lines where sys_tynatom is defined and sliced rg -n 'sys_tynatom' -C5 deepmd/dpmodel/fitting/ener_fitting.pyLength of output: 1913
#!/bin/bash
Find definitions or usage of natoms_vec and real_natoms_vec across the codebase
rg -n '"natoms_vec"' -C5 --type py
rg -n '"real_natoms_vec"' -C5 --type py
🏁 Script executed:
#!/bin/bash # Show context where all_stat is defined or populated in ener_fitting.py rg -n 'all_stat' -C5 deepmd/dpmodel/fitting/ener_fitting.pyLength of output: 2317
🏁 Script executed:
#!/bin/bash # Search for definitions and usage of natoms_vec, real_natoms_vec, and make_stat_input rg -n '"natoms_vec"' -C5 --type py rg -n '"real_natoms_vec"' -C5 --type py rg -n 'def make_stat_input' -C5 --type pyLength of output: 72872
Validate
sys_tynatom
shape before slicingEnsure
sys_tynatom
has at least three columns before applyingsys_tynatom[:, 2:]
. While NumPy slicing won’t throw anIndexError
for out‐of‐bounds slices, it will silently return an empty array ifshape[1] < 3
, which will break downstream logic incompute_stats_from_redu
. Add a guard immediately after the reshape:--- a/deepmd/dpmodel/fitting/ener_fitting.py +++ b/deepmd/dpmodel/fitting/ener_fitting.py @@ -142,6 +142,10 @@ sys_tynatom = np.reshape(sys_tynatom, [nsys, -1]) + if sys_tynatom.shape[1] < 3: + raise ValueError( + f"compute_output_stats requires at least 3 columns in sys_tynatom, got {sys_tynatom.shape[1]}" + ) sys_tynatom = sys_tynatom[:, 2:]• File:
deepmd/dpmodel/fitting/ener_fitting.py
• Context: lines 142–144, inside_compute_output_stats
Please verify that
all_stat["natoms_vec"]
andall_stat["real_natoms_vec"]
always yield ≥3 columns so this slice is valid.deepmd/dpmodel/loss/ener.py (1)
180-182
: Correct RMSE calculation implementation.The changes properly compute the root mean square error by applying
sqrt
to the mean squared loss, providing accurate error metrics.Also applies to: 194-196
elif args.command is None: | ||
pass | ||
else: | ||
raise RuntimeError(f"unknown command {args.command}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve command handling for better error reporting.
The current implementation silently passes when no command is provided, which might hide configuration issues. Additionally, the error message for unknown commands could be more helpful.
Apply this diff to improve error handling:
- elif args.command is None:
- pass
+ elif args.command is None:
+ raise RuntimeError("No command specified. Available commands: train, freeze")
else:
- raise RuntimeError(f"unknown command {args.command}")
+ raise RuntimeError(
+ f"Unknown command '{args.command}'. Available commands: train, freeze"
+ )
🤖 Prompt for AI Agents
In deepmd/jax/entrypoints/main.py around lines 64 to 67, replace the silent pass
when args.command is None with a clear error message indicating that no command
was provided. Also, enhance the RuntimeError message for unknown commands to
suggest checking available commands or usage. This improves error reporting by
explicitly handling missing commands and providing more informative feedback for
unknown commands.
# In this situation, we directly use these assigned energies instead of computing stats. | ||
# This will make the loss decrease quickly | ||
assigned_atom_ener = np.array( | ||
[ee if ee is not None else np.nan for ee in self.atom_ener_v] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix undefined attribute reference.
The code references self.atom_ener_v
which is not defined in the class. This will cause an AttributeError
at runtime.
Based on the similar implementation in deepmd/tf/fit/ener.py
, this should likely be self.atom_ener
:
- [ee if ee is not None else np.nan for ee in self.atom_ener_v]
+ [ee if ee is not None else np.nan for ee in self.atom_ener]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
[ee if ee is not None else np.nan for ee in self.atom_ener_v] | |
[ee if ee is not None else np.nan for ee in self.atom_ener] |
🤖 Prompt for AI Agents
In deepmd/dpmodel/fitting/ener_fitting.py at line 150, the code references an
undefined attribute self.atom_ener_v, causing an AttributeError. Replace
self.atom_ener_v with self.atom_ener to match the correct attribute name used in
the class, following the pattern from deepmd/tf/fit/ener.py.
Other keyword arguments. | ||
""" | ||
EnergyLoss.__init__(self, **kwargs) | ||
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix logical condition for enabling Hessian loss.
The condition uses and
which requires both prefactors to be non-zero. This is likely too restrictive - the Hessian loss should be enabled if either prefactor is non-zero.
Apply this fix:
- self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
+ self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0 | |
self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0 |
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 409, the condition for enabling Hessian
loss uses 'and' to check if both start_pref_h and limit_pref_h are non-zero,
which is too restrictive. Change the logical operator from 'and' to 'or' so that
the Hessian loss is enabled if either start_pref_h or limit_pref_h is non-zero.
and "energy_derv_r_derv_r" in model_dict | ||
and "hessian" in label_dict | ||
): | ||
find_hessian = label_dict.get("find_hessian", 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify find_hessian data type consistency.
The code uses label_dict.get("find_hessian", 0.0)
which returns a scalar default, but other find_*
variables in the parent class appear to be arrays from the labeled data.
Let me verify the expected data type for find_* variables:
🏁 Script executed:
#!/bin/bash
# Search for find_* variable usage patterns in loss calculations
rg -B 2 -A 2 'find_[a-z_]+\s*=' deepmd/dpmodel/loss/ --type py
Length of output: 1419
🏁 Script executed:
#!/bin/bash
# Search for DataRequirementItem definitions that might indicate find_* data structure
ast-grep --pattern 'DataRequirementItem(
"find_$_",
$$$
)'
Length of output: 64
Ensure find_hessian
is accessed like the other find_*
flags
The other find_*
variables are pulled directly from label_dict
(no default), so using .get(..., 0.0)
here is inconsistent and may silently disable Hessian loss even when "hessian"
is present. Please update this so that missing flags are surfaced (or, if you really intend an optional flag, default to a matching array of ones).
• File deepmd/dpmodel/loss/ener.py, line 434
- find_hessian = label_dict.get("find_hessian", 0.0)
+ find_hessian = label_dict["find_hessian"]
If you do need a default, consider instead:
find_hessian = label_dict.get(
"find_hessian",
xp.ones(label_dict["hessian"].shape[0]) # match per-frame flag shape
)
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 434, the assignment of find_hessian uses
label_dict.get with a scalar default 0.0, which is inconsistent with other
find_* variables that do not use defaults and are arrays. To fix this, remove
the default value so that missing keys raise an error or, if a default is
necessary, set it to an array of ones matching the shape of
label_dict["hessian"]. This ensures data type consistency and proper handling of
the find_hessian flag.
optimizer.update(grads) | ||
|
||
start_time = time.time() | ||
disp_file_fp = open(self.disp_file, "w") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use context manager for file operations.
The file disp_file_fp
should be opened using a context manager to ensure proper cleanup even if an exception occurs.
Apply this diff:
- disp_file_fp = open(self.disp_file, "w")
- for step in range(self.start_step, self.num_steps):
+ with open(self.disp_file, "w") as disp_file_fp:
+ for step in range(self.start_step, self.num_steps):
# ... (training loop)
-
- disp_file_fp.close()
Note: You'll need to indent the entire training loop inside the context manager.
Also applies to: 382-382
🧰 Tools
🪛 Ruff (0.11.9)
269-269: Use a context manager for opening files
(SIM115)
🤖 Prompt for AI Agents
In deepmd/jax/train/trainer.py at lines 269 and 382, the file disp_file_fp is
opened without a context manager, risking resource leaks if exceptions occur.
Refactor the code to open disp_file_fp using a with statement as a context
manager, and indent the entire training loop or relevant code block inside this
with block to ensure the file is properly closed after use.
valid_batch_data = valid_data.get_batch() | ||
jax_valid_data = { | ||
kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apply consistent data conversion for validation data.
The validation data conversion should check for keys starting with "find_" similar to the training data conversion.
Apply this diff:
jax_valid_data = {
- kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items()
+ kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item())
+ for kk, vv in valid_batch_data.items()
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
valid_batch_data = valid_data.get_batch() | |
jax_valid_data = { | |
kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items() | |
} | |
valid_batch_data = valid_data.get_batch() | |
jax_valid_data = { | |
kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item()) | |
for kk, vv in valid_batch_data.items() | |
} |
🤖 Prompt for AI Agents
In deepmd/jax/train/trainer.py around lines 320 to 323, the validation data
conversion to JAX arrays should be consistent with the training data conversion
by checking if keys start with "find_". Update the dictionary comprehension to
convert values to jnp.asarray only for keys starting with "find_", leaving other
keys unchanged.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4782 +/- ##
==========================================
- Coverage 84.79% 84.40% -0.40%
==========================================
Files 698 702 +4
Lines 67775 68126 +351
Branches 3544 3542 -2
==========================================
+ Hits 57472 57499 +27
- Misses 9169 9494 +325
+ Partials 1134 1133 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
New Features
Enhancements
Other Changes