Skip to content

Commit

Permalink
updated docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Feb 27, 2024
1 parent bce2ffd commit 1b35d63
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
24 changes: 14 additions & 10 deletions kgcnn/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras.losses import mean_absolute_error
import keras.saving
from kgcnn.ops.core import decompose_ragged_tensor
# from kgcnn.ops.scatter import scatter_reduce_mean


@ks.saving.register_keras_serializable(package='kgcnn', name='MeanAbsoluteError')
Expand Down Expand Up @@ -61,32 +62,35 @@ def get_config(self):

@ks.saving.register_keras_serializable(package='kgcnn', name='DisjointForceMeanAbsoluteError')
class DisjointForceMeanAbsoluteError(Loss):
"""This is dummy class. Not working at the moment as intended.
"""This is an experimental class for force loss of disjoint output."""

We need to pass the node ids here somehow.
"""
# Ideally we want to pass batch IDs and energy shape to the loss to do scatter mean.
# However, passing as attribute similar to mask, has not been working yet.

def __init__(self, reduction="sum_over_batch_size", name="force_mean_absolute_error",
squeeze_states: bool = True, find_padded_atoms: bool = True, dtype=None):
squeeze_states: bool = True, padded_disjoint: bool = False, dtype=None):
super(DisjointForceMeanAbsoluteError, self).__init__(reduction=reduction, name=name, dtype=dtype)
self.squeeze_states = squeeze_states
self.find_padded_atoms = find_padded_atoms
self.padded_disjoint = padded_disjoint

def call(self, y_true, y_pred):
# Shape: ([N], 3, S)
if self.find_padded_atoms:
check_nonzero = ops.logical_not(
ops.all(ops.isclose(y_true, ops.convert_to_tensor(0., dtype=y_true.dtype)), axis=1))
y_pred = y_pred * ops.cast(ops.expand_dims(check_nonzero, axis=1), dtype=y_pred.dtype)
row_count = ops.sum(ops.cast(check_nonzero, dtype="int32"), axis=0)

if self.padded_disjoint:
# Mask Shape: ([N, S])
mask = ops.logical_not(ops.all(ops.isclose(y_true, ops.convert_to_tensor(0., dtype=y_true.dtype)), axis=1))
y_pred = y_pred * ops.cast(ops.expand_dims(mask, axis=1), dtype=y_pred.dtype)
row_count = ops.sum(ops.cast(mask, dtype="int32"), axis=0)
row_count = ops.where(row_count < 1, 1, row_count) # Prevent divide by 0.
# roq count shape ([S])
norm = 1 / ops.cast(row_count, dtype=y_true.dtype)
else:
norm = 1/ops.shape(y_true)[0]

diff = ops.abs(y_true-y_pred)
out = ops.mean(diff, axis=1)
out = ops.sum(out, axis=0)*norm

if not self.squeeze_states:
out = ops.mean(out, axis=-1)

Expand Down
3 changes: 1 addition & 2 deletions kgcnn/models/force.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import keras as ks
import keras.saving
from keras import ops
from typing import Union
from kgcnn.models.utils import get_model_class
from keras.saving import deserialize_keras_object, serialize_keras_object
Expand Down Expand Up @@ -87,12 +88,10 @@ def __init__(self,
# Additional parameters of io and behavior of this class.
self.ragged_validate = ragged_validate
self.coordinate_input = coordinate_input
# self.output_to_tensor = output_to_tensor
self.output_squeeze_states = output_squeeze_states
self.is_physical_force = is_physical_force
self.nested_model_config = nested_model_config
self._force_outputs = outputs
# self.use_batch_jacobian = use_batch_jacobian

self.output_as_dict = output_as_dict
if isinstance(output_as_dict, bool):
Expand Down

0 comments on commit 1b35d63

Please sign in to comment.