Skip to content

feat: JAX training #4782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: devel
Choose a base branch
from
Draft

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Jun 5, 2025

Summary by CodeRabbit

  • New Features

    • Introduced JAX backend entry points for training and model freezing, including command-line interfaces for both tasks.
    • Added support for Hessian loss computation in energy loss, with configurable prefactors and RMSE reporting.
    • Implemented a new training framework for DeePMD models using JAX, supporting checkpointing, mixed precision, and detailed logging.
  • Enhancements

    • Improved RMSE calculation and display for energy and force losses.
    • Added output statistics computation for energy fitting models.
    • Generalized learning rate scheduling to support alternative numerical libraries.
  • Other Changes

    • Updated serialization to include the current training step in JAX model files.
    • Minor adjustments to data preparation and environment matrix statistics handling.
    • Added license identifiers to new modules.

njzjz added 20 commits May 25, 2025 12:53
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
@njzjz njzjz added this to the v3.2.0 milestone Jun 5, 2025
@njzjz njzjz linked an issue Jun 5, 2025 that may be closed by this pull request
@github-actions github-actions bot added the Python label Jun 5, 2025
valid_data = None

# get training info
stop_batch = jdata["training"]["numb_steps"]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable stop_batch is not used.
if (
origin_type_map is not None and not origin_type_map
): # get the type_map from data if not provided
origin_type_map = get_data(

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable origin_type_map is not used.
)
jdata_cpy = jdata.copy()
type_map = jdata["model"].get("type_map")
train_data = get_data(

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable train_data is not used.
Copy link
Contributor

coderabbitai bot commented Jun 5, 2025

📝 Walkthrough

Walkthrough

This update introduces a JAX-based training and entrypoint framework for DeePMD-Kit, including new modules for command-line handling, model freezing, and training orchestration. It adds or modifies several utilities and loss functions, enhances serialization with step tracking, and generalizes learning rate scheduling. Minor internal changes and new attributes are also introduced in descriptors and statistics utilities.

Changes

File(s) Change Summary
deepmd/backend/jax.py Updated JAXBackend.entry_point_hook to return the actual entry point (main) instead of raising NotImplementedError.
deepmd/dpmodel/descriptor/dpa1.py Added ndescrpt attribute to DescrptBlockSeAtten as self.nnei * 4.
deepmd/dpmodel/fitting/ener_fitting.py Added compute_output_stats and _compute_output_stats methods to EnergyFittingNet for output statistics computation.
deepmd/dpmodel/loss/ener.py Refined RMSE calculation for energy/force loss; added EnergyHessianLoss class for Hessian loss support.
deepmd/dpmodel/utils/env_mat_stat.py Adjusted unpacking in EnvMatStatSe.iter to ignore natoms, updated reshaping logic.
deepmd/dpmodel/utils/learning_rate.py Generalized LearningRateExp.value to accept a numerical library parameter (xp), defaulting to NumPy.
deepmd/jax/entrypoints/__init__.py, deepmd/jax/train/__init__.py Added license identifier files, no functional code.
deepmd/jax/entrypoints/freeze.py Added freeze function to convert checkpoint data into a serialized output file.
deepmd/jax/entrypoints/main.py Added main function as CLI entry point, dispatching to train or freeze commands.
deepmd/jax/entrypoints/train.py Added JAX-based training entrypoint, with SummaryPrinter, train, and update_sel functions.
deepmd/jax/train/trainer.py Introduced DPTrainer class for model training, checkpointing, and reporting; added prepare_input utility.
deepmd/jax/utils/serialization.py Included current_step in serialized model data under "@variables".

Sequence Diagram(s)

sequenceDiagram
    participant CLI_User
    participant MainEntrypoint
    participant TrainEntrypoint
    participant FreezeEntrypoint
    participant DPTrainer
    participant SerializationUtils

    CLI_User->>MainEntrypoint: main(args)
    MainEntrypoint->>MainEntrypoint: parse_args(args)
    alt command == "train"
        MainEntrypoint->>TrainEntrypoint: train(**args)
        TrainEntrypoint->>DPTrainer: DPTrainer(jdata, ...)
        DPTrainer->>DPTrainer: train(train_data, valid_data)
        DPTrainer->>SerializationUtils: save_checkpoint(...)
    else command == "freeze"
        MainEntrypoint->>FreezeEntrypoint: freeze(checkpoint_folder, output)
        FreezeEntrypoint->>SerializationUtils: serialize_from_file(folder)
        FreezeEntrypoint->>SerializationUtils: deserialize_to_file(output)
    end
Loading

Suggested labels

Python

Suggested reviewers

  • wanghan-iapcm

Warning

There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure.

🔧 Pylint (3.3.7)
deepmd/dpmodel/loss/ener.py

No files to lint: exiting.

deepmd/dpmodel/fitting/ener_fitting.py

No files to lint: exiting.

deepmd/dpmodel/descriptor/dpa1.py

No files to lint: exiting.

  • 10 others
✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (10)
deepmd/dpmodel/utils/learning_rate.py (1)

48-52: Update return type annotation for backend generalization.

The addition of the xp parameter to support different array backends (like JAX) is excellent for the framework's extensibility. However, the return type annotation should be updated to reflect the actual return type.

-    def value(self, step, xp=np) -> np.float64:
+    def value(self, step, xp=np):

Or if you want to be more specific:

-    def value(self, step, xp=np) -> np.float64:
+    def value(self, step, xp=np) -> Union[np.float64, Any]:
deepmd/jax/entrypoints/freeze.py (1)

12-36: LGTM! Consider adding output validation and improved error handling.

The freeze function implementation is well-structured with proper checkpoint handling logic. The keyword-only argument design and integration with serialization utilities are excellent.

Consider these minor enhancements:

def freeze(
    *,
    checkpoint_folder: str,
    output: str,
    **kwargs,
) -> None:
    """Freeze the graph in supplied folder.

    Parameters
    ----------
    checkpoint_folder : str
        location of either the folder with checkpoint or the checkpoint prefix
    output : str
-        output file name
+        output file name (supported formats: .jax, .hlo, .savedmodel)
    **kwargs
        other arguments
    """
+    # Validate output format
+    supported_formats = ['.jax', '.hlo', '.savedmodel']
+    if not any(output.endswith(fmt) for fmt in supported_formats):
+        raise ValueError(f"Unsupported output format. Supported: {supported_formats}")
+    
    if (Path(checkpoint_folder) / "checkpoint").is_file():
        checkpoint_meta = Path(checkpoint_folder) / "checkpoint"
        checkpoint_folder = checkpoint_meta.read_text().strip()
    if Path(checkpoint_folder).is_dir():
-        data = serialize_from_file(checkpoint_folder)
-        deserialize_to_file(output, data)
+        try:
+            data = serialize_from_file(checkpoint_folder)
+            deserialize_to_file(output, data)
+        except Exception as e:
+            raise RuntimeError(f"Failed to freeze checkpoint: {e}") from e
    else:
        raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.")
deepmd/dpmodel/fitting/ener_fitting.py (1)

118-124: Simplify nested loops for better readability.

The triple-nested loops can be simplified using list comprehensions or numpy operations.

Consider refactoring the nested loops:

-        sys_ener = []
-        for ss in range(len(data)):
-            sys_data = []
-            for ii in range(len(data[ss])):
-                for jj in range(len(data[ss][ii])):
-                    sys_data.append(data[ss][ii][jj])
-            sys_data = np.concatenate(sys_data)
-            sys_ener.append(np.average(sys_data))
+        sys_ener = []
+        for system_data in data:
+            # Flatten all batches and frames for this system
+            sys_data = np.concatenate([frame for batch in system_data for frame in batch])
+            sys_ener.append(np.average(sys_data))

Similarly for the mixed_type branch:

-                tmp_tynatom = []
-                for ii in range(len(data[ss])):
-                    for jj in range(len(data[ss][ii])):
-                        tmp_tynatom.append(data[ss][ii][jj].astype(np.float64))
-                tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0)
+                # Flatten all batches and frames, then compute average
+                tmp_tynatom = np.average(
+                    np.array([frame.astype(np.float64) 
+                             for batch in data[ss] 
+                             for frame in batch]), 
+                    axis=0
+                )

Also applies to: 130-136

deepmd/dpmodel/loss/ener.py (1)

457-457: Clarify the ndof comment for Hessian data.

The comment # 9=3*3 --> 3N*3N=ndof*natoms*natoms is confusing since ndof is set to 1, not 9.

Update the comment to be clearer:

-                    ndof=1,  # 9=3*3 --> 3N*3N=ndof*natoms*natoms
+                    ndof=1,  # Hessian has shape (natoms, 3, natoms, 3), flattened per atom
deepmd/jax/train/trainer.py (3)

134-134: Remove unused instance variables.

The following instance variables are assigned but never used in the class:

  • self.numb_fparam (line 134)
  • self.frz_model (line 142)
  • self.ckpt_meta (line 143)
  • self.model_type (line 144)

Consider removing these unused variables to improve code clarity.

Also applies to: 142-145


283-284: Simplify .get() calls by removing redundant None default.

Apply this diff:

-                fparam=jax_data.get("fparam", None),
-                aparam=jax_data.get("aparam", None),
+                fparam=jax_data.get("fparam"),
+                aparam=jax_data.get("aparam"),

And similarly for lines 333-334:

-                            fparam=jax_valid_data.get("fparam", None),
-                            aparam=jax_valid_data.get("aparam", None),
+                            fparam=jax_valid_data.get("fparam"),
+                            aparam=jax_valid_data.get("aparam"),

Also applies to: 333-334

🧰 Tools
🪛 Ruff (0.11.9)

283-283: Use jax_data.get("fparam") instead of jax_data.get("fparam", None)

Replace jax_data.get("fparam", None) with jax_data.get("fparam")

(SIM910)


284-284: Use jax_data.get("aparam") instead of jax_data.get("aparam", None)

Replace jax_data.get("aparam", None) with jax_data.get("aparam")

(SIM910)


396-396: Simplify dictionary key iteration.

Apply this diff:

-            for k in valid_results.keys():
+            for k in valid_results:

And for line 401:

-            for k in train_results.keys():
+            for k in train_results:

Also applies to: 401-401

🧰 Tools
🪛 Ruff (0.11.9)

396-396: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

deepmd/jax/entrypoints/train.py (3)

144-147: Simplify type map assignment with ternary operator.

Apply this diff:

-    if len(type_map) == 0:
-        ipt_type_map = None
-    else:
-        ipt_type_map = type_map
+    ipt_type_map = None if len(type_map) == 0 else type_map
🧰 Tools
🪛 Ruff (0.11.9)

144-147: Use ternary operator ipt_type_map = None if len(type_map) == 0 else type_map instead of if-else-block

Replace if-else-block with ipt_type_map = None if len(type_map) == 0 else type_map

(SIM108)


172-172: Remove unused variable stop_batch.

The variable stop_batch is assigned but never used.

Apply this diff:

-    stop_batch = jdata["training"]["numb_steps"]
🧰 Tools
🪛 Ruff (0.11.9)

172-172: Local variable stop_batch is assigned to but never used

Remove assignment to unused variable stop_batch

(F841)


201-204: Address the OOM issue in neighbor statistics calculation.

The commented code indicates an out-of-memory issue that needs to be resolved. This functionality appears to be important for updating the model's selection parameters based on neighbor statistics.

Would you like me to help investigate the OOM issue and propose a memory-efficient solution for computing neighbor statistics? I could open a new issue to track this task.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9ef43fa and 15bb506.

📒 Files selected for processing (13)
  • deepmd/backend/jax.py (1 hunks)
  • deepmd/dpmodel/descriptor/dpa1.py (1 hunks)
  • deepmd/dpmodel/fitting/ener_fitting.py (3 hunks)
  • deepmd/dpmodel/loss/ener.py (3 hunks)
  • deepmd/dpmodel/utils/env_mat_stat.py (1 hunks)
  • deepmd/dpmodel/utils/learning_rate.py (1 hunks)
  • deepmd/jax/entrypoints/__init__.py (1 hunks)
  • deepmd/jax/entrypoints/freeze.py (1 hunks)
  • deepmd/jax/entrypoints/main.py (1 hunks)
  • deepmd/jax/entrypoints/train.py (1 hunks)
  • deepmd/jax/train/__init__.py (1 hunks)
  • deepmd/jax/train/trainer.py (1 hunks)
  • deepmd/jax/utils/serialization.py (2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
deepmd/backend/jax.py (1)
deepmd/jax/entrypoints/main.py (1)
  • main (32-67)
deepmd/dpmodel/fitting/ener_fitting.py (2)
deepmd/utils/out_stat.py (1)
  • compute_stats_from_redu (15-86)
deepmd/tf/fit/ener.py (2)
  • compute_output_stats (257-273)
  • _compute_output_stats (275-321)
deepmd/jax/entrypoints/main.py (3)
deepmd/backend/suffix.py (1)
  • format_model_suffix (17-75)
deepmd/jax/entrypoints/freeze.py (1)
  • freeze (12-36)
deepmd/loggers/loggers.py (1)
  • set_log_handles (146-278)
🪛 Ruff (0.11.9)
deepmd/jax/entrypoints/train.py

144-147: Use ternary operator ipt_type_map = None if len(type_map) == 0 else type_map instead of if-else-block

Replace if-else-block with ipt_type_map = None if len(type_map) == 0 else type_map

(SIM108)


172-172: Local variable stop_batch is assigned to but never used

Remove assignment to unused variable stop_batch

(F841)


195-195: Local variable train_data is assigned to but never used

Remove assignment to unused variable train_data

(F841)

deepmd/jax/train/trainer.py

269-269: Use a context manager for opening files

(SIM115)


283-283: Use jax_data.get("fparam") instead of jax_data.get("fparam", None)

Replace jax_data.get("fparam", None) with jax_data.get("fparam")

(SIM910)


284-284: Use jax_data.get("aparam") instead of jax_data.get("aparam", None)

Replace jax_data.get("aparam", None) with jax_data.get("aparam")

(SIM910)


333-333: Use jax_valid_data.get("fparam") instead of jax_valid_data.get("fparam", None)

Replace jax_valid_data.get("fparam", None) with jax_valid_data.get("fparam")

(SIM910)


334-334: Use jax_valid_data.get("aparam") instead of jax_valid_data.get("aparam", None)

Replace jax_valid_data.get("aparam", None) with jax_valid_data.get("aparam")

(SIM910)


396-396: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


401-401: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms (29)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Analyze (python)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test C++ (false)
🔇 Additional comments (11)
deepmd/jax/train/__init__.py (1)

1-1: LGTM! Clean package initialization.

The license header is properly included and this follows standard Python package conventions.

deepmd/jax/entrypoints/__init__.py (1)

1-1: LGTM! Clean package initialization.

The license header is properly included and this follows standard Python package conventions.

deepmd/dpmodel/descriptor/dpa1.py (1)

784-784: LGTM! Descriptor count calculation is logical.

The calculation self.ndescrpt = self.nnei * 4 correctly computes the total descriptor elements, where each neighbor contributes 4 values (consistent with the wanted_shape = (self.ntypes, self.nnei, 4) on line 779).

deepmd/backend/jax.py (1)

63-67: LGTM! Proper implementation of the entry point hook.

The change correctly implements the JAX backend's entry point by importing and returning the main function from the new entrypoints module, replacing the previous NotImplementedError. This properly connects the backend to the CLI infrastructure.

deepmd/jax/utils/serialization.py (1)

180-180: LGTM! Excellent checkpoint step tracking implementation.

The extraction of current_step from model_def_script and its inclusion in the @variables dictionary provides proper support for training resumption. The default value of 0 and the use of .pop() to avoid duplication are both appropriate design choices.

Also applies to: 191-193

deepmd/dpmodel/utils/env_mat_stat.py (1)

122-130:

✅ Verification successful

Verify natoms removal doesn't affect other code.

The explicit reshaping operations improve data handling clarity and the removal of natoms unpacking simplifies the code. However, ensure that natoms isn't required elsewhere in the codebase.


🏁 Script executed:

#!/bin/bash
# Description: Check if natoms is used elsewhere in the codebase that might depend on this method.
# Expected: Find any usage of natoms that could be affected by this change.

# Search for natoms usage in environment matrix statistics contexts
rg -A 5 -B 5 "natoms.*system\[" --type py
rg -A 5 -B 5 "system.*natoms" --type py

Length of output: 28436


🏁 Script executed:

#!/bin/bash
# Locate DP model env_mat_stat file and inspect natoms usage
FILE=$(find deepmd/dpmodel -type f -name env_mat_stat.py)
echo "Found file: $FILE"
if [ -z "$FILE" ]; then
  echo "env_mat_stat.py not found in deepmd/dpmodel"
  exit 0
fi

echo -e "\n---- grep natoms occurrences ----"
grep -n "natoms" "$FILE" || echo "No natoms references found"

echo -e "\n---- Context around lines 100-160 ----"
sed -n '100,160p' "$FILE"

Length of output: 2702


No missing natoms references in deepmd/dpmodel/utils/env_mat_stat.py; removal is safe.

A search in deepmd/dpmodel/utils/env_mat_stat.py shows no remaining uses of natoms around the updated lines or elsewhere in that file. The reshaping logic now only relies on coord.shape[0], and no downstream code in this module expects natoms.

– deepmd/dpmodel/utils/env_mat_stat.py: lines 122–130 no longer mention natoms, and no grep hits for that identifier in this file.

All clear to approve.

deepmd/jax/entrypoints/main.py (1)

1-31: Well-structured imports and module setup.

The imports are properly organized and the module docstring clearly describes the purpose.

deepmd/dpmodel/fitting/ener_fitting.py (3)

9-9: Appropriate imports for statistics computation.

The numpy and compute_stats_from_redu imports are necessary for the new output statistics methods.

Also applies to: 22-25


96-112: Well-documented public interface for computing output statistics.

The method provides a clean interface that delegates to the private implementation method.


144-144:

❓ Verification inconclusive

Validate array shape before slicing.

The code slices sys_tynatom[:, 2:] without validating the array has at least 3 columns. This could cause an IndexError if the data structure is unexpected.

Consider adding a validation check:

         sys_tynatom = np.array(sys_tynatom)
         sys_tynatom = np.reshape(sys_tynatom, [nsys, -1])
+        if sys_tynatom.shape[1] < 3:
+            raise ValueError(f"Expected natoms_vec to have at least 3 columns, got {sys_tynatom.shape[1]}")
         sys_tynatom = sys_tynatom[:, 2:]

Let me verify the expected structure of natoms_vec:


🏁 Script executed:

#!/bin/bash
# Search for natoms_vec usage patterns to understand expected structure
rg -A 5 -B 5 'natoms_vec.*\[:.*2' --type py

Length of output: 45


#!/bin/bash

Display context around the slicing to understand sys_tynatom’s origin and expected shape

rg -n 'sys_tynatom' -C5 deepmd/dpmodel/fitting/ener_fitting.py


🏁 Script executed:

#!/bin/bash
# Show surrounding lines where sys_tynatom is defined and sliced
rg -n 'sys_tynatom' -C5 deepmd/dpmodel/fitting/ener_fitting.py

Length of output: 1913


#!/bin/bash

Find definitions or usage of natoms_vec and real_natoms_vec across the codebase

rg -n '"natoms_vec"' -C5 --type py
rg -n '"real_natoms_vec"' -C5 --type py


🏁 Script executed:

#!/bin/bash
# Show context where all_stat is defined or populated in ener_fitting.py
rg -n 'all_stat' -C5 deepmd/dpmodel/fitting/ener_fitting.py

Length of output: 2317


🏁 Script executed:

#!/bin/bash
# Search for definitions and usage of natoms_vec, real_natoms_vec, and make_stat_input
rg -n '"natoms_vec"' -C5 --type py
rg -n '"real_natoms_vec"' -C5 --type py
rg -n 'def make_stat_input' -C5 --type py

Length of output: 72872


Validate sys_tynatom shape before slicing

Ensure sys_tynatom has at least three columns before applying sys_tynatom[:, 2:]. While NumPy slicing won’t throw an IndexError for out‐of‐bounds slices, it will silently return an empty array if shape[1] < 3, which will break downstream logic in compute_stats_from_redu. Add a guard immediately after the reshape:

--- a/deepmd/dpmodel/fitting/ener_fitting.py
+++ b/deepmd/dpmodel/fitting/ener_fitting.py
@@ -142,6 +142,10 @@
         sys_tynatom = np.reshape(sys_tynatom, [nsys, -1])
+        if sys_tynatom.shape[1] < 3:
+            raise ValueError(
+                f"compute_output_stats requires at least 3 columns in sys_tynatom, got {sys_tynatom.shape[1]}"
+            )
         sys_tynatom = sys_tynatom[:, 2:]

• File: deepmd/dpmodel/fitting/ener_fitting.py
• Context: lines 142–144, inside _compute_output_stats

Please verify that all_stat["natoms_vec"] and all_stat["real_natoms_vec"] always yield ≥3 columns so this slice is valid.

deepmd/dpmodel/loss/ener.py (1)

180-182: Correct RMSE calculation implementation.

The changes properly compute the root mean square error by applying sqrt to the mean squared loss, providing accurate error metrics.

Also applies to: 194-196

Comment on lines +64 to +67
elif args.command is None:
pass
else:
raise RuntimeError(f"unknown command {args.command}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve command handling for better error reporting.

The current implementation silently passes when no command is provided, which might hide configuration issues. Additionally, the error message for unknown commands could be more helpful.

Apply this diff to improve error handling:

-    elif args.command is None:
-        pass
+    elif args.command is None:
+        raise RuntimeError("No command specified. Available commands: train, freeze")
     else:
-        raise RuntimeError(f"unknown command {args.command}")
+        raise RuntimeError(
+            f"Unknown command '{args.command}'. Available commands: train, freeze"
+        )
🤖 Prompt for AI Agents
In deepmd/jax/entrypoints/main.py around lines 64 to 67, replace the silent pass
when args.command is None with a clear error message indicating that no command
was provided. Also, enhance the RuntimeError message for unknown commands to
suggest checking available commands or usage. This improves error reporting by
explicitly handling missing commands and providing more informative feedback for
unknown commands.

# In this situation, we directly use these assigned energies instead of computing stats.
# This will make the loss decrease quickly
assigned_atom_ener = np.array(
[ee if ee is not None else np.nan for ee in self.atom_ener_v]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix undefined attribute reference.

The code references self.atom_ener_v which is not defined in the class. This will cause an AttributeError at runtime.

Based on the similar implementation in deepmd/tf/fit/ener.py, this should likely be self.atom_ener:

-                [ee if ee is not None else np.nan for ee in self.atom_ener_v]
+                [ee if ee is not None else np.nan for ee in self.atom_ener]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
[ee if ee is not None else np.nan for ee in self.atom_ener_v]
[ee if ee is not None else np.nan for ee in self.atom_ener]
🤖 Prompt for AI Agents
In deepmd/dpmodel/fitting/ener_fitting.py at line 150, the code references an
undefined attribute self.atom_ener_v, causing an AttributeError. Replace
self.atom_ener_v with self.atom_ener to match the correct attribute name used in
the class, following the pattern from deepmd/tf/fit/ener.py.

Other keyword arguments.
"""
EnergyLoss.__init__(self, **kwargs)
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix logical condition for enabling Hessian loss.

The condition uses and which requires both prefactors to be non-zero. This is likely too restrictive - the Hessian loss should be enabled if either prefactor is non-zero.

Apply this fix:

-        self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
+        self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 409, the condition for enabling Hessian
loss uses 'and' to check if both start_pref_h and limit_pref_h are non-zero,
which is too restrictive. Change the logical operator from 'and' to 'or' so that
the Hessian loss is enabled if either start_pref_h or limit_pref_h is non-zero.

and "energy_derv_r_derv_r" in model_dict
and "hessian" in label_dict
):
find_hessian = label_dict.get("find_hessian", 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify find_hessian data type consistency.

The code uses label_dict.get("find_hessian", 0.0) which returns a scalar default, but other find_* variables in the parent class appear to be arrays from the labeled data.

Let me verify the expected data type for find_* variables:


🏁 Script executed:

#!/bin/bash
# Search for find_* variable usage patterns in loss calculations
rg -B 2 -A 2 'find_[a-z_]+\s*=' deepmd/dpmodel/loss/ --type py

Length of output: 1419


🏁 Script executed:

#!/bin/bash
# Search for DataRequirementItem definitions that might indicate find_* data structure
ast-grep --pattern 'DataRequirementItem(
  "find_$_",
  $$$
)'

Length of output: 64


Ensure find_hessian is accessed like the other find_* flags

The other find_* variables are pulled directly from label_dict (no default), so using .get(..., 0.0) here is inconsistent and may silently disable Hessian loss even when "hessian" is present. Please update this so that missing flags are surfaced (or, if you really intend an optional flag, default to a matching array of ones).

• File deepmd/dpmodel/loss/ener.py, line 434

-    find_hessian = label_dict.get("find_hessian", 0.0)
+    find_hessian = label_dict["find_hessian"]

If you do need a default, consider instead:

find_hessian = label_dict.get(
    "find_hessian",
    xp.ones(label_dict["hessian"].shape[0])  # match per-frame flag shape
)
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 434, the assignment of find_hessian uses
label_dict.get with a scalar default 0.0, which is inconsistent with other
find_* variables that do not use defaults and are arrays. To fix this, remove
the default value so that missing keys raise an error or, if a default is
necessary, set it to an array of ones matching the shape of
label_dict["hessian"]. This ensures data type consistency and proper handling of
the find_hessian flag.

optimizer.update(grads)

start_time = time.time()
disp_file_fp = open(self.disp_file, "w")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use context manager for file operations.

The file disp_file_fp should be opened using a context manager to ensure proper cleanup even if an exception occurs.

Apply this diff:

-        disp_file_fp = open(self.disp_file, "w")
-        for step in range(self.start_step, self.num_steps):
+        with open(self.disp_file, "w") as disp_file_fp:
+            for step in range(self.start_step, self.num_steps):
             # ... (training loop)
-
-        disp_file_fp.close()

Note: You'll need to indent the entire training loop inside the context manager.

Also applies to: 382-382

🧰 Tools
🪛 Ruff (0.11.9)

269-269: Use a context manager for opening files

(SIM115)

🤖 Prompt for AI Agents
In deepmd/jax/train/trainer.py at lines 269 and 382, the file disp_file_fp is
opened without a context manager, risking resource leaks if exceptions occur.
Refactor the code to open disp_file_fp using a with statement as a context
manager, and indent the entire training loop or relevant code block inside this
with block to ensure the file is properly closed after use.

Comment on lines +320 to +323
valid_batch_data = valid_data.get_batch()
jax_valid_data = {
kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Apply consistent data conversion for validation data.

The validation data conversion should check for keys starting with "find_" similar to the training data conversion.

Apply this diff:

                     jax_valid_data = {
-                        kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items()
+                        kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item())
+                        for kk, vv in valid_batch_data.items()
                     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
valid_batch_data = valid_data.get_batch()
jax_valid_data = {
kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items()
}
valid_batch_data = valid_data.get_batch()
jax_valid_data = {
kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item())
for kk, vv in valid_batch_data.items()
}
🤖 Prompt for AI Agents
In deepmd/jax/train/trainer.py around lines 320 to 323, the validation data
conversion to JAX arrays should be consistent with the training data conversion
by checking if keys start with "find_". Update the dictionary comprehension to
convert values to jnp.asarray only for keys starting with "find_", leaving other
keys unchanged.

Copy link

codecov bot commented Jun 5, 2025

Codecov Report

Attention: Patch coverage is 5.76369% with 327 lines in your changes missing coverage. Please review.

Project coverage is 84.40%. Comparing base (265d094) to head (15bb506).
Report is 3 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/jax/train/trainer.py 0.00% 168 Missing ⚠️
deepmd/jax/entrypoints/train.py 0.00% 68 Missing ⚠️
deepmd/dpmodel/fitting/ener_fitting.py 10.52% 34 Missing ⚠️
deepmd/dpmodel/loss/ener.py 24.13% 22 Missing ⚠️
deepmd/jax/entrypoints/main.py 0.00% 22 Missing ⚠️
deepmd/jax/entrypoints/freeze.py 0.00% 10 Missing ⚠️
deepmd/backend/jax.py 0.00% 2 Missing ⚠️
deepmd/jax/utils/serialization.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4782      +/-   ##
==========================================
- Coverage   84.79%   84.40%   -0.40%     
==========================================
  Files         698      702       +4     
  Lines       67775    68126     +351     
  Branches     3544     3542       -2     
==========================================
+ Hits        57472    57499      +27     
- Misses       9169     9494     +325     
+ Partials     1134     1133       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] JAX training
1 participant