Skip to content

Commit

Permalink
added safe_scatter_max_min_to_zero parameters and simple ragged sup…
Browse files Browse the repository at this point in the history
…port for forces and metrics.
  • Loading branch information
PatReis committed Feb 19, 2024
1 parent 019a9a2 commit d7b9f3f
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 17 deletions.
10 changes: 8 additions & 2 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
v4.0.1

* Removed unused layers and added manual built in scripts and training functions, since with keras==3.0.5 the pytorch trainer tries to rebuild the model,
even if the model is already built and does it eagerly without proper tensor input, which causes crashes for almost every model in kgcnn.
* Fix Error in ``ExtensiveMolecularLabelScaler.transform`` missing default value.
* Added further benchmark results for kgcnn version 4.
* Fix error in ``kgcnn.layers.geom.PositionEncodingBasisLayer``
* Fix error in ``kgcnn.literature.GCN.make_model_weighted``
* Fix error in ``kgcnn.literature.AttentiveFP.make_model``
* Had to change serialization for activation functions since with keras>=3.0.2 custom strings are not allowed also
causing clashes with built-in functions. We catch defaults to be at least backward compatible as possible and changed to serialization dictionary. Adapted all hyperparameter.
* Renamed leaky_relu and swish in ``kgcnn.ops.activ`` to leaky_relu2 and swish2.
causing clashes with built-in functions. We catch defaults to be at least as backward compatible as possible and changed to serialization dictionary. Adapted all hyperparameter.
* Renamed leaky_relu and swish in ``kgcnn.ops.activ`` to leaky_relu2 and swish2.
* Fix error in jax scatter min/max functions.
* Added ``kgcnn.__safe_scatter_max_min_to_zero__`` for tensorflow and jax backend scattering with default to True.
* Added simple ragged support for loss and metrics.
* Added simple ragged support for ``train_force.py``

v4.0.0

Expand Down
3 changes: 3 additions & 0 deletions kgcnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
__indices_axis__ = 0
__index_receive__ = 0
__index_send__ = 1

# Behaviour for backend functions.
__safe_scatter_max_min_to_zero__ = True
41 changes: 37 additions & 4 deletions kgcnn/backend/_jax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
import numpy as np
import jax.numpy as jnp
from kgcnn import __safe_scatter_max_min_to_zero__ as global_safe_scatter_max_min_to_zero


class binfo:
kind: str = "b"
bits: int = 8 # May be different for jax.
min: bool = False
max: bool = True
dtype: np.dtype = "bool"


def dtype_infos(dtype):
if dtype.kind in ["f", "c"]:
return jnp.finfo(dtype)
elif dtype.kind in ["i", "u"]:
return jnp.iinfo(dtype)
elif dtype.kind in ["b"]:
return binfo()
else:
raise TypeError("Unknown dtype '%s' to get type info." % dtype)


def scatter_reduce_sum(indices, values, shape):
Expand All @@ -7,13 +28,25 @@ def scatter_reduce_sum(indices, values, shape):


def scatter_reduce_min(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype.max, values.dtype)
return zeros.at[indices].min(values)
max_of_dtype = dtype_infos(values.dtype).max
zeros = jnp.full(shape, max_of_dtype, values.dtype)
out = zeros.at[indices].min(values)
if global_safe_scatter_max_min_to_zero:
has_scattered = jnp.zeros(shape, "bool")
has_scattered = has_scattered.at[indices].set(jnp.ones_like(values, dtype="bool"))
out = jnp.where(has_scattered, out, jnp.zeros_like(out))
return out


def scatter_reduce_max(indices, values, shape):
zeros = jnp.full(shape, values.dtype.min, values.dtype)
return zeros.at[indices].max(values)
min_of_dtype = dtype_infos(values.dtype).min
zeros = jnp.full(shape, min_of_dtype, values.dtype)
out = zeros.at[indices].max(values)
if global_safe_scatter_max_min_to_zero:
has_scattered = jnp.zeros(shape, "bool")
has_scattered = has_scattered.at[indices].set(jnp.ones_like(values, dtype="bool"))
out = jnp.where(has_scattered, out, jnp.zeros_like(out))
return out


def scatter_reduce_mean(indices, values, shape):
Expand Down
13 changes: 11 additions & 2 deletions kgcnn/backend/_tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
from kgcnn import __safe_scatter_max_min_to_zero__ as global_safe_scatter_max_min_to_zero


def scatter_reduce_sum(indices, values, shape):
Expand All @@ -9,13 +10,21 @@ def scatter_reduce_sum(indices, values, shape):
def scatter_reduce_min(indices, values, shape):
indices = tf.expand_dims(indices, axis=1)
target = tf.cast(tf.fill(shape, values.dtype.max), dtype=values.dtype)
return tf.tensor_scatter_nd_min(target, indices, values)
out = tf.tensor_scatter_nd_min(target, indices, values)
if global_safe_scatter_max_min_to_zero:
has_scattered = tf.scatter_nd(indices, tf.ones_like(values, dtype="bool"), tf.cast(shape, dtype="int64"))
out = tf.where(has_scattered, out, tf.zeros_like(out))
return out


def scatter_reduce_max(indices, values, shape):
indices = tf.expand_dims(indices, axis=1)
target = tf.cast(tf.fill(shape, values.dtype.min), dtype=values.dtype)
return tf.tensor_scatter_nd_max(target, indices, values)
out = tf.tensor_scatter_nd_max(target, indices, values)
if global_safe_scatter_max_min_to_zero:
has_scattered = tf.scatter_nd(indices, tf.ones_like(values, dtype="bool"), tf.cast(shape, dtype="int64"))
out = tf.where(has_scattered, out, tf.zeros_like(out))
return out


def scatter_reduce_mean(indices, values, shape):
Expand Down
17 changes: 17 additions & 0 deletions kgcnn/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from keras.losses import Loss
from keras.losses import mean_absolute_error
import keras.saving
from kgcnn.ops.core import decompose_ragged_tensor


@ks.saving.register_keras_serializable(package='kgcnn', name='MeanAbsoluteError')
Expand Down Expand Up @@ -62,3 +63,19 @@ def call(self, y_true, y_pred):
y_pred = ops.where(is_nan, ops.zeros_like(y_pred), y_pred)
y_true = ops.where(is_nan, ops.zeros_like(y_true), y_true)
return super(BinaryCrossentropyNoNaN, self).call(y_true, y_pred)


@ks.saving.register_keras_serializable(package='kgcnn', name='RaggedValuesMeanAbsoluteError')
class RaggedValuesMeanAbsoluteError(Loss):

def __init__(self, reduction="sum_over_batch_size", name="mean_absolute_error", dtype=None):
super(RaggedValuesMeanAbsoluteError, self).__init__(reduction=reduction, name=name, dtype=dtype)

def call(self, y_true, y_pred):
y_true_values = decompose_ragged_tensor(y_true)[0]
y_pred_values = decompose_ragged_tensor(y_pred)[0]
return mean_absolute_error(y_true_values, y_pred_values)

def get_config(self):
config = super(RaggedValuesMeanAbsoluteError, self).get_config()
return config
20 changes: 16 additions & 4 deletions kgcnn/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import numpy as np
from keras import ops
import keras.saving
from kgcnn.ops.core import decompose_ragged_tensor


@ks.saving.register_keras_serializable(package='kgcnn', name='ScaledMeanAbsoluteError')
class ScaledMeanAbsoluteError(ks.metrics.MeanAbsoluteError):
"""Metric for a scaled mean absolute error (MAE), which can undo a pre-scaling of the targets. Only intended as
metric this allows to info the MAE with correct units or absolute values during fit."""

def __init__(self, scaling_shape=(), name='mean_absolute_error', dtype_scale: str = None, **kwargs):
def __init__(self, scaling_shape=(), name='mean_absolute_error', dtype_scale: str = None, ragged: bool = False,
**kwargs):
super(ScaledMeanAbsoluteError, self).__init__(name=name, **kwargs)
self.scaling_shape = scaling_shape
self._is_ragged = ragged
self.dtype_scale = dtype_scale
self.scale = self.add_variable(
shape=scaling_shape,
Expand All @@ -27,14 +30,18 @@ def reset_state(self):
v.assign(ops.zeros(v.shape, dtype=v.dtype))

def update_state(self, y_true, y_pred, sample_weight=None):
if self._is_ragged:
y_true = decompose_ragged_tensor(y_true)[0]
y_pred = decompose_ragged_tensor(y_pred)[0]
y_true = self.scale * ops.cast(y_true, dtype=self.scale.dtype)
y_pred = self.scale * ops.cast(y_pred, dtype=self.scale.dtype)
return super(ScaledMeanAbsoluteError, self).update_state(y_true, y_pred, sample_weight=sample_weight)

def get_config(self):
"""Returns the serializable config of the metric."""
conf = super(ScaledMeanAbsoluteError, self).get_config()
conf.update({"scaling_shape": self.scaling_shape, "dtype_scale": self.dtype_scale})
conf.update({"scaling_shape": self.scaling_shape, "dtype_scale": self.dtype_scale,
"ragged": self._is_ragged})
return conf

def set_scale(self, scale):
Expand All @@ -47,10 +54,12 @@ class ScaledRootMeanSquaredError(ks.metrics.RootMeanSquaredError):
"""Metric for a scaled root mean squared error (RMSE), which can undo a pre-scaling of the targets.
Only intended as metric this allows to info the MAE with correct units or absolute values during fit."""

def __init__(self, scaling_shape=(), name='root_mean_squared_error', dtype_scale: str = None, **kwargs):
def __init__(self, scaling_shape=(), name='root_mean_squared_error', dtype_scale: str = None, ragged: bool = False,
**kwargs):
super(ScaledRootMeanSquaredError, self).__init__(name=name, **kwargs)
self.scaling_shape = scaling_shape
self.dtype_scale = dtype_scale
self._is_ragged = ragged
self.scale = self.add_variable(
shape=scaling_shape,
initializer=ks.initializers.Ones(),
Expand All @@ -64,14 +73,17 @@ def reset_state(self):
v.assign(ops.zeros(v.shape, dtype=v.dtype))

def update_state(self, y_true, y_pred, sample_weight=None):
if self._is_ragged:
y_true = decompose_ragged_tensor(y_true)[0]
y_pred = decompose_ragged_tensor(y_pred)[0]
y_true = self.scale * ops.cast(y_true, dtype=self.scale.dtype)
y_pred = self.scale * ops.cast(y_pred, dtype=self.scale.dtype)
return super(ScaledRootMeanSquaredError, self).update_state(y_true, y_pred, sample_weight=sample_weight)

def get_config(self):
"""Returns the serializable config of the metric."""
conf = super(ScaledRootMeanSquaredError, self).get_config()
conf.update({"scaling_shape": self.scaling_shape, "dtype_scale": self.dtype_scale})
conf.update({"scaling_shape": self.scaling_shape, "dtype_scale": self.dtype_scale, "ragged": self._is_ragged})
return conf

def set_scale(self, scale):
Expand Down
10 changes: 8 additions & 2 deletions training/hyper/hyper_md17.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@
"kgcnn_version": "4.0.0"
}
},
# Ragged!
# The Metrics deviate from padded sum.
"EGNN.EnergyForceModel": {
"model": {
"class_name": "EnergyForceModel",
Expand Down Expand Up @@ -243,7 +245,7 @@
}
},
"outputs": {"energy": {"name": "energy", "shape": (1,)},
"force": {"name": "force", "shape": (None, 3)}}
"force": {"name": "force", "shape": (None, 3), "ragged": True}}
}
},
"training": {
Expand All @@ -256,7 +258,11 @@
},
"compile": {
"optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}},
"loss_weights": {"energy": 0.02, "force": 0.98}
"loss_weights": {"energy": 0.02, "force": 0.98},
"loss": {
"energy": "mean_absolute_error",
"force": {"class_name": "kgcnn>RaggedValuesMeanAbsoluteError", "config": {}}
}
},
"scaler": {"class_name": "EnergyForceExtensiveLabelScaler",
"config": {"standardize_scale": False}},
Expand Down
13 changes: 10 additions & 3 deletions training/train_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
parser = argparse.ArgumentParser(description='Train a GNN on an Energy-Force Dataset.')
parser.add_argument("--hyper", required=False, help="Filepath to hyper-parameter config file (.py or .json).",
default="hyper/hyper_md17.py")
parser.add_argument("--category", required=False, help="Graph model to train.", default="Schnet.EnergyForceModel")
parser.add_argument("--category", required=False, help="Graph model to train.", default="EGNN.EnergyForceModel")
parser.add_argument("--model", required=False, help="Graph model to train.", default=None)
parser.add_argument("--dataset", required=False, help="Name of the dataset.", default=None)
parser.add_argument("--make", required=False, help="Name of the class for model.", default=None)
Expand Down Expand Up @@ -133,8 +133,14 @@
# If scaler was used we add rescaled standard metrics to compile, since otherwise the keras history will not
# directly log the original target values, but the scaled ones.
scaler_scale = scaler.get_scaling()
force_output_parameter = hyper["model"]["config"]["outputs"]["force"]
is_ragged = force_output_parameter["ragged"] if "ragged" in force_output_parameter else False
mae_metric_energy = ScaledMeanAbsoluteError(scaler_scale.shape, name="scaled_mean_absolute_error")
mae_metric_force = ScaledForceMeanAbsoluteError(scaler_scale.shape, name="scaled_mean_absolute_error")
if is_ragged:
mae_metric_force = ScaledMeanAbsoluteError(
scaler_scale.shape, name="scaled_mean_absolute_error", ragged=True)
else:
mae_metric_force = ScaledForceMeanAbsoluteError(scaler_scale.shape, name="scaled_mean_absolute_error")
if scaler_scale is not None:
mae_metric_energy.set_scale(scaler_scale)
mae_metric_force.set_scale(scaler_scale)
Expand All @@ -158,7 +164,8 @@
"energy": "mean_absolute_error",
"force": ForceMeanAbsoluteError()
},
metrics=scaled_metrics))
metrics=scaled_metrics
))

# Build model with reasonable data.
model.predict(x_test)
Expand Down

0 comments on commit d7b9f3f

Please sign in to comment.