-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Atom weighted loss functions and loss_func
argument refactor
#256
Conversation
Signed-off-by: Lee, Kin Long Kelvin <kin.long.kelvin.lee@intel.com>
…ceregressiontask" This reverts commit 84950bc. Realizing it's probably better to make `loss_func` and mapping instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall! My only suggestion would be to add a couple pytests for the new loss modules. I'll leave it up to you if you want to do that or not. Otherwise feel free to merge when ready.
Signed-off-by: Lee, Kin Long Kelvin <kin.long.kelvin.lee@intel.com>
Added a very superficial parametrized test in 04d359f. Will merge when tests pass! |
This PR adds support for MSE and L1 loss functions that are weighted by the number of atoms in each graph in the
matsciml.models.losses
module.In order to enable usage of these functions with tasks that have both scalar (e.g. energy) and vector (e.g. force) targets, I've had to refactor
loss_func
as an argument to all tasks to support a dictionary mapping, whereby each key corresponds to atask_key
, and the passed function the loss for that corresponding target. As an example:This does not break previous specifications: if a loss module is passed by itself (e.g.
loss_func = nn.MSELoss()
), thetask_keys.setter
method copy the function to be used for all targets.Some refactoring was also needed in
_compute_losses
to allow for additional arguments to be passed into the loss function, e.g. in this case the number of atoms per graph. New loss functions inlosses
should try to be consistent in function signatures with native PyTorch ones (e.g.input
andtarget
) for consistent mapping.