Skip to content

Commit

Permalink
Fix jax mean for 0 counts. Added BatchNorm for padded disjoint. Upgra…
Browse files Browse the repository at this point in the history
…ded keras requirements.
  • Loading branch information
PatReis committed Feb 26, 2024
1 parent 4f2f7f4 commit 4fa439b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion kgcnn/backend/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def scatter_reduce_mean(indices, values, shape):
zeros = jnp.zeros(shape, values.dtype)
counts = jnp.zeros(shape, values.dtype)
counts = counts.at[indices].add(jnp.ones_like(values))
inverse_counts = jnp.nan_to_num(jnp.reciprocal(counts))
inverse_counts = jnp.nan_to_num(jnp.reciprocal(counts), posinf=0.0, neginf=0.0, nan=0.0)
return zeros.at[indices].add(values)*inverse_counts


Expand Down
22 changes: 11 additions & 11 deletions kgcnn/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from keras import ops
from keras import InputSpec
from kgcnn.ops.scatter import scatter_reduce_sum
from keras.layers import LayerNormalization as _LayerNormalization
from keras.layers import BatchNormalization as _BatchNormalization
from keras.layers import LayerNormalization
from keras.layers import BatchNormalization

global_normalization_args = {
"GraphNormalization": (
Expand All @@ -27,11 +27,8 @@
),
}

# GraphLayerNormalization = _LayerNormalization
# GraphBatchNormalization = _BatchNormalization


class GraphLayerNormalization(_LayerNormalization):
class GraphLayerNormalization(LayerNormalization):

def __init__(self, **kwargs):
super(GraphLayerNormalization, self).__init__(**kwargs)
Expand All @@ -42,19 +39,18 @@ def compute_output_shape(self, input_shape):
def build(self, input_shape):
super(GraphLayerNormalization, self).build(input_shape[0])

def call(self, inputs):
def call(self, inputs, **kwargs):
return super(GraphLayerNormalization, self).call(inputs[0])

def get_config(self):
return super(GraphLayerNormalization, self).get_config()


class GraphBatchNormalization(_BatchNormalization):
class GraphBatchNormalization(BatchNormalization):

def __init__(self, padded_disjoint: bool = False, **kwargs):
super(GraphBatchNormalization, self).__init__(**kwargs)
self.padded_disjoint = padded_disjoint
assert not self.padded_disjoint, "Not implemented error"

def compute_output_shape(self, input_shape):
return super(GraphBatchNormalization, self).compute_output_shape(input_shape[0])
Expand All @@ -67,8 +63,12 @@ def build(self, input_shape):
InputSpec(ndim=len(input_shape[2])),
]

def call(self, inputs, **kwargs):
return super(GraphBatchNormalization, self).call(inputs[0])
def call(self, inputs, training=None, **kwargs):
if not self.padded_disjoint:
return super(GraphBatchNormalization, self).call(inputs[0], training=training)
else:
padded_mask = inputs[1] > 0
return super(GraphBatchNormalization, self).call(inputs[0], training=training, mask=padded_mask)

def get_config(self):
config = super(GraphBatchNormalization, self).get_config()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
url="https://github.com/aimat-lab/gcnn_keras",
install_requires=[
# "dm-tree",
"keras>=3.0.0",
"keras>=3.0.5",
# Backends
# "tf-nightly-cpu==2.16.0.dev20240101",
# "torch>=2.1.0",
Expand Down

0 comments on commit 4fa439b

Please sign in to comment.