Skip to content

Commit

Permalink
removed unused layers for fully model built to prevent errors in new …
Browse files Browse the repository at this point in the history
…pytorch trainer.
  • Loading branch information
PatReis committed Feb 19, 2024
1 parent 61617b0 commit 019a9a2
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 38 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ v4.0.1
* Added further benchmark results for kgcnn version 4.
* Fix error in ``kgcnn.layers.geom.PositionEncodingBasisLayer``
* Fix error in ``kgcnn.literature.GCN.make_model_weighted``
* Fix error in ``kgcnn.literature.AttentiveFP.make_model``
* Had to change serialization for activation functions since with keras>=3.0.2 custom strings are not allowed also
causing clashes with built-in functions. We catch defaults to be at least backward compatible as possible and changed to serialization dictionary. Adapted all hyperparameter.
* Renamed leaky_relu and swish in ``kgcnn.ops.activ`` to leaky_relu2 and swish2.
Expand Down
8 changes: 5 additions & 3 deletions kgcnn/layers/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,9 @@ def __init__(self, pooling_method: str = "scatter_sum", pooling_index: int = glo
self.normalize_by_weights = normalize_by_weights
self.pooling_index = pooling_index
self.pooling_method = pooling_method
self.to_aggregate = Aggregate(pooling_method=pooling_method)
self.to_aggregate_weights = Aggregate(pooling_method="scatter_sum")
# to_aggregate already made by super
if self.normalize_by_weights:
self.to_aggregate_weights = Aggregate(pooling_method="scatter_sum")
self.axis_indices = axis_indices

def build(self, input_shape):
Expand All @@ -201,7 +202,8 @@ def build(self, input_shape):
node_shape, edges_shape, edge_index_shape, weights_shape = [list(x) for x in input_shape]
edge_index_shape.pop(self.axis_indices)
self.to_aggregate.build([tuple(x) for x in [edges_shape, edge_index_shape, node_shape]])
self.to_aggregate_weights.build([tuple(x) for x in [weights_shape, edge_index_shape, node_shape]])
if self.normalize_by_weights:
self.to_aggregate_weights.build([tuple(x) for x in [weights_shape, edge_index_shape, node_shape]])
self.built = True

def compute_output_shape(self, input_shape):
Expand Down
88 changes: 74 additions & 14 deletions kgcnn/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,47 +247,97 @@ def get_config(self):
return config


class MultiHeadGATV2Layer(AttentionHeadGATV2): # noqa
class MultiHeadGATV2Layer(Layer): # noqa
r"""Single layer for multiple Attention heads from :obj:`AttentionHeadGATV2` .
Uses concatenation or averaging of heads for final output.
"""

def __init__(self,
units: int,
num_heads: int,
activation: str = "kgcnn>leaky_relu2",
use_bias: bool = True,
concat_heads: bool = True,
use_edge_features=False,
use_final_activation=True,
has_self_loops=True,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
normalize_softmax: bool = False,
**kwargs):
super(MultiHeadGATV2Layer, self).__init__(
units=units,
activation=activation,
use_bias=use_bias,
**kwargs
)
r"""Initialize layer.
Args:
units (int): Units for the linear trafo of node features before attention.
num_heads: Number of attention heads.
concat_heads: Whether to concatenate heads or average.
use_edge_features (bool): Append edge features to attention computation. Default is False.
use_final_activation (bool): Whether to apply the final activation for the output.
has_self_loops (bool): If the graph has self-loops. Not used here. Default is True.
activation (str): Activation. Default is "kgcnn>leaky_relu2".
use_bias (bool): Use bias. Default is True.
kernel_regularizer: Kernel regularization. Default is None.
bias_regularizer: Bias regularization. Default is None.
activity_regularizer: Activity regularization. Default is None.
kernel_constraint: Kernel constrains. Default is None.
bias_constraint: Bias constrains. Default is None.
kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'.
bias_initializer: Initializer for bias. Default is 'zeros'.
"""
super(MultiHeadGATV2Layer, self).__init__(**kwargs)
# Changes in keras serialization behaviour for activations in 3.0.2.
# Keep string at least for default. Also renames to prevent clashes with keras leaky_relu.
if activation in ["kgcnn>leaky_relu", "kgcnn>leaky_relu2"]:
activation = {"class_name": "function", "config": "kgcnn>leaky_relu2"}
self.num_heads = num_heads
self.concat_heads = concat_heads
self.use_edge_features = use_edge_features
self.use_final_activation = use_final_activation
self.has_self_loops = has_self_loops
self.units = int(units)
self.normalize_softmax = normalize_softmax
self.use_bias = use_bias
kernel_args = {"kernel_regularizer": kernel_regularizer,
"activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer,
"kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint,
"kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer}

self.head_layers = []
for _ in range(num_heads):
lay_linear = Dense(units, activation=activation, use_bias=use_bias)
lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias)
lay_alpha = Dense(1, activation='linear', use_bias=False)
lay_linear = Dense(units, activation=activation, use_bias=use_bias, **kernel_args)
lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args)
lay_alpha = Dense(1, activation='linear', use_bias=False, **kernel_args)

self.head_layers.append((lay_linear, lay_alpha_activation, lay_alpha))

self.lay_concat_alphas = Concatenate(axis=-2)
self.lay_concat_embeddings = Concatenate(axis=-2)
self.lay_pool_attention = AggregateLocalEdgesAttention()
# self.lay_pool = AggregateLocalEdges()

# self.lay_linear_trafo = Dense(units, activation="linear", use_bias=use_bias, **kernel_args)
# self.lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args)
# self.lay_alpha = Dense(1, activation="linear", use_bias=False, **kernel_args)
self.lay_gather_in = GatherNodesIngoing()
self.lay_gather_out = GatherNodesOutgoing()
self.lay_concat = Concatenate(axis=-1)
self.lay_pool_attention = AggregateLocalEdgesAttention(normalize_softmax=normalize_softmax)
if self.use_final_activation:
self.lay_final_activ = Activation(activation=activation)

if self.concat_heads:
self.lay_combine_heads = Concatenate(axis=-1)
else:
self.lay_combine_heads = Average()

def __call__(self, inputs, **kwargs):
def build(self, input_shape):
"""Build layer."""
super(MultiHeadGATV2Layer, self).build(input_shape)

def call(self, inputs, **kwargs):
node, edge, edge_index = inputs

# "a_ij" is a single-channel edge attention logits tensor. "a_ijs" is consequently the list which
Expand Down Expand Up @@ -338,6 +388,16 @@ def __call__(self, inputs, **kwargs):
def get_config(self):
"""Update layer config."""
config = super(MultiHeadGATV2Layer, self).get_config()
config.update({"use_edge_features": self.use_edge_features, "use_bias": self.use_bias,
"units": self.units, "has_self_loops": self.has_self_loops,
"normalize_softmax": self.normalize_softmax,
"use_final_activation": self.use_final_activation})
if self.num_heads > 0:
conf_sub = self.head_layers[0][0].get_config()
for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint",
"bias_constraint", "kernel_initializer", "bias_initializer", "activation"]:
if x in conf_sub:
config.update({x: conf_sub[x]})
config.update({
'num_heads': self.num_heads,
'concat_heads': self.concat_heads
Expand Down
1 change: 0 additions & 1 deletion kgcnn/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def call(self, inputs, **kwargs):
Returns:
Tensor: Updated node features.
"""
# print(inputs)
node, edge, disjoint_indices = inputs
x = self.lay_dense1(edge, **kwargs)
x = self.lay_dense2(x, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/literature/AttentiveFP/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def make_model(inputs: list = None,
model_inputs,
input_tensor_type=input_tensor_type,
cast_disjoint_kwargs=cast_disjoint_kwargs,
mask_assignment=[0, 0, 1],
mask_assignment=[0, 1, 1],
index_assignment=[None, None, 0]
)

Expand Down
2 changes: 1 addition & 1 deletion kgcnn/literature/DimeNetPP/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def call(self, inputs, **kwargs):
# Transform via 2D spherical basis
sbf = self.dense_sbf1(sbf, **kwargs)
sbf = self.dense_sbf2(sbf, **kwargs)
x_kj = self.lay_mult1([x_kj, sbf], **kwargs)
x_kj = self.lay_mult2([x_kj, sbf], **kwargs)

# Aggregate interactions and up-project embeddings
x_kj = self.lay_pool([rbf, x_kj, id_expand], **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/literature/HamNet/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def call(self, inputs, **kwargs):
q_u_ftr, q_v_ftr = self.gather_p([q_ftr, edi], **kwargs)
p_u_ftr, p_v_ftr = self.gather_q([p_ftr, edi], **kwargs)
p_uv_ftr = self.lazy_sub_p([p_v_ftr, p_u_ftr], **kwargs)
q_uv_ftr = self.lazy_sub_p([q_v_ftr, q_u_ftr], **kwargs)
q_uv_ftr = self.lazy_sub_q([q_v_ftr, q_u_ftr], **kwargs)

attend_ftr = self.dense_attend(hv_v_ftr, **kwargs)

Expand Down
1 change: 1 addition & 0 deletions kgcnn/literature/MAT/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class MATDistanceMatrix(ks.layers.Layer):
def __init__(self, trafo: Union[str, None] = "exp", **kwargs):
super(MATDistanceMatrix, self).__init__(**kwargs)
self.trafo = trafo
# self._softmax = ks.layers.Softmax(axis=2)
if self.trafo not in [None, "exp", "softmax"]:
raise ValueError("`trafo` must be in [None, 'exp', 'softmax']")

Expand Down
1 change: 1 addition & 0 deletions kgcnn/literature/MEGAN/_make.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import keras as ks
from ._model import MEGAN
from kgcnn.models.utils import update_model_kwargs
from kgcnn.layers.modules import Input
Expand Down
2 changes: 1 addition & 1 deletion kgcnn/literature/MXMNet/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(self, units: int = 64, output_units: int = 1, activation: str = "sw

self.lin_rbf_out = Dense(self.dim, use_bias=False, activation="linear")

self.h_mlp = GraphMLP(self.dim, activation=activation)
# Fix for kgcnn==4.0.1: removed overwrite mlp here. Should not change model but prevents unused layers.

self.y_mlp = GraphMLP([self.dim, self.dim, self.dim], activation=activation)
self.y_W = Dense(self.output_dim, activation="linear",
Expand Down
14 changes: 9 additions & 5 deletions training/hyper/hyper_esol.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,13 @@
"config": {
"name": "PAiNN",
"inputs": [
{"shape": [None], "name": "node_number", "dtype": "int64", "ragged": True},
{"shape": [None, 3], "name": "node_coordinates", "dtype": "float32", "ragged": True},
{"shape": [None, 2], "name": "range_indices", "dtype": "int64", "ragged": True}
{"shape": [None], "name": "node_number", "dtype": "int64"},
{"shape": [None, 3], "name": "node_coordinates", "dtype": "float32"},
{"shape": [None, 2], "name": "range_indices", "dtype": "int64"},
{"shape": (), "name": "total_nodes", "dtype": "int64"},
{"shape": (), "name": "total_ranges", "dtype": "int64"}
],
"input_tensor_type": "ragged",
"input_tensor_type": "padded",
"cast_disjoint_kwargs": {},
"input_node_embedding": {"input_dim": 95, "output_dim": 128},
"bessel_basis": {"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5},
Expand Down Expand Up @@ -522,7 +524,9 @@
"config": {},
"methods": [
{"set_attributes": {"add_hydrogen": True}},
{"map_list": {"method": "set_range", "max_distance": 3, "max_neighbours": 10000}}
{"map_list": {"method": "set_range", "max_distance": 3, "max_neighbours": 10000}},
{"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges",
"count_edges": "range_indices"}},
]
},
"data_unit": "mol/L"
Expand Down
25 changes: 15 additions & 10 deletions training/hyper/hyper_mp_jdft2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,18 @@
"config": {
'name': 'CGCNN',
'inputs': [
{'shape': (None,), 'name': 'node_number', 'dtype': 'int64', 'ragged': True},
{'shape': (None, 3), 'name': 'node_frac_coordinates', 'dtype': 'float64', 'ragged': True},
{'shape': (None, 2), 'name': 'range_indices', 'dtype': 'int64', 'ragged': True},
{'shape': (None, 3), 'name': 'range_image', 'dtype': 'float32', 'ragged': True},
{'shape': (3, 3), 'name': 'graph_lattice', 'dtype': 'float64', 'ragged': False},
{'shape': (None,), 'name': 'node_number', 'dtype': 'int64'},
{'shape': (None, 3), 'name': 'node_frac_coordinates', 'dtype': 'float64'},
{'shape': (None, 2), 'name': 'range_indices', 'dtype': 'int64'},
{'shape': (None, 3), 'name': 'range_image', 'dtype': 'float32'},
{'shape': (3, 3), 'name': 'graph_lattice', 'dtype': 'float64'},
# For `representation="asu"`:
# {'shape': (None, 1), 'name': 'multiplicities', 'dtype': 'float32', 'ragged': True},
# {'shape': (None, 4, 4), 'name': 'symmops', 'dtype': 'float64', 'ragged': True},
{"shape": (), "name": "total_nodes", "dtype": "int64"},
{"shape": (), "name": "total_ranges", "dtype": "int64"}
],
"input_tensor_type": "ragged",
"input_tensor_type": "padded",
'input_node_embedding': {'input_dim': 95, 'output_dim': 64},
'representation': 'unit', # None, 'asu' or 'unit'
'expand_distance': True,
Expand Down Expand Up @@ -205,15 +207,18 @@
"config": {"with_std": True, "with_mean": True, "copy": True}
},
},
"data": {
"dataset": {
"dataset": {
"class_name": "MatProjectJdft2dDataset",
"module_name": "kgcnn.data.datasets.MatProjectJdft2dDataset",
"config": {},
"methods": [
{"map_list": {"method": "set_range_periodic", "max_distance": 6.0}}
{"map_list": {"method": "set_range_periodic", "max_distance": 6.0}},
{"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges",
"count_edges": "range_indices", "count_nodes": "node_number",
"total_nodes": "total_nodes"}},
]
},
},
"data": {
"data_unit": "meV/atom"
},
"info": {
Expand Down
5 changes: 4 additions & 1 deletion training/train_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@
# Model summary
model.summary()
print(" Compiled with jit: %s" % model._jit_compile) # noqa
print(" Model is built: %s" % all([layer.built for layer in model._flatten_layers()])) # noqa
print(" Model is built: %s, with unbuilt: %s" % (
all([layer.built for layer in model._flatten_layers()]), # noqa
[layer.name for layer in model._flatten_layers() if not layer.built]
))

# Run keras model-fit and take time for training.
start = time.time()
Expand Down

0 comments on commit 019a9a2

Please sign in to comment.