Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atom weighted loss functions and loss_func argument refactor #256

Merged
merged 17 commits into from
Jul 19, 2024

Conversation

laserkelvin
Copy link
Collaborator

This PR adds support for MSE and L1 loss functions that are weighted by the number of atoms in each graph in the matsciml.models.losses module.

In order to enable usage of these functions with tasks that have both scalar (e.g. energy) and vector (e.g. force) targets, I've had to refactor loss_func as an argument to all tasks to support a dictionary mapping, whereby each key corresponds to a task_key, and the passed function the loss for that corresponding target. As an example:

ForceRegressionTask(
   ...,
   loss_func={"energy": matsciml.models.losses.AtomWeightedMSE, "force": nn.MSELoss},
   ...
)

This does not break previous specifications: if a loss module is passed by itself (e.g. loss_func = nn.MSELoss()), the task_keys.setter method copy the function to be used for all targets.

Some refactoring was also needed in _compute_losses to allow for additional arguments to be passed into the loss function, e.g. in this case the number of atoms per graph. New loss functions in losses should try to be consistent in function signatures with native PyTorch ones (e.g. input and target) for consistent mapping.

@laserkelvin laserkelvin added enhancement New feature or request training Issues related to model training labels Jul 19, 2024
Copy link
Collaborator

@melo-gonzo melo-gonzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall! My only suggestion would be to add a couple pytests for the new loss modules. I'll leave it up to you if you want to do that or not. Otherwise feel free to merge when ready.

Signed-off-by: Lee, Kin Long Kelvin <kin.long.kelvin.lee@intel.com>
@laserkelvin
Copy link
Collaborator Author

laserkelvin commented Jul 19, 2024

Looks good overall! My only suggestion would be to add a couple pytests for the new loss modules. I'll leave it up to you if you want to do that or not. Otherwise feel free to merge when ready.

Added a very superficial parametrized test in 04d359f. Will merge when tests pass!

@laserkelvin laserkelvin merged commit 8f8ec4b into IntelLabs:main Jul 19, 2024
3 checks passed
@laserkelvin laserkelvin deleted the atom-weighted-loss branch July 19, 2024 18:54
@melo-gonzo melo-gonzo mentioned this pull request Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request training Issues related to model training
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants