Skip to content

Commit

Permalink
Merge pull request #46 from NREL/gb/minmax
Browse files Browse the repository at this point in the history
added custom functional layer for basic tensorflow functional layers …
  • Loading branch information
grantbuster committed Nov 29, 2023
2 parents a7ca20d + 28eda8e commit e39b1c4
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 1 deletion.
43 changes: 43 additions & 0 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,46 @@ def call(self, x, hi_res_feature):
Output tensor with the hi_res_feature added to x.
"""
return tf.concat((x, hi_res_feature), axis=-1)


class FunctionalLayer(tf.keras.layers.Layer):
"""Custom layer to implement the tensorflow layer functions (e.g., add,
subtract, multiply, maximum, and minimum) with a constant value. These
cannot be implemented in phygnn as normal layers because they need to
operate on two tensors of equal shape."""

def __init__(self, name, value):
"""
Parameters
----------
name : str
Name of the tensorflow layer function to be implemented, options
are (all lower-case): add, subtract, multiply, maximum, and minimum
value : float
Constant value to use in the function operation
"""

options = ('add', 'subtract', 'multiply', 'maximum', 'minimum')
msg = (f'FunctionalLayer input `name` must be one of "{options}" '
f'but received "{name}"')
assert name in options, msg

super().__init__(name=name)
self.value = value
self.fun = getattr(tf.keras.layers, self.name)

def call(self, x):
"""Operates on x with the specified function
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
x : tf.Tensor
Output tensor operated on by the specified function
"""
const = tf.constant(value=self.value, shape=x.shape, dtype=x.dtype)
return self.fun((x, const))
2 changes: 1 addition & 1 deletion phygnn/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Physics Guided Neural Network version."""

__version__ = '0.0.25'
__version__ = '0.0.26'
21 changes: 21 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SkipConnection,
SpatioTemporalExpansion,
TileLayer,
FunctionalLayer,
)
from phygnn.layers.handlers import HiddenLayers, Layers

Expand Down Expand Up @@ -423,3 +424,23 @@ def test_fno_3d():
x = layer(x)
with pytest.raises(tf.errors.InvalidArgumentError):
tf.assert_equal(x_in, x)


def test_functional_layer():
"""Test the generic functional layer"""

layer = FunctionalLayer('maximum', 1)
x = np.random.normal(0.5, 3, size=(1, 4, 4, 6, 3))
assert layer(x).numpy().min() == 1.0

# make sure layer works with input of arbitrary shape
x = np.random.normal(0.5, 3, size=(2, 8, 8, 4, 1))
assert layer(x).numpy().min() == 1.0

layer = FunctionalLayer('multiply', 1.5)
x = np.random.normal(0.5, 3, size=(1, 4, 4, 6, 3))
assert np.allclose(layer(x).numpy(), x * 1.5)

with pytest.raises(AssertionError) as excinfo:
FunctionalLayer('bad_arg', 0)
assert "must be one of" in str(excinfo.value)

0 comments on commit e39b1c4

Please sign in to comment.