Skip to content

Commit

Permalink
First fix for change in activation serialization in keras 3.0.2.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Feb 17, 2024
1 parent fae3029 commit d69a65f
Show file tree
Hide file tree
Showing 28 changed files with 483 additions and 126 deletions.
6 changes: 5 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ v4.0.1

* Fix Error in ``ExtensiveMolecularLabelScaler.transform`` missing default value.
* Added further benchmark results for kgcnn version 4.

* Fix error in ``kgcnn.layers.geom.PositionEncodingBasisLayer``
* Fix error in ``kgcnn.literature.GCN.make_model_weighted``
* Had to change serialization for activation functions since with keras>=3.0.2 custom strings are not allowed also
causing clashes with built-in functions. We catch defaults to be at least backward compatible as possible and changed to serialization dictionary. Adapted all hyperparameter.
* Renamed leaky_relu and swish in ``kgcnn.ops.activ`` to leaky_relu2 and swish2.

v4.0.0

Expand Down
31 changes: 24 additions & 7 deletions kgcnn/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from keras.layers import Dense, Concatenate, Activation, Average, Layer
from kgcnn.layers.aggr import AggregateLocalEdgesAttention
from keras import ops
import kgcnn.ops.activ


class AttentionHeadGAT(Layer): # noqa
Expand All @@ -23,7 +24,7 @@ def __init__(self,
use_edge_features=False,
use_final_activation=True,
has_self_loops=True,
activation="kgcnn>leaky_relu",
activation="kgcnn>leaky_relu2",
use_bias=True,
kernel_regularizer=None,
bias_regularizer=None,
Expand All @@ -41,7 +42,7 @@ def __init__(self,
use_edge_features (bool): Append edge features to attention computation. Default is False.
use_final_activation (bool): Whether to apply the final activation for the output.
has_self_loops (bool): If the graph has self-loops. Not used here. Default is True.
activation (str): Activation. Default is "kgcnn>leaky_relu",
activation (str): Activation. Default is "kgcnn>leaky_relu2".
use_bias (bool): Use bias. Default is True.
kernel_regularizer: Kernel regularization. Default is None.
bias_regularizer: Bias regularization. Default is None.
Expand All @@ -52,6 +53,10 @@ def __init__(self,
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(AttentionHeadGAT, self).__init__(**kwargs)
# Changes in keras serialization behaviour for activations in 3.0.2.
# Keep string at least for default. Also renames to prevent clashes with keras leaky_relu.
if activation in ["kgcnn>leaky_relu", "kgcnn>leaky_relu2"]:
activation = {"class_name": "function", "config": "kgcnn>leaky_relu2"}
self.use_edge_features = use_edge_features
self.use_final_activation = use_final_activation
self.has_self_loops = has_self_loops
Expand Down Expand Up @@ -138,7 +143,7 @@ def __init__(self,
use_edge_features=False,
use_final_activation=True,
has_self_loops=True,
activation="kgcnn>leaky_relu",
activation="kgcnn>leaky_relu2",
use_bias=True,
kernel_regularizer=None,
bias_regularizer=None,
Expand All @@ -156,7 +161,7 @@ def __init__(self,
use_edge_features (bool): Append edge features to attention computation. Default is False.
use_final_activation (bool): Whether to apply the final activation for the output.
has_self_loops (bool): If the graph has self-loops. Not used here. Default is True.
activation (str): Activation. Default is "kgcnn>leaky_relu",
activation (str): Activation. Default is "kgcnn>leaky_relu2".
use_bias (bool): Use bias. Default is True.
kernel_regularizer: Kernel regularization. Default is None.
bias_regularizer: Bias regularization. Default is None.
Expand All @@ -167,6 +172,10 @@ def __init__(self,
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(AttentionHeadGATV2, self).__init__(**kwargs)
# Changes in keras serialization behaviour for activations in 3.0.2.
# Keep string at least for default. Also renames to prevent clashes with keras leaky_relu.
if activation in ["kgcnn>leaky_relu", "kgcnn>leaky_relu2"]:
activation = {"class_name": "function", "config": "kgcnn>leaky_relu2"}
self.use_edge_features = use_edge_features
self.use_final_activation = use_final_activation
self.has_self_loops = has_self_loops
Expand Down Expand Up @@ -243,7 +252,7 @@ class MultiHeadGATV2Layer(AttentionHeadGATV2): # noqa
def __init__(self,
units: int,
num_heads: int,
activation: str = 'kgcnn>leaky_relu',
activation: str = "kgcnn>leaky_relu2",
use_bias: bool = True,
concat_heads: bool = True,
**kwargs):
Expand All @@ -253,6 +262,10 @@ def __init__(self,
use_bias=use_bias,
**kwargs
)
# Changes in keras serialization behaviour for activations in 3.0.2.
# Keep string at least for default. Also renames to prevent clashes with keras leaky_relu.
if activation in ["kgcnn>leaky_relu", "kgcnn>leaky_relu2"]:
activation = {"class_name": "function", "config": "kgcnn>leaky_relu2"}
self.num_heads = num_heads
self.concat_heads = concat_heads

Expand Down Expand Up @@ -345,7 +358,7 @@ class AttentiveHeadFP(Layer):
def __init__(self,
units,
use_edge_features=False,
activation='kgcnn>leaky_relu',
activation="kgcnn>leaky_relu2",
activation_context="elu",
use_bias=True,
kernel_regularizer=None,
Expand All @@ -361,7 +374,7 @@ def __init__(self,
Args:
units (int): Units for the linear trafo of node features before attention.
use_edge_features (bool): Append edge features to attention computation. Default is False.
activation (str): Activation. Default is {"class_name": "kgcnn>leaky_relu", "config": {"alpha": 0.2}}.
activation (str): Activation. Default is "kgcnn>leaky_relu2".
activation_context (str): Activation function for context. Default is "elu".
use_bias (bool): Use bias. Default is True.
kernel_regularizer: Kernel regularization. Default is None.
Expand All @@ -373,6 +386,10 @@ def __init__(self,
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(AttentiveHeadFP, self).__init__(**kwargs)
# Changes in keras serialization behaviour for activations in 3.0.2.
# Keep string at least for default. Also renames to prevent clashes with keras leaky_relu.
if activation in ["kgcnn>leaky_relu", "kgcnn>leaky_relu2"]:
activation = {"class_name": "function", "config": "kgcnn>leaky_relu2"}
self.use_edge_features = use_edge_features
self.units = int(units)
self.use_bias = use_bias
Expand Down
27 changes: 20 additions & 7 deletions kgcnn/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kgcnn.layers.aggr import AggregateWeightedLocalEdges, AggregateLocalEdges
from kgcnn.layers.gather import GatherNodesOutgoing
from keras import ops
from kgcnn.ops.activ import shifted_softplus
import kgcnn.ops.activ


class GCN(Layer):
Expand All @@ -28,7 +28,7 @@ def __init__(self,
units,
pooling_method='scatter_sum',
normalize_by_weights=False,
activation='kgcnn>leaky_relu',
activation="kgcnn>leaky_relu2",
use_bias=True,
kernel_regularizer=None,
bias_regularizer=None,
Expand All @@ -45,7 +45,7 @@ def __init__(self,
pooling_method (str): Pooling method for summing edges. Default is 'segment_sum'.
normalize_by_weights (bool): Normalize the pooled output by the sum of weights. Default is False.
In this case the edge features are considered weights of dimension (...,1) and are summed for each node.
activation (str): Activation. Default is 'kgcnn>leaky_relu'.
activation (str): Activation. Default is "kgcnn>leaky_relu2".
use_bias (bool): Use bias. Default is True.
kernel_regularizer: Kernel regularization. Default is None.
bias_regularizer: Bias regularization. Default is None.
Expand All @@ -56,6 +56,10 @@ def __init__(self,
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(GCN, self).__init__(**kwargs)
# Changes in keras serialization behaviour for activations in 3.0.2.
# Keep string at least for default. Also renames to prevent clashes with keras leaky_relu.
if activation in ["kgcnn>leaky_relu", "kgcnn>leaky_relu2"]:
activation = {"class_name": "function", "config": "kgcnn>leaky_relu2"}
self.normalize_by_weights = normalize_by_weights
self.pooling_method = pooling_method
self.units = units
Expand Down Expand Up @@ -133,7 +137,7 @@ def __init__(self, units,
units (int): Units for Dense layer.
cfconv_pool (str): Pooling method. Default is 'segment_sum'.
use_bias (bool): Use bias. Default is True.
activation (str): Activation function. Default is 'kgcnn>shifted_softplus'.
activation (str): Activation function. Default is "kgcnn>shifted_softplus".
kernel_regularizer: Kernel regularization. Default is None.
bias_regularizer: Bias regularization. Default is None.
activity_regularizer: Activity regularization. Default is None.
Expand All @@ -143,6 +147,10 @@ def __init__(self, units,
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(SchNetCFconv, self).__init__(**kwargs)
# Changes in keras serialization behaviour for activations in 3.0.2.
# Keep string at least for default.
if activation in ["kgcnn>shifted_softplus"]:
activation = {"class_name": "function", "config": "kgcnn>shifted_softplus"}
self.cfconv_pool = cfconv_pool
self.units = units
self.use_bias = use_bias
Expand Down Expand Up @@ -173,6 +181,7 @@ def call(self, inputs, **kwargs):
Returns:
Tensor: Updated node features.
"""
# print(inputs)
node, edge, disjoint_indices = inputs
x = self.lay_dense1(edge, **kwargs)
x = self.lay_dense2(x, **kwargs)
Expand Down Expand Up @@ -202,7 +211,7 @@ def __init__(self,
units=128,
cfconv_pool='scatter_sum',
use_bias=True,
activation='kgcnn>shifted_softplus',
activation="kgcnn>shifted_softplus",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
Expand All @@ -215,9 +224,9 @@ def __init__(self,
Args:
units (int): Dimension of node embedding. Default is 128.
cfconv_pool (str): Pooling method information for SchNetCFconv layer. Default is'segment_sum'.
cfconv_pool (str): Pooling method information for SchNetCFconv layer. Default is 'scatter_sum'.
use_bias (bool): Use bias in last layers. Default is True.
activation (str): Activation function. Default is 'kgcnn>shifted_softplus'.
activation (str): Activation function. Default is "kgcnn>shifted_softplus".
kernel_regularizer: Kernel regularization. Default is None.
bias_regularizer: Bias regularization. Default is None.
activity_regularizer: Activity regularization. Default is None.
Expand All @@ -227,6 +236,10 @@ def __init__(self,
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(SchNetInteraction, self).__init__(**kwargs)
# Changes in keras serialization behaviour for activations in 3.0.2.
# Keep string at least for default.
if activation in ["kgcnn>shifted_softplus"]:
activation = {"class_name": "function", "config": "kgcnn>shifted_softplus"}
self.cfconv_pool = cfconv_pool
self.use_bias = use_bias
self.units = units
Expand Down
10 changes: 9 additions & 1 deletion kgcnn/layers/gather.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import keras as ks
from typing import Union
from keras.layers import Layer, Concatenate
from keras import ops
Expand Down Expand Up @@ -50,7 +51,8 @@ def __init__(self, split_indices=(0, 1),
self.split_indices = split_indices
self.concat_axis = concat_axis
self.axis_indices = axis_indices
self._concat = Concatenate(axis=concat_axis)
if self.concat_axis is not None:
self._concat = Concatenate(axis=concat_axis)

def _compute_gathered_shape(self, input_shape):
assert len(input_shape) == 2
Expand All @@ -76,6 +78,12 @@ def compute_output_shape(self, input_shape):
xs = self._concat.compute_output_shape(xs)
return xs

def compute_output_spec(self, inputs_spec):
output_shape = self.compute_output_shape([x.shape for x in inputs_spec])
if self.concat_axis is not None:
return ks.KerasTensor(output_shape, dtype=inputs_spec[0].dtype)
return [ks.KerasTensor(s, dtype=inputs_spec[0].dtype) for s in output_shape]

def call(self, inputs, **kwargs):
r"""Forward pass.
Expand Down
20 changes: 12 additions & 8 deletions kgcnn/layers/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,16 @@ def __init__(self, selection_index: list = None, **kwargs):

def build(self, input_shape):
"""Build layer."""
super(NodePosition, self).build(input_shape)
self.layer_gather.build(input_shape)
self.built = True

def compute_output_shape(self, input_shape):
return self.layer_gather.compute_output_shape(input_shape)

def compute_output_spec(self, inputs_spec):
output_shape = self.compute_output_shape([x.shape for x in inputs_spec])
return [ks.KerasTensor(s, dtype=inputs_spec[0].dtype) for s in output_shape]

def call(self, inputs, **kwargs):
r"""Forward pass of :obj:`NodePosition`.
Expand Down Expand Up @@ -194,7 +199,7 @@ def _compute_euclidean_norm(inputs, axis: int = -1, keepdims: bool = False, inve
if not square_norm:
out = ops.sqrt(out)
if invert_norm:
out = 1/out
out = 1 / out
if no_nan:
out = ops.where(ops.isnan(out), ops.convert_to_tensor(0, dtype=out.dtype), out)
return out
Expand Down Expand Up @@ -580,7 +585,7 @@ def call(self, inputs, **kwargs):
Tensor: Expanded distance. Shape is `([K], bins)`.
"""
return self._compute_gauss_basis(inputs,
offset=self.offset, gamma=self.gamma, bins=self.bins, distance=self.distance)
offset=self.offset, gamma=self.gamma, bins=self.bins, distance=self.distance)

def get_config(self):
"""Update config."""
Expand Down Expand Up @@ -695,10 +700,9 @@ def call(self, inputs, **kwargs):
Returns:
Tensor: Expanded distance. Shape is `([K], bins)`.
"""
return self.map_values(
self._compute_fourier_encoding, inputs, dim_half=self.dim_half, wave_length_min=self.wave_length_min,
num_mult=self.num_mult, include_frequencies=self.include_frequencies,
interleave_sin_cos=self.interleave_sin_cos)
return self._compute_fourier_encoding(inputs, dim_half=self.dim_half, wave_length_min=self.wave_length_min,
num_mult=self.num_mult, include_frequencies=self.include_frequencies,
interleave_sin_cos=self.interleave_sin_cos)

def get_config(self):
"""Update config."""
Expand Down Expand Up @@ -762,7 +766,7 @@ def freq_init(shape, dtype):

self.frequencies = self.add_weight(
name="frequencies",
shape=(self.num_radial, ),
shape=(self.num_radial,),
dtype=self.dtype,
initializer=freq_init,
trainable=True
Expand Down
9 changes: 7 additions & 2 deletions kgcnn/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from keras.layers import Layer, Dense, Concatenate, GRUCell, Activation
from kgcnn.layers.gather import GatherState
from keras import ops
import kgcnn.ops.activ
from kgcnn.ops.scatter import scatter_reduce_softmax
from kgcnn.layers.aggr import Aggregate

Expand Down Expand Up @@ -180,7 +181,7 @@ def __init__(self,
units,
depth=3,
pooling_method="sum",
activation='kgcnn>leaky_relu',
activation="kgcnn>leaky_relu2",
activation_context="elu",
use_bias=True,
kernel_regularizer=None,
Expand All @@ -204,7 +205,7 @@ def __init__(self,
units (int): Units for the linear trafo of node features before attention.
pooling_method(str): Initial pooling before iteration. Default is "sum".
depth (int): Number of iterations for graph embedding. Default is 3.
activation (str): Activation. Default is {"class_name": "kgcnn>leaky_relu", "config": {"alpha": 0.2}}.
activation (str): Activation. Default is "kgcnn>leaky_relu2".
activation_context (str): Activation function for context. Default is "elu".
use_bias (bool): Use bias. Default is True.
kernel_regularizer: Kernel regularization. Default is None.
Expand All @@ -216,6 +217,10 @@ def __init__(self,
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(PoolingNodesAttentive, self).__init__(**kwargs)
# Changes in keras serialization behaviour for activations in 3.0.2.
# Keep string at least for default. Also renames to prevent clashes with keras leaky_relu.
if activation in ["kgcnn>leaky_relu", "kgcnn>leaky_relu2"]:
activation = {"class_name": "function", "config": "kgcnn>leaky_relu2"}
self.pooling_method = pooling_method
self.depth = depth
self.units = int(units)
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/layers/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class ResidualLayer(Layer):

def __init__(self, units,
use_bias=True,
activation='kgcnn>swish',
activation='swish',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/literature/CGCNN/_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import keras as ks
from kgcnn.ops.activ import *
import kgcnn.ops.activ
from kgcnn.layers.geom import (
DisplacementVectorsUnitCell,
DisplacementVectorsASU, NodePosition, FracToRealCoordinates,
Expand Down
Loading

0 comments on commit d69a65f

Please sign in to comment.